Example #1
0
def main():

    args = parse()

    # Load a conf file
    dir_name = os.path.dirname(args.recog_model[0])
    conf = load_config(os.path.join(dir_name, 'conf.yml'))

    # Overwrite conf
    for k, v in conf.items():
        if 'recog' not in k:
            setattr(args, k, v)
    recog_params = vars(args)

    # Setting for logging
    if os.path.isfile(os.path.join(args.recog_dir, 'plot.log')):
        os.remove(os.path.join(args.recog_dir, 'plot.log'))
    logger = set_logger(os.path.join(args.recog_dir, 'plot.log'),
                        key='decoding')

    for i, s in enumerate(args.recog_sets):
        # Load dataset
        dataset = Dataset(
            corpus=args.corpus,
            tsv_path=s,
            dict_path=os.path.join(dir_name, 'dict.txt'),
            dict_path_sub1=os.path.join(dir_name, 'dict_sub1.txt') if
            os.path.isfile(os.path.join(dir_name, 'dict_sub1.txt')) else False,
            nlsyms=args.nlsyms,
            wp_model=os.path.join(dir_name, 'wp.model'),
            unit=args.unit,
            unit_sub1=args.unit_sub1,
            batch_size=args.recog_batch_size,
            is_test=True)

        if i == 0:
            # Load the ASR model
            model = Speech2Text(args, dir_name)
            model, checkpoint = load_checkpoint(model, args.recog_model[0])
            epoch = checkpoint['epoch']

            # ensemble (different models)
            ensemble_models = [model]
            if len(args.recog_model) > 1:
                for recog_model_e in args.recog_model[1:]:
                    # Load the ASR model
                    conf_e = load_config(
                        os.path.join(os.path.dirname(recog_model_e),
                                     'conf.yml'))
                    args_e = copy.deepcopy(args)
                    for k, v in conf_e.items():
                        if 'recog' not in k:
                            setattr(args_e, k, v)
                    model_e = Speech2Text(args_e)
                    model_e, _ = load_checkpoint(model_e, recog_model_e)
                    model_e.cuda()
                    ensemble_models += [model_e]

            # Load the LM for shallow fusion
            if not args.lm_fusion:
                if args.recog_lm is not None and args.recog_lm_weight > 0:
                    conf_lm = load_config(
                        os.path.join(os.path.dirname(args.recog_lm),
                                     'conf.yml'))
                    args_lm = argparse.Namespace()
                    for k, v in conf_lm.items():
                        setattr(args_lm, k, v)
                    lm = select_lm(args_lm)
                    lm, _ = load_checkpoint(lm, args.recog_lm)
                    if args_lm.backward:
                        model.lm_bwd = lm
                    else:
                        model.lm_fwd = lm

                if args.recog_lm_bwd is not None and args.recog_lm_weight > 0 and \
                        (args.recog_fwd_bwd_attention or args.recog_reverse_lm_rescoring):
                    conf_lm = load_config(
                        os.path.join(args.recog_lm_bwd, 'conf.yml'))
                    args_lm_bwd = argparse.Namespace()
                    for k, v in conf_lm.items():
                        setattr(args_lm_bwd, k, v)
                    lm_bwd = select_lm(args_lm_bwd)
                    lm_bwd, _ = load_checkpoint(lm_bwd, args.recog_lm_bwd)
                    model.lm_bwd = lm_bwd

            if not args.recog_unit:
                args.recog_unit = args.unit

            logger.info('recog unit: %s' % args.recog_unit)
            logger.info('recog metric: %s' % args.recog_metric)
            logger.info('recog oracle: %s' % args.recog_oracle)
            logger.info('epoch: %d' % (epoch - 1))
            logger.info('batch size: %d' % args.recog_batch_size)
            logger.info('beam width: %d' % args.recog_beam_width)
            logger.info('min length ratio: %.3f' % args.recog_min_len_ratio)
            logger.info('max length ratio: %.3f' % args.recog_max_len_ratio)
            logger.info('length penalty: %.3f' % args.recog_length_penalty)
            logger.info('coverage penalty: %.3f' % args.recog_coverage_penalty)
            logger.info('coverage threshold: %.3f' %
                        args.recog_coverage_threshold)
            logger.info('CTC weight: %.3f' % args.recog_ctc_weight)
            logger.info('LM path: %s' % args.recog_lm)
            logger.info('LM path (bwd): %s' % args.recog_lm_bwd)
            logger.info('LM weight: %.3f' % args.recog_lm_weight)
            logger.info('GNMT: %s' % args.recog_gnmt_decoding)
            logger.info('forward-backward attention: %s' %
                        args.recog_fwd_bwd_attention)
            logger.info('reverse LM rescoring: %s' %
                        args.recog_reverse_lm_rescoring)
            logger.info('resolving UNK: %s' % args.recog_resolving_unk)
            logger.info('ensemble: %d' % (len(ensemble_models)))
            logger.info('ASR decoder state carry over: %s' %
                        (args.recog_asr_state_carry_over))
            logger.info('LM state carry over: %s' %
                        (args.recog_lm_state_carry_over))
            logger.info('cache size: %d' % (args.recog_n_caches))
            logger.info('cache type: %s' % (args.recog_cache_type))
            logger.info('cache word frequency threshold: %s' %
                        (args.recog_cache_word_freq))
            logger.info('cache theta (speech): %.3f' %
                        (args.recog_cache_theta_speech))
            logger.info('cache lambda (speech): %.3f' %
                        (args.recog_cache_lambda_speech))
            logger.info('cache theta (lm): %.3f' % (args.recog_cache_theta_lm))
            logger.info('cache lambda (lm): %.3f' %
                        (args.recog_cache_lambda_lm))

            # GPU setting
            model.cuda()
            # TODO(hirofumi): move this

        save_path = mkdir_join(args.recog_dir, 'att_weights')
        if args.recog_n_caches > 0:
            save_path_cache = mkdir_join(args.recog_dir, 'cache')

        # Clean directory
        if save_path is not None and os.path.isdir(save_path):
            shutil.rmtree(save_path)
            os.mkdir(save_path)
            if args.recog_n_caches > 0:
                shutil.rmtree(save_path_cache)
                os.mkdir(save_path_cache)

        while True:
            batch, is_new_epoch = dataset.next(
                recog_params['recog_batch_size'])
            best_hyps_id, aws, (cache_attn_hist, cache_id_hist) = model.decode(
                batch['xs'],
                recog_params,
                dataset.idx2token[0],
                exclude_eos=False,
                refs_id=batch['ys'],
                ensemble_models=ensemble_models[1:]
                if len(ensemble_models) > 1 else [],
                speakers=batch['sessions']
                if dataset.corpus == 'swbd' else batch['speakers'])

            if model.bwd_weight > 0.5:
                # Reverse the order
                best_hyps_id = [hyp[::-1] for hyp in best_hyps_id]
                aws = [aw[::-1] for aw in aws]

            for b in range(len(batch['xs'])):
                tokens = dataset.idx2token[0](best_hyps_id[b],
                                              return_list=True)
                spk = batch['speakers'][b]

                plot_attention_weights(
                    aws[b][:len(tokens)],
                    tokens,
                    spectrogram=batch['xs'][b][:, :dataset.input_dim]
                    if args.input_type == 'speech' else None,
                    save_path=mkdir_join(save_path, spk,
                                         batch['utt_ids'][b] + '.png'),
                    figsize=(20, 8))

                if args.recog_n_caches > 0 and cache_id_hist is not None and cache_attn_hist is not None:
                    n_keys, n_queries = cache_attn_hist[0].shape
                    # mask = np.ones((n_keys, n_queries))
                    # for i in range(n_queries):
                    #     mask[:n_keys - i, -(i + 1)] = 0
                    mask = np.zeros((n_keys, n_queries))

                    plot_cache_weights(
                        cache_attn_hist[0],
                        keys=dataset.idx2token[0](cache_id_hist[-1],
                                                  return_list=True),  # fifo
                        # keys=dataset.idx2token[0](cache_id_hist, return_list=True),  # dict
                        queries=tokens,
                        save_path=mkdir_join(save_path_cache, spk,
                                             batch['utt_ids'][b] + '.png'),
                        figsize=(40, 16),
                        mask=mask)

                if model.bwd_weight > 0.5:
                    hyp = ' '.join(tokens[::-1])
                else:
                    hyp = ' '.join(tokens)
                logger.info('utt-id: %s' % batch['utt_ids'][b])
                logger.info('Ref: %s' % batch['text'][b].lower())
                logger.info('Hyp: %s' % hyp)
                logger.info('-' * 50)

            if is_new_epoch:
                break
