Пример #1
0
def main():

    # Load a config file
    config = load_config(os.path.join(args.model, 'config.yml'))

    decode_params = vars(args)

    # Merge config with args
    for k, v in config.items():
        if not hasattr(args, k):
            setattr(args, k, v)

    # Setting for logging
    logger = set_logger(os.path.join(args.plot_dir, 'plot.log'), key='decoding')

    for i, set in enumerate(args.eval_sets):
        # Load dataset
        eval_set = Dataset(csv_path=set,
                           dict_path=os.path.join(args.model, 'dict.txt'),
                           dict_path_sub=os.path.join(args.model, 'dict_sub.txt') if os.path.isfile(
                               os.path.join(args.model, 'dict_sub.txt')) else None,
                           wp_model=os.path.join(args.model, 'wp.model'),
                           unit=args.unit,
                           batch_size=args.batch_size,
                           max_num_frames=args.max_num_frames,
                           min_num_frames=args.min_num_frames,
                           is_test=True)

        if i == 0:
            args.vocab = eval_set.vocab
            args.vocab_sub = eval_set.vocab_sub
            args.input_dim = eval_set.input_dim

            # TODO(hirofumi): For cold fusion
            args.rnnlm_cold_fusion = None
            args.rnnlm_init = None

            # Load the ASR model
            model = Seq2seq(args)
            epoch, _, _, _ = model.load_checkpoint(args.model, epoch=args.epoch)

            model.save_path = args.model

            # For shallow fusion
            if args.rnnlm_cold_fusion is None and args.rnnlm is not None and args.rnnlm_weight > 0:
                # Load a RNNLM config file
                config_rnnlm = load_config(os.path.join(args.rnnlm, 'config.yml'))

                # Merge config with args
                args_rnnlm = argparse.Namespace()
                for k, v in config_rnnlm.items():
                    setattr(args_rnnlm, k, v)

                assert args.unit == args_rnnlm.unit
                args_rnnlm.vocab = eval_set.vocab

                # Load the pre-trianed RNNLM
                rnnlm = RNNLM(args_rnnlm)
                rnnlm.load_checkpoint(args.rnnlm, epoch=-1)
                if args_rnnlm.backward:
                    model.rnnlm_bwd_0 = rnnlm
                else:
                    model.rnnlm_fwd_0 = rnnlm

                logger.info('RNNLM path: %s' % args.rnnlm)
                logger.info('RNNLM weight: %.3f' % args.rnnlm_weight)
                logger.info('RNNLM backward: %s' % str(config_rnnlm['backward']))

            # GPU setting
            model.cuda()

            logger.info('beam width: %d' % args.beam_width)
            logger.info('length penalty: %.3f' % args.length_penalty)
            logger.info('coverage penalty: %.3f' % args.coverage_penalty)
            logger.info('coverage threshold: %.3f' % args.coverage_threshold)
            logger.info('epoch: %d' % (epoch - 1))

        save_path = mkdir_join(args.plot_dir, 'att_weights')

        # Clean directory
        if save_path is not None and os.path.isdir(save_path):
            shutil.rmtree(save_path)
            os.mkdir(save_path)

        while True:
            batch, is_new_epoch = eval_set.next(decode_params['batch_size'])
            best_hyps, aws, perm_idx = model.decode(batch['xs'], decode_params,
                                                    exclude_eos=False)
            ys = [batch['ys'][i] for i in perm_idx]

            if model.bwd_weight > 0.5:
                # Reverse the order
                best_hyps = [hyp[::-1] for hyp in best_hyps]
                aws = [aw[::-1] for aw in aws]

            for b in range(len(batch['xs'])):
                if args.unit == 'word':
                    token_list = eval_set.idx2word(best_hyps[b], return_list=True)
                if args.unit == 'wp':
                    token_list = eval_set.idx2wp(best_hyps[b], return_list=True)
                elif args.unit == 'char':
                    token_list = eval_set.idx2char(best_hyps[b], return_list=True)
                elif args.unit == 'phone':
                    token_list = eval_set.idx2phone(best_hyps[b], return_list=True)
                else:
                    raise NotImplementedError(args.unit)
                token_list = [unicode(t, 'utf-8') for t in token_list]
                speaker = '_'.join(batch['utt_ids'][b].replace('-', '_').split('_')[:-2])

                # error check
                assert len(batch['xs'][b]) <= 2000

                plot_attention_weights(aws[b][:len(token_list)],
                                       label_list=token_list,
                                       spectrogram=batch['xs'][b][:,
                                                                  :eval_set.input_dim] if args.input_type == 'speech' else None,
                                       save_path=mkdir_join(save_path, speaker, batch['utt_ids'][b] + '.png'),
                                       figsize=(20, 8))

                ref = ys[b]
                if model.bwd_weight > 0.5:
                    hyp = ' '.join(token_list[::-1])
                else:
                    hyp = ' '.join(token_list)
                logger.info('utt-id: %s' % batch['utt_ids'][b])
                logger.info('Ref: %s' % ref.lower())
                logger.info('Hyp: %s' % hyp)
                logger.info('-' * 50)

            if is_new_epoch:
                break