Example #2
0
def main():

    args = parse()

    # Load a conf file
    dir_name = os.path.dirname(args.recog_model[0])
    conf = load_config(os.path.join(dir_name, 'conf.yml'))

    # Overwrite conf
    for k, v in conf.items():
        if 'recog' not in k:
            setattr(args, k, v)

    # Setting for logging
    if os.path.isfile(os.path.join(args.recog_dir, 'plot.log')):
        os.remove(os.path.join(args.recog_dir, 'plot.log'))
    set_logger(os.path.join(args.recog_dir, 'plot.log'),
               stdout=args.recog_stdout)

    for i, s in enumerate(args.recog_sets):
        # Load dataset
        dataset = Dataset(corpus=args.corpus,
                          tsv_path=s,
                          dict_path=os.path.join(dir_name, 'dict.txt'),
                          wp_model=os.path.join(dir_name, 'wp.model'),
                          unit=args.unit,
                          batch_size=args.recog_batch_size,
                          bptt=args.bptt,
                          backward=args.backward,
                          serialize=args.serialize,
                          is_test=True)

        if i == 0:
            # Load the LM
            model = build_lm(args, dir_name)
            topk_list = load_checkpoint(model, args.recog_model[0])
            epoch = int(args.recog_model[0].split('-')[-1])

            # Model averaging for Transformer
            if conf['lm_type'] == 'transformer':
                model = average_checkpoints(model,
                                            args.recog_model[0],
                                            n_average=args.recog_n_average,
                                            topk_list=topk_list)

            logger.info('epoch: %d' % (epoch - 1))
            logger.info('batch size: %d' % args.recog_batch_size)
            # logger.info('recog unit: %s' % args.recog_unit)
            # logger.info('ensemble: %d' % (len(ensemble_models)))
            logger.info('BPTT: %d' % (args.bptt))
            logger.info('cache size: %d' % (args.recog_n_caches))
            logger.info('cache theta: %.3f' % (args.recog_cache_theta))
            logger.info('cache lambda: %.3f' % (args.recog_cache_lambda))
            model.cache_theta = args.recog_cache_theta
            model.cache_lambda = args.recog_cache_lambda

            # GPU setting
            model.cuda()

        assert args.recog_n_caches > 0
        save_path = mkdir_join(args.recog_dir, 'cache')

        # Clean directory
        if save_path is not None and os.path.isdir(save_path):
            shutil.rmtree(save_path)
            os.mkdir(save_path)

        hidden = None
        fig_count = 0
        toknen_count = 0
        n_tokens = args.recog_n_caches
        while True:
            ys, is_new_epoch = dataset.next()

            for t in range(ys.shape[1] - 1):
                loss, hidden = model(ys[:, t:t + 2],
                                     hidden,
                                     is_eval=True,
                                     n_caches=args.recog_n_caches)[:2]

                if len(model.cache_attn) > 0:
                    if toknen_count == n_tokens:
                        tokens_keys = dataset.idx2token[0](
                            model.cache_ids[:args.recog_n_caches],
                            return_list=True)
                        tokens_query = dataset.idx2token[0](
                            model.cache_ids[-n_tokens:], return_list=True)

                        # Slide attention matrix
                        n_keys = len(tokens_keys)
                        n_queries = len(tokens_query)
                        cache_probs = np.zeros(
                            (n_keys, n_queries))  # `[n_keys, n_queries]`
                        mask = np.zeros((n_keys, n_queries))
                        for i, aw in enumerate(model.cache_attn[-n_tokens:]):
                            cache_probs[:(n_keys - n_queries + i + 1),
                                        i] = aw[0,
                                                -(n_keys - n_queries + i + 1):]
                            mask[(n_keys - n_queries + i + 1):, i] = 1

                        plot_cache_weights(cache_probs,
                                           keys=tokens_keys,
                                           queries=tokens_query,
                                           save_path=mkdir_join(
                                               save_path,
                                               str(fig_count) + '.png'),
                                           figsize=(40, 16),
                                           mask=mask)
                        toknen_count = 0
                        fig_count += 1
                    else:
                        toknen_count += 1

            if is_new_epoch:
                break