Пример #2
0
def main():

    # Load a config file
    config = load_config(os.path.join(args.model, 'config.yml'))

    decode_params = vars(args)

    # Merge config with args
    for k, v in config.items():
        if not hasattr(args, k):
            setattr(args, k, v)

    # Setting for logging
    if os.path.isfile(os.path.join(args.plot_dir, 'plot.log')):
        os.remove(os.path.join(args.plot_dir, 'plot.log'))
    logger = set_logger(os.path.join(args.plot_dir, 'plot.log'),
                        key='decoding')

    for i, set in enumerate(args.eval_sets):
        subsample_factor = 1
        subsample_factor_sub1 = 1
        subsample = [int(s) for s in args.subsample.split('_')]
        if args.conv_poolings:
            for p in args.conv_poolings.split('_'):
                p = int(p.split(',')[0].replace('(', ''))
                if p > 1:
                    subsample_factor *= p
        if args.train_set_sub1 is not None:
            subsample_factor_sub1 = subsample_factor * np.prod(
                subsample[:args.enc_nlayers_sub1 - 1])
        subsample_factor *= np.prod(subsample)

        # Load dataset
        dataset = Dataset(
            csv_path=set,
            dict_path=os.path.join(args.model, 'dict.txt'),
            dict_path_sub1=os.path.join(args.model, 'dict_sub.txt') if
            os.path.isfile(os.path.join(args.model, 'dict_sub.txt')) else None,
            wp_model=os.path.join(args.model, 'wp.model'),
            unit=args.unit,
            unit_sub1=args.unit_sub1,
            batch_size=args.batch_size,
            is_test=True)

        if i == 0:
            args.vocab = dataset.vocab
            args.vocab_sub1 = dataset.vocab_sub1
            args.input_dim = dataset.input_dim

            # TODO(hirofumi): For cold fusion
            args.rnnlm_cold_fusion = None
            args.rnnlm_init = None

            # Load the ASR model
            model = Seq2seq(args)
            epoch, _, _, _ = model.load_checkpoint(args.model,
                                                   epoch=args.epoch)

            model.save_path = args.model

            # GPU setting
            model.cuda()

            logger.info('epoch: %d' % (epoch - 1))

        save_path = mkdir_join(args.plot_dir, 'att_weights')

        # Clean directory
        if save_path is not None and os.path.isdir(save_path):
            shutil.rmtree(save_path)
            os.mkdir(save_path)

        while True:
            batch, is_new_epoch = dataset.next(decode_params['batch_size'])
            best_hyps, aws, perm_idx = model.decode(batch['xs'],
                                                    decode_params,
                                                    exclude_eos=False)
            ys = [batch['ys'][i] for i in perm_idx]

            # Get CTC probs
            ctc_probs, indices_topk, x_lens = model.get_ctc_posteriors(
                batch['xs'], temperature=1, topk=min(100, model.vocab))
            # NOTE: ctc_probs: '[B, T, topk]'

            for b in range(len(batch['xs'])):
                if args.unit == 'word':
                    token_list = dataset.idx2word(best_hyps[b],
                                                  return_list=True)
                elif args.unit == 'wp':
                    token_list = dataset.idx2wp(best_hyps[b], return_list=True)
                elif args.unit == 'char':
                    token_list = dataset.idx2char(best_hyps[b],
                                                  return_list=True)
                elif args.unit == 'phone':
                    token_list = dataset.idx2phone(best_hyps[b],
                                                   return_list=True)
                else:
                    raise NotImplementedError(args.unit)
                token_list = [unicode(t, 'utf-8') for t in token_list]
                speaker = '_'.join(batch['utt_ids'][b].replace(
                    '-', '_').split('_')[:-2])

                plot_ctc_probs(
                    ctc_probs[b, :x_lens[b]],
                    indices_topk[b],
                    nframes=x_lens[b],
                    subsample_factor=subsample_factor,
                    spectrogram=batch['xs'][b][:, :dataset.input_dim],
                    save_path=mkdir_join(save_path, speaker,
                                         batch['utt_ids'][b] + '.png'),
                    figsize=(20, 8))

                ref = ys[b]
                hyp = ' '.join(token_list)
                logger.info('utt-id: %s' % batch['utt_ids'][b])
                logger.info('Ref: %s' % ref.lower())
                logger.info('Hyp: %s' % hyp)
                logger.info('-' * 50)

            if is_new_epoch:
                break