Example #3
0
def main():

    # Load configuration
    args, _, dir_name = parse_args_eval(sys.argv[1:])

    # Setting for logging
    if os.path.isfile(os.path.join(args.recog_dir, 'plot.log')):
        os.remove(os.path.join(args.recog_dir, 'plot.log'))
    set_logger(os.path.join(args.recog_dir, 'plot.log'),
               stdout=args.recog_stdout)

    # Load the LM
    model = build_lm(args, dir_name)
    load_checkpoint(args.recog_model[0], model)
    # NOTE: model averaging is not helpful for LM

    logger.info('batch size: %d' % args.recog_batch_size)
    logger.info('BPTT: %d' % (args.bptt))
    logger.info('cache size: %d' % (args.recog_n_caches))
    logger.info('cache theta: %.3f' % (args.recog_cache_theta))
    logger.info('cache lambda: %.3f' % (args.recog_cache_lambda))

    model.cache_theta = args.recog_cache_theta
    model.cache_lambda = args.recog_cache_lambda

    # GPU setting
    if args.recog_n_gpus > 0:
        model.cuda()

    for s in args.recog_sets:
        # Load dataset
        dataset = Dataset(corpus=args.corpus,
                          tsv_path=s,
                          batch_size=args.recog_batch_size,
                          bptt=args.bptt,
                          backward=args.backward,
                          serialize=args.serialize,
                          is_test=True)

        assert args.recog_n_caches > 0
        save_path = mkdir_join(args.recog_dir, 'cache')

        # Clean directory
        if save_path is not None and os.path.isdir(save_path):
            shutil.rmtree(save_path)
            os.mkdir(save_path)

        hidden = None
        fig_count = 0
        token_count = 0
        n_tokens = args.recog_n_caches
        while True:
            ys, is_new_epoch = dataset.next()

            for t in range(ys.shape[1] - 1):
                loss, hidden = model(ys[:, t:t + 2],
                                     hidden,
                                     is_eval=True,
                                     n_caches=args.recog_n_caches)[:2]

                if len(model.cache_attn) > 0:
                    if token_count == n_tokens:
                        tokens_keys = dataset.idx2token[0](
                            model.cache_ids[:args.recog_n_caches],
                            return_list=True)
                        tokens_query = dataset.idx2token[0](
                            model.cache_ids[-n_tokens:], return_list=True)

                        # Slide attention matrix
                        n_keys = len(tokens_keys)
                        n_queries = len(tokens_query)
                        cache_probs = np.zeros(
                            (n_keys, n_queries))  # `[n_keys, n_queries]`
                        mask = np.zeros((n_keys, n_queries))
                        for i, aw in enumerate(model.cache_attn[-n_tokens:]):
                            cache_probs[:(n_keys - n_queries + i + 1),
                                        i] = aw[0,
                                                -(n_keys - n_queries + i + 1):]
                            mask[(n_keys - n_queries + i + 1):, i] = 1

                        plot_cache_weights(cache_probs,
                                           keys=tokens_keys,
                                           queries=tokens_query,
                                           save_path=mkdir_join(
                                               save_path,
                                               str(fig_count) + '.png'),
                                           figsize=(40, 16),
                                           mask=mask)
                        token_count = 0
                        fig_count += 1
                    else:
                        token_count += 1

            if is_new_epoch:
                break