Esempio n. 1
0
 def plot_ctc(self):
     """Plot CTC posteriors during training."""
     self.dec_fwd._plot_ctc(mkdir_join(self.save_path, 'ctc'))
     if getattr(self, 'dec_fwd_sub1', None) is not None:
         self.dec_fwd_sub1._plot_ctc(mkdir_join(self.save_path, 'ctc_sub1'))
     if getattr(self, 'dec_fwd_sub2', None) is not None:
         self.dec_fwd_sub2._plot_ctc(mkdir_join(self.save_path, 'ctc_sub2'))
Esempio n. 2
0
 def plot_attention(self):
     """Plot attention weights during training."""
     # encoder
     self.enc._plot_attention(mkdir_join(self.save_path, 'enc_att_weights'))
     # decoder
     self.dec_fwd._plot_attention(mkdir_join(self.save_path, 'dec_att_weights'))
     if getattr(self, 'dec_fwd_sub1', None) is not None:
         self.dec_fwd_sub1._plot_attention(mkdir_join(self.save_path, 'dec_att_weights_sub1'))
     if getattr(self, 'dec_fwd_sub2', None) is not None:
         self.dec_fwd_sub2._plot_attention(mkdir_join(self.save_path, 'dec_att_weights_sub2'))
Esempio n. 3
0
    def _plot_attention(self, save_path, n_cols=2):
        """Plot attention for each head in all encoder layers."""
        from matplotlib import pyplot as plt
        from matplotlib.ticker import MaxNLocator

        _save_path = mkdir_join(save_path, 'enc_att_weights')

        # Clean directory
        if _save_path is not None and os.path.isdir(_save_path):
            shutil.rmtree(_save_path)
            os.mkdir(_save_path)

        elens = self.data_dict['elens']

        for k, aw in self.aws_dict.items():
            plt.clf()
            n_heads = aw.shape[1]
            n_cols_tmp = 1 if n_heads == 1 else n_cols
            fig, axes = plt.subplots(max(1, n_heads // n_cols_tmp), n_cols_tmp,
                                     figsize=(20, 8), squeeze=False)
            for h in range(n_heads):
                ax = axes[h // n_cols_tmp, h % n_cols_tmp]
                ax.imshow(aw[-1, h, :elens[-1], :elens[-1]], aspect="auto")
                ax.grid(False)
                ax.set_xlabel("Input (head%d)" % h)
                ax.set_ylabel("Output (head%d)" % h)
                ax.xaxis.set_major_locator(MaxNLocator(integer=True))
                ax.yaxis.set_major_locator(MaxNLocator(integer=True))

            fig.tight_layout()
            fig.savefig(os.path.join(_save_path, '%s.png' % k), dvi=500)
            plt.close()
Esempio n. 4
0
    def plot_attention(self, n_cols=4):
        """Plot attention for each head in all layers."""
        from matplotlib import pyplot as plt
        from matplotlib.ticker import MaxNLocator

        save_path = mkdir_join(self.save_path, 'att_weights')

        # Clean directory
        if save_path is not None and os.path.isdir(save_path):
            shutil.rmtree(save_path)
            os.mkdir(save_path)

        for lth in range(self.n_layers):
            if not hasattr(self, 'yy_aws_layer%d' % lth):
                continue

            yy_aws = getattr(self, 'yy_aws_layer%d' % lth)

            plt.clf()
            fig, axes = plt.subplots(self.n_heads // n_cols, n_cols, figsize=(20, 8))
            for h in range(self.n_heads):
                if self.n_heads > n_cols:
                    ax = axes[h // n_cols, h % n_cols]
                else:
                    ax = axes[h]
                ax.imshow(yy_aws[-1, h, :, :], aspect="auto")
                ax.grid(False)
                ax.set_xlabel("Input (head%d)" % h)
                ax.set_ylabel("Output (head%d)" % h)
                ax.xaxis.set_major_locator(MaxNLocator(integer=True))
                ax.yaxis.set_major_locator(MaxNLocator(integer=True))

            fig.tight_layout()
            fig.savefig(os.path.join(save_path, 'layer%d.png' % (lth)), dvi=500)
            plt.close()
Esempio n. 5
0
    def _plot_attention(self, save_path, n_cols=1):
        """Plot attention."""
        from matplotlib import pyplot as plt
        from matplotlib.ticker import MaxNLocator

        _save_path = mkdir_join(save_path, 'dec_att_weights')

        # Clean directory
        if _save_path is not None and os.path.isdir(_save_path):
            shutil.rmtree(_save_path)
            os.mkdir(_save_path)

        if hasattr(self, 'aws'):
            plt.clf()
            fig, axes = plt.subplots(max(1, self.score.n_heads // n_cols),
                                     n_cols,
                                     figsize=(20, 8),
                                     squeeze=False)
            for h in range(self.score.n_heads):
                ax = axes[h // n_cols, h % n_cols]
                ax.imshow(self.aws[-1, h, :, :], aspect="auto")
                ax.grid(False)
                ax.set_xlabel("Input (head%d)" % h)
                ax.set_ylabel("Output (head%d)" % h)
                ax.xaxis.set_major_locator(MaxNLocator(integer=True))
                ax.yaxis.set_major_locator(MaxNLocator(integer=True))

            fig.tight_layout()
            fig.savefig(os.path.join(_save_path, 'attention.png'), dvi=500)
            plt.close()
Esempio n. 6
0
    def _plot_attention(self, save_path, n_cols=2):
        """Plot attention for each head in all layers."""
        from matplotlib import pyplot as plt
        from matplotlib.ticker import MaxNLocator

        for attn in ['yy', 'xy']:
            _save_path = mkdir_join(save_path, 'dec_%s_att_weights' % attn)

            # Clean directory
            if _save_path is not None and os.path.isdir(_save_path):
                shutil.rmtree(_save_path)
                os.mkdir(_save_path)

            for l in range(self.n_layers):
                if hasattr(self, '%s_aws_layer%d' % (attn, l)):
                    aws = getattr(self, '%s_aws_layer%d' % (attn, l))

                    plt.clf()
                    fig, axes = plt.subplots(max(1, self.n_heads // n_cols),
                                             n_cols,
                                             figsize=(20, 8),
                                             squeeze=False)
                    for h in range(self.n_heads):
                        ax = axes[h // n_cols, h % n_cols]
                        ax.imshow(aws[-1, h, :, :], aspect="auto")
                        ax.grid(False)
                        ax.set_xlabel("Input (head%d)" % h)
                        ax.set_ylabel("Output (head%d)" % h)
                        ax.xaxis.set_major_locator(MaxNLocator(integer=True))
                        ax.yaxis.set_major_locator(MaxNLocator(integer=True))

                    fig.tight_layout()
                    fig.savefig(os.path.join(_save_path, 'layer%d.png' % (l)),
                                dvi=500)
                    plt.close()
Esempio n. 7
0
    def _plot_attention(self, save_path, n_cols=2):
        """Plot attention for each head in all decoder layers."""
        if getattr(self, 'att_weight', 0) == 0:
            return
        from matplotlib import pyplot as plt
        from matplotlib.ticker import MaxNLocator

        _save_path = mkdir_join(save_path, 'dec_att_weights')

        # Clean directory
        if _save_path is not None and os.path.isdir(_save_path):
            shutil.rmtree(_save_path)
            os.mkdir(_save_path)

        elens = self.data_dict['elens']
        ylens = self.data_dict['ylens']
        # ys = self.data_dict['ys']

        for k, aw in self.aws_dict.items():
            plt.clf()
            n_heads = aw.shape[1]
            n_cols_tmp = 1 if n_heads == 1 else n_cols * max(1, n_heads // 4)
            fig, axes = plt.subplots(max(1, n_heads // n_cols_tmp),
                                     n_cols_tmp,
                                     figsize=(20 * max(1, n_heads // 4), 8),
                                     squeeze=False)
            for h in range(n_heads):
                ax = axes[h // n_cols_tmp, h % n_cols_tmp]
                if 'yy' in k:
                    ax.imshow(aw[-1, h, :ylens[-1], :ylens[-1]], aspect="auto")
                else:
                    ax.imshow(aw[-1, h, :ylens[-1], :elens[-1]], aspect="auto")
                # NOTE: show the last utterance in a mini-batch
                ax.grid(False)
                ax.set_xlabel("Input (head%d)" % h)
                ax.set_ylabel("Output (head%d)" % h)
                ax.xaxis.set_major_locator(MaxNLocator(integer=True))
                ax.yaxis.set_major_locator(MaxNLocator(integer=True))
                # ax.set_yticks(np.linspace(0, ylens[-1] - 1, ylens[-1]))
                # ax.set_yticks(np.linspace(0, ylens[-1] - 1, 1), minor=True)
                # ax.set_yticklabels(ys + [''])

            fig.tight_layout()
            fig.savefig(os.path.join(_save_path, '%s.png' % k), dvi=500)
            plt.close()
Esempio n. 8
0
    def _plot_ctc(self, save_path, topk=10):
        """Plot CTC posteriors."""
        if self.ctc_weight == 0:
            return
        from matplotlib import pyplot as plt

        _save_path = mkdir_join(save_path, 'ctc')

        # Clean directory
        if _save_path is not None and os.path.isdir(_save_path):
            shutil.rmtree(_save_path)
            os.mkdir(_save_path)

        elen = self.ctc.data_dict['elens'][-1]
        probs = self.ctc.prob_dict['probs'][-1, :elen]  # `[T, vocab]`
        # NOTE: show the last utterance in a mini-batch

        topk_ids = np.argsort(probs, axis=1)

        plt.clf()
        n_frames = probs.shape[0]
        times_probs = np.arange(n_frames)

        # NOTE: index 0 is reserved for blank
        for idx in set(topk_ids.reshape(-1).tolist()):
            if idx == 0:
                plt.plot(times_probs,
                         probs[:, 0],
                         ':',
                         label='<blank>',
                         color='grey')
            else:
                plt.plot(times_probs, probs[:, idx])
        plt.xlabel(u'Time [frame]', fontsize=12)
        plt.ylabel('Posteriors', fontsize=12)
        plt.xticks(list(range(0, int(n_frames) + 1, 10)))
        plt.yticks(list(range(0, 2, 1)))

        plt.tight_layout()
        plt.savefig(os.path.join(_save_path, '%s.png' % 'prob'), dvi=500)
        plt.close()
Esempio n. 9
0
def main():

    args = parse()

    # Load a conf file
    dir_name = os.path.dirname(args.recog_model[0])
    conf = load_config(os.path.join(dir_name, 'conf.yml'))

    # Overwrite conf
    for k, v in conf.items():
        if 'recog' not in k:
            setattr(args, k, v)

    # Setting for logging
    if os.path.isfile(os.path.join(args.recog_dir, 'plot.log')):
        os.remove(os.path.join(args.recog_dir, 'plot.log'))
    set_logger(os.path.join(args.recog_dir, 'plot.log'),
               stdout=args.recog_stdout)

    for i, s in enumerate(args.recog_sets):
        # Load dataset
        dataset = Dataset(corpus=args.corpus,
                          tsv_path=s,
                          dict_path=os.path.join(dir_name, 'dict.txt'),
                          wp_model=os.path.join(dir_name, 'wp.model'),
                          unit=args.unit,
                          batch_size=args.recog_batch_size,
                          bptt=args.bptt,
                          backward=args.backward,
                          serialize=args.serialize,
                          is_test=True)

        if i == 0:
            # Load the LM
            model = build_lm(args, dir_name)
            topk_list = load_checkpoint(model, args.recog_model[0])
            epoch = int(args.recog_model[0].split('-')[-1])

            # Model averaging for Transformer
            if conf['lm_type'] == 'transformer':
                model = average_checkpoints(model,
                                            args.recog_model[0],
                                            n_average=args.recog_n_average,
                                            topk_list=topk_list)

            logger.info('epoch: %d' % (epoch - 1))
            logger.info('batch size: %d' % args.recog_batch_size)
            # logger.info('recog unit: %s' % args.recog_unit)
            # logger.info('ensemble: %d' % (len(ensemble_models)))
            logger.info('BPTT: %d' % (args.bptt))
            logger.info('cache size: %d' % (args.recog_n_caches))
            logger.info('cache theta: %.3f' % (args.recog_cache_theta))
            logger.info('cache lambda: %.3f' % (args.recog_cache_lambda))
            model.cache_theta = args.recog_cache_theta
            model.cache_lambda = args.recog_cache_lambda

            # GPU setting
            model.cuda()

        assert args.recog_n_caches > 0
        save_path = mkdir_join(args.recog_dir, 'cache')

        # Clean directory
        if save_path is not None and os.path.isdir(save_path):
            shutil.rmtree(save_path)
            os.mkdir(save_path)

        hidden = None
        fig_count = 0
        toknen_count = 0
        n_tokens = args.recog_n_caches
        while True:
            ys, is_new_epoch = dataset.next()

            for t in range(ys.shape[1] - 1):
                loss, hidden = model(ys[:, t:t + 2],
                                     hidden,
                                     is_eval=True,
                                     n_caches=args.recog_n_caches)[:2]

                if len(model.cache_attn) > 0:
                    if toknen_count == n_tokens:
                        tokens_keys = dataset.idx2token[0](
                            model.cache_ids[:args.recog_n_caches],
                            return_list=True)
                        tokens_query = dataset.idx2token[0](
                            model.cache_ids[-n_tokens:], return_list=True)

                        # Slide attention matrix
                        n_keys = len(tokens_keys)
                        n_queries = len(tokens_query)
                        cache_probs = np.zeros(
                            (n_keys, n_queries))  # `[n_keys, n_queries]`
                        mask = np.zeros((n_keys, n_queries))
                        for i, aw in enumerate(model.cache_attn[-n_tokens:]):
                            cache_probs[:(n_keys - n_queries + i + 1),
                                        i] = aw[0,
                                                -(n_keys - n_queries + i + 1):]
                            mask[(n_keys - n_queries + i + 1):, i] = 1

                        plot_cache_weights(cache_probs,
                                           keys=tokens_keys,
                                           queries=tokens_query,
                                           save_path=mkdir_join(
                                               save_path,
                                               str(fig_count) + '.png'),
                                           figsize=(40, 16),
                                           mask=mask)
                        toknen_count = 0
                        fig_count += 1
                    else:
                        toknen_count += 1

            if is_new_epoch:
                break
Esempio n. 10
0
def main(args):

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    args_init = copy.deepcopy(args)
    args_teacher = copy.deepcopy(args)

    # Load a conf file
    if args.resume:
        conf = load_config(
            os.path.join(os.path.dirname(args.resume), 'conf.yml'))
        for k, v in conf.items():
            if k not in ['resume', 'local_rank']:
                setattr(args, k, v)

    args = compute_subsampling_factor(args)
    resume_epoch = int(args.resume.split('-')[-1]) if args.resume else 0

    # Load dataset
    train_set = build_dataloader(args=args,
                                 tsv_path=args.train_set,
                                 tsv_path_sub1=args.train_set_sub1,
                                 tsv_path_sub2=args.train_set_sub2,
                                 batch_size=args.batch_size,
                                 batch_size_type=args.batch_size_type,
                                 max_n_frames=args.max_n_frames,
                                 resume_epoch=resume_epoch,
                                 sort_by=args.sort_by,
                                 short2long=args.sort_short2long,
                                 sort_stop_epoch=args.sort_stop_epoch,
                                 num_workers=args.workers,
                                 pin_memory=args.pin_memory,
                                 distributed=args.distributed,
                                 word_alignment_dir=args.train_word_alignment,
                                 ctc_alignment_dir=args.train_ctc_alignment)
    dev_set = build_dataloader(
        args=args,
        tsv_path=args.dev_set,
        tsv_path_sub1=args.dev_set_sub1,
        tsv_path_sub2=args.dev_set_sub2,
        batch_size=1 if 'transducer' in args.dec_type else args.batch_size,
        batch_size_type='seq'
        if 'transducer' in args.dec_type else args.batch_size_type,
        max_n_frames=1600,
        word_alignment_dir=args.dev_word_alignment,
        ctc_alignment_dir=args.dev_ctc_alignment)
    eval_sets = [
        build_dataloader(args=args, tsv_path=s, batch_size=1, is_test=True)
        for s in args.eval_sets
    ]

    args.vocab = train_set.vocab
    args.vocab_sub1 = train_set.vocab_sub1
    args.vocab_sub2 = train_set.vocab_sub2
    args.input_dim = train_set.input_dim

    # Set save path
    if args.resume:
        args.save_path = os.path.dirname(args.resume)
        dir_name = os.path.basename(args.save_path)
    else:
        dir_name = set_asr_model_name(args)
        if args.mbr_training:
            assert args.asr_init
            args.save_path = mkdir_join(os.path.dirname(args.asr_init),
                                        dir_name)
        else:
            args.save_path = mkdir_join(
                args.model_save_dir,
                '_'.join(os.path.basename(args.train_set).split('.')[:-1]),
                dir_name)
        if args.local_rank > 0:
            time.sleep(1)
        args.save_path = set_save_path(args.save_path)  # avoid overwriting

    # Set logger
    set_logger(os.path.join(args.save_path, 'train.log'), args.stdout,
               args.local_rank)

    # Load a LM conf file for LM fusion & LM initialization
    if not args.resume and args.external_lm:
        lm_conf = load_config(
            os.path.join(os.path.dirname(args.external_lm), 'conf.yml'))
        args.lm_conf = argparse.Namespace()
        for k, v in lm_conf.items():
            setattr(args.lm_conf, k, v)
        assert args.unit == args.lm_conf.unit
        assert args.vocab == args.lm_conf.vocab

    # Model setting
    model = Speech2Text(args, args.save_path, train_set.idx2token[0])

    if not args.resume:
        # Save nlsyms, dictionary, and wp_model
        if args.nlsyms:
            shutil.copy(args.nlsyms, os.path.join(args.save_path,
                                                  'nlsyms.txt'))
        for sub in ['', '_sub1', '_sub2']:
            if args.get('dict' + sub):
                shutil.copy(
                    args.get('dict' + sub),
                    os.path.join(args.save_path, 'dict' + sub + '.txt'))
            if args.get('unit' + sub) == 'wp':
                shutil.copy(
                    args.get('wp_model' + sub),
                    os.path.join(args.save_path, 'wp' + sub + '.model'))

        for k, v in sorted(args.items(), key=lambda x: x[0]):
            logger.info('%s: %s' % (k, str(v)))

        # Count total parameters
        for n in sorted(list(model.num_params_dict.keys())):
            n_params = model.num_params_dict[n]
            logger.info("%s %d" % (n, n_params))
        logger.info("Total %.2f M parameters" %
                    (model.total_parameters / 1000000))
        logger.info('torch version: %s' % str(torch.__version__))
        logger.info(model)

        # Initialize with pre-trained model's parameters
        if args.asr_init:
            # Load ASR model (full model)
            conf_init = load_config(
                os.path.join(os.path.dirname(args.asr_init), 'conf.yml'))
            for k, v in conf_init.items():
                setattr(args_init, k, v)
            model_init = Speech2Text(args_init)
            load_checkpoint(args.asr_init, model_init)

            # Overwrite parameters
            param_dict = dict(model_init.named_parameters())
            for n, p in model.named_parameters():
                if n in param_dict.keys() and p.size() == param_dict[n].size():
                    if args.asr_init_enc_only and 'enc' not in n:
                        continue
                    p.data = param_dict[n].data
                    logger.info('Overwrite %s' % n)

    # Set optimizer
    optimizer = set_optimizer(
        model,
        'sgd' if resume_epoch > args.convert_to_sgd_epoch else args.optimizer,
        args.lr, args.weight_decay)

    # Wrap optimizer by learning rate scheduler
    is_transformer = 'former' in args.enc_type or 'former' in args.dec_type or 'former' in args.dec_type_sub1
    scheduler = LRScheduler(
        optimizer,
        args.lr,
        decay_type=args.lr_decay_type,
        decay_start_epoch=args.lr_decay_start_epoch,
        decay_rate=args.lr_decay_rate,
        decay_patient_n_epochs=args.lr_decay_patient_n_epochs,
        early_stop_patient_n_epochs=args.early_stop_patient_n_epochs,
        lower_better=args.metric not in ['accuracy', 'bleu'],
        warmup_start_lr=args.warmup_start_lr,
        warmup_n_steps=args.warmup_n_steps,
        peak_lr=0.05 / (args.get('transformer_enc_d_model', 0)**0.5)
        if 'conformer' in args.enc_type else 1e6,
        model_size=args.get('transformer_enc_d_model',
                            args.get('transformer_dec_d_model', 0)),
        factor=args.lr_factor,
        noam=args.optimizer == 'noam',
        save_checkpoints_topk=10 if is_transformer else 1)

    if args.resume:
        # Restore the last saved model
        load_checkpoint(args.resume, model, scheduler)

        # Resume between convert_to_sgd_epoch -1 and convert_to_sgd_epoch
        if resume_epoch == args.convert_to_sgd_epoch:
            scheduler.convert_to_sgd(model,
                                     args.lr,
                                     args.weight_decay,
                                     decay_type='always',
                                     decay_rate=0.5)

    # Load teacher ASR model
    teacher = None
    if args.teacher:
        assert os.path.isfile(args.teacher), 'There is no checkpoint.'
        conf_teacher = load_config(
            os.path.join(os.path.dirname(args.teacher), 'conf.yml'))
        for k, v in conf_teacher.items():
            setattr(args_teacher, k, v)
        # Setting for knowledge distillation
        args_teacher.ss_prob = 0
        args.lsm_prob = 0
        teacher = Speech2Text(args_teacher)
        load_checkpoint(args.teacher, teacher)

    # Load teacher LM
    teacher_lm = None
    if args.teacher_lm:
        assert os.path.isfile(args.teacher_lm), 'There is no checkpoint.'
        conf_lm = load_config(
            os.path.join(os.path.dirname(args.teacher_lm), 'conf.yml'))
        args_lm = argparse.Namespace()
        for k, v in conf_lm.items():
            setattr(args_lm, k, v)
        teacher_lm = build_lm(args_lm)
        load_checkpoint(args.teacher_lm, teacher_lm)

    # GPU setting
    args.use_apex = args.train_dtype in ["O0", "O1", "O2", "O3"]
    amp, scaler = None, None
    if args.n_gpus >= 1:
        model.cudnn_setting(
            deterministic=((not is_transformer) and (not args.cudnn_benchmark))
            or args.cudnn_deterministic,
            benchmark=(not is_transformer) and args.cudnn_benchmark)

        # Mixed precision training setting
        if args.use_apex:
            if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
                scaler = torch.cuda.amp.GradScaler()
            else:
                from apex import amp
                model, scheduler.optimizer = amp.initialize(
                    model, scheduler.optimizer, opt_level=args.train_dtype)
                from neural_sp.models.seq2seq.decoders.ctc import CTC
                amp.register_float_function(CTC, "loss_fn")
                # NOTE: see https://github.com/espnet/espnet/pull/1779
                amp.init()
                if args.resume:
                    load_checkpoint(args.resume, amp=amp)

        n = torch.cuda.device_count() // args.local_world_size
        device_ids = list(range(args.local_rank * n,
                                (args.local_rank + 1) * n))

        torch.cuda.set_device(device_ids[0])
        model.cuda(device_ids[0])
        scheduler.cuda(device_ids[0])
        if args.distributed:
            model = DDP(model, device_ids=device_ids)
        else:
            model = CustomDataParallel(model,
                                       device_ids=list(range(args.n_gpus)))

        if teacher is not None:
            teacher.cuda()
        if teacher_lm is not None:
            teacher_lm.cuda()
    else:
        model = CPUWrapperASR(model)

    # Set process name
    logger.info('PID: %s' % os.getpid())
    logger.info('USERNAME: %s' % os.uname()[1])
    logger.info('#GPU: %d' % torch.cuda.device_count())
    setproctitle(args.job_name if args.job_name else dir_name)

    # Set reporter
    reporter = Reporter(args, model, args.local_rank)
    args.wandb_id = reporter.wandb_id
    if args.resume:
        n_steps = scheduler.n_steps * max(
            1, args.accum_grad_n_steps // args.local_world_size)
        reporter.resume(n_steps, resume_epoch)

    # Save conf file as a yaml file
    if args.local_rank == 0:
        save_config(args, os.path.join(args.save_path, 'conf.yml'))
        if args.external_lm:
            save_config(args.lm_conf,
                        os.path.join(args.save_path, 'conf_lm.yml'))
        # NOTE: save after reporter for wandb ID

    # Define tasks
    if args.mtl_per_batch:
        # NOTE: from easier to harder tasks
        tasks = []
        if args.total_weight - args.bwd_weight - args.ctc_weight - args.sub1_weight - args.sub2_weight > 0:
            tasks += ['ys']
        if args.bwd_weight > 0:
            tasks = ['ys.bwd'] + tasks
        if args.ctc_weight > 0:
            tasks = ['ys.ctc'] + tasks
        if args.mbr_ce_weight > 0:
            tasks = ['ys.mbr'] + tasks
        for sub in ['sub1', 'sub2']:
            if args.get('train_set_' + sub) is not None:
                if args.get(sub + '_weight', 0) - args.get(
                        'ctc_weight_' + sub, 0) > 0:
                    tasks = ['ys_' + sub] + tasks
                if args.get('ctc_weight_' + sub, 0) > 0:
                    tasks = ['ys_' + sub + '.ctc'] + tasks
    else:
        tasks = ['all']

    if args.get('ss_start_epoch', 0) <= resume_epoch:
        model.module.trigger_scheduled_sampling()
    if args.get('mocha_quantity_loss_start_epoch', 0) <= resume_epoch:
        model.module.trigger_quantity_loss()

    start_time_train = time.time()
    for ep in range(resume_epoch, args.n_epochs):
        train_one_epoch(model, train_set, dev_set, eval_sets, scheduler,
                        reporter, logger, args, amp, scaler, tasks, teacher,
                        teacher_lm)

        # Save checkpoint and validate model per epoch
        if reporter.n_epochs + 1 < args.eval_start_epoch:
            scheduler.epoch()  # lr decay
            reporter.epoch()  # plot

            # Save model
            if args.local_rank == 0:
                scheduler.save_checkpoint(model,
                                          args.save_path,
                                          amp=amp,
                                          remove_old=(not is_transformer)
                                          and args.remove_old_checkpoints)
        else:
            start_time_eval = time.time()
            # dev
            metric_dev = validate([model.module], dev_set, args,
                                  reporter.n_epochs + 1, logger)
            scheduler.epoch(metric_dev)  # lr decay
            reporter.epoch(metric_dev, name=args.metric)  # plot
            reporter.add_scalar('dev/' + args.metric, metric_dev)

            if scheduler.is_topk or is_transformer:
                # Save model
                if args.local_rank == 0:
                    scheduler.save_checkpoint(model,
                                              args.save_path,
                                              amp=amp,
                                              remove_old=(not is_transformer)
                                              and args.remove_old_checkpoints)

                # test
                if scheduler.is_topk:
                    for eval_set in eval_sets:
                        validate([model.module], eval_set, args,
                                 reporter.n_epochs, logger)

            logger.info('Evaluation time: %.2f min' %
                        ((time.time() - start_time_eval) / 60))

            # Early stopping
            if scheduler.is_early_stop:
                break

            # Convert to fine-tuning stage
            if reporter.n_epochs == args.convert_to_sgd_epoch:
                scheduler.convert_to_sgd(model,
                                         args.lr,
                                         args.weight_decay,
                                         decay_type='always',
                                         decay_rate=0.5)

        if reporter.n_epochs >= args.n_epochs:
            break
        if args.get('ss_start_epoch', 0) == (ep + 1):
            model.module.trigger_scheduled_sampling()
        if args.get('mocha_quantity_loss_start_epoch', 0) == (ep + 1):
            model.module.trigger_quantity_loss()

    logger.info('Total time: %.2f hour' %
                ((time.time() - start_time_train) / 3600))
    reporter.close()

    return args.save_path
Esempio n. 11
0
def eval_word(models,
              dataset,
              recog_params,
              epoch,
              recog_dir=None,
              streaming=False,
              progressbar=False):
    """Evaluate the word-level model by WER.

    Args:
        models (list): models to evaluate
        dataset (Dataset): evaluation dataset
        recog_params (dict):
        epoch (int):
        recog_dir (str):
        streaming (bool): streaming decoding for the session-level evaluation
        progressbar (bool): visualize the progressbar
    Returns:
        wer (float): Word error rate
        cer (float): Character error rate
        n_oov_total (int): totol number of OOV

    """
    # Reset data counter
    dataset.reset(recog_params['recog_batch_size'])

    if recog_dir is None:
        recog_dir = 'decode_' + dataset.set + '_ep' + str(
            epoch) + '_beam' + str(recog_params['recog_beam_width'])
        recog_dir += '_lp' + str(recog_params['recog_length_penalty'])
        recog_dir += '_cp' + str(recog_params['recog_coverage_penalty'])
        recog_dir += '_' + str(
            recog_params['recog_min_len_ratio']) + '_' + str(
                recog_params['recog_max_len_ratio'])
        recog_dir += '_lm' + str(recog_params['recog_lm_weight'])

        ref_trn_save_path = mkdir_join(models[0].save_path, recog_dir,
                                       'ref.trn')
        hyp_trn_save_path = mkdir_join(models[0].save_path, recog_dir,
                                       'hyp.trn')
    else:
        ref_trn_save_path = mkdir_join(recog_dir, 'ref.trn')
        hyp_trn_save_path = mkdir_join(recog_dir, 'hyp.trn')

    wer, cer = 0, 0
    n_sub_w, n_ins_w, n_del_w = 0, 0, 0
    n_sub_c, n_ins_c, n_del_c = 0, 0, 0
    n_word, n_char = 0, 0
    n_oov_total = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))

    with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path,
                                                     'w') as f_ref:
        while True:
            batch, is_new_epoch = dataset.next(
                recog_params['recog_batch_size'])
            if streaming or recog_params['recog_chunk_sync']:
                best_hyps_id, _ = models[0].decode_streaming(
                    batch['xs'],
                    recog_params,
                    dataset.idx2token[0],
                    exclude_eos=True)
            else:
                best_hyps_id, aws = models[0].decode(
                    batch['xs'],
                    recog_params,
                    idx2token=dataset.idx2token[0] if progressbar else None,
                    exclude_eos=True,
                    refs_id=batch['ys'],
                    utt_ids=batch['utt_ids'],
                    speakers=batch['sessions' if dataset.corpus ==
                                   'swbd' else 'speakers'],
                    ensemble_models=models[1:] if len(models) > 1 else [])

            for b in range(len(batch['xs'])):
                ref = batch['text'][b]
                hyp = dataset.idx2token[0](best_hyps_id[b])

                n_oov_total += hyp.count('<unk>')

                # Resolving UNK
                if recog_params['recog_resolving_unk'] and '<unk>' in hyp:
                    recog_params_char = copy.deepcopy(recog_params)
                    recog_params_char['recog_lm_weight'] = 0
                    recog_params_char['recog_beam_width'] = 1
                    best_hyps_id_char, aw_char = models[0].decode(
                        batch['xs'][b:b + 1],
                        recog_params_char,
                        idx2token=dataset.idx2token[1]
                        if progressbar else None,
                        exclude_eos=True,
                        refs_id=batch['ys_sub1'],
                        utt_ids=batch['utt_ids'],
                        speakers=batch['sessions']
                        if dataset.corpus == 'swbd' else batch['speakers'],
                        task='ys_sub1')
                    # TODO(hirofumi): support ys_sub2 and ys_sub3

                    assert not streaming

                    hyp = resolve_unk(
                        hyp,
                        best_hyps_id_char[0],
                        aws[b],
                        aw_char[0],
                        dataset.idx2token[1],
                        subsample_factor_word=np.prod(models[0].subsample),
                        subsample_factor_char=np.prod(
                            models[0].subsample[:models[0].enc_n_layers_sub1 -
                                                1]))
                    logger.debug('Hyp (after OOV resolution): %s' % hyp)
                    hyp = hyp.replace('*', '')

                    # Compute CER
                    ref_char = ref
                    hyp_char = hyp
                    if dataset.corpus == 'csj':
                        ref_char = ref.replace(' ', '')
                        hyp_char = hyp.replace(' ', '')
                    cer_b, sub_b, ins_b, del_b = compute_wer(
                        ref=list(ref_char),
                        hyp=list(hyp_char),
                        normalize=False)
                    cer += cer_b
                    n_sub_c += sub_b
                    n_ins_c += ins_b
                    n_del_c += del_b
                    n_char += len(ref_char)

                # Write to trn
                speaker = str(batch['speakers'][b]).replace('-', '_')
                if streaming:
                    utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001'
                else:
                    utt_id = str(batch['utt_ids'][b])
                f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n')
                f_hyp.write(hyp + ' (' + speaker + '-' + utt_id + ')\n')
                logger.debug('utt-id: %s' % utt_id)
                logger.debug('Ref: %s' % ref)
                logger.debug('Hyp: %s' % hyp)
                logger.debug('-' * 150)

                if not streaming:
                    # Compute WER
                    wer_b, sub_b, ins_b, del_b = compute_wer(
                        ref=ref.split(' '),
                        hyp=hyp.split(' '),
                        normalize=False)
                    wer += wer_b
                    n_sub_w += sub_b
                    n_ins_w += ins_b
                    n_del_w += del_b
                    n_word += len(ref.split(' '))

                if progressbar:
                    pbar.update(1)

            if is_new_epoch:
                break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    if not streaming:
        wer /= n_word
        n_sub_w /= n_word
        n_ins_w /= n_word
        n_del_w /= n_word

        if n_char > 0:
            cer /= n_char
            n_sub_c /= n_char
            n_ins_c /= n_char
            n_del_c /= n_char

    logger.debug('WER (%s): %.2f %%' % (dataset.set, wer))
    logger.debug('SUB: %.2f / INS: %.2f / DEL: %.2f' %
                 (n_sub_w, n_ins_w, n_del_w))
    logger.debug('CER (%s): %.2f %%' % (dataset.set, cer))
    logger.debug('SUB: %.2f / INS: %.2f / DEL: %.2f' %
                 (n_sub_c, n_ins_c, n_del_c))
    logger.debug('OOV (total): %d' % (n_oov_total))

    return wer, cer, n_oov_total
Esempio n. 12
0
def main():

    # Load configuration
    args, dir_name = parse_args_eval(sys.argv[1:])

    # Setting for logging
    if os.path.isfile(os.path.join(args.recog_dir, 'align.log')):
        os.remove(os.path.join(args.recog_dir, 'align.log'))
    set_logger(os.path.join(args.recog_dir, 'align.log'),
               stdout=args.recog_stdout)

    for i, s in enumerate(args.recog_sets):
        # Align all utterances
        args.min_n_frames = 0
        args.max_n_frames = 1e5

        # Load dataloader
        dataloader = build_dataloader(args=args,
                                      tsv_path=s,
                                      batch_size=args.recog_batch_size)

        if i == 0:
            # Load ASR model
            model = Speech2Text(args, dir_name)
            epoch = int(args.recog_model[0].split('-')[-1])
            if args.recog_n_average > 1:
                # Model averaging for Transformer
                model = average_checkpoints(model,
                                            args.recog_model[0],
                                            n_average=args.recog_n_average)
            else:
                load_checkpoint(args.recog_model[0], model)

            if not args.recog_unit:
                args.recog_unit = args.unit

            logger.info('recog unit: %s' % args.recog_unit)
            logger.info('epoch: %d' % epoch)
            logger.info('batch size: %d' % args.recog_batch_size)

            # GPU setting
            if args.recog_n_gpus >= 1:
                model.cudnn_setting(deterministic=True, benchmark=False)
                model.cuda()

        save_path = mkdir_join(args.recog_dir, 'ctc_forced_alignments')

        # Clean directory
        if save_path is not None and os.path.isdir(save_path):
            shutil.rmtree(save_path)
            os.mkdir(save_path)

        pbar = tqdm(total=len(dataloader))
        while True:
            batch, is_new_epoch = dataloader.next()
            trigger_points = model.ctc_forced_align(batch['xs'],
                                                    batch['ys'])  # `[B, L]`

            for b in range(len(batch['xs'])):
                save_path_spk = mkdir_join(save_path, batch['speakers'][b])
                save_path_utt = mkdir_join(save_path_spk,
                                           batch['utt_ids'][b] + '.txt')

                tokens = dataloader.idx2token[0](batch['ys'][b],
                                                 return_list=True)
                with codecs.open(save_path_utt, 'w', encoding="utf-8") as f:
                    for i, tok in enumerate(tokens):
                        f.write('%s %d\n' % (tok, trigger_points[b, i]))
                    f.write('%s %d\n' %
                            ('<eos>', trigger_points[b, len(tokens)]))

            pbar.update(len(batch['xs']))

            if is_new_epoch:
                break

        pbar.close()
Esempio n. 13
0
def eval_word(models,
              dataloader,
              recog_params,
              epoch,
              recog_dir=None,
              streaming=False,
              progressbar=False,
              edit_distance=True,
              fine_grained=False,
              oracle=False,
              teacher_force=False):
    """Evaluate a word-level model by WER.

    Args:
        models (List): models to evaluate
        dataloader (torch.utils.data.DataLoader): evaluation dataloader
        recog_params (omegaconf.dictconfig.DictConfig): decoding hyperparameters
        epoch (int): current epoch
        recog_dir (str): directory path to save hypotheses
        streaming (bool): streaming decoding for session-level evaluation
        progressbar (bool): visualize progressbar
        edit_distance (bool): calculate edit-distance (can be skipped for RTF calculation)
        fine_grained (bool): calculate fine-grained WER distributions based on input lengths
        oracle (bool): calculate oracle WER
        teacher_force (bool): conduct decoding in teacher-forcing mode
    Returns:
        wer (float): Word error rate
        cer (float): Character error rate
        n_oov_total (int): total number of OOV

    """
    if recog_dir is None:
        recog_dir = 'decode_' + dataloader.set + '_ep' + \
            str(epoch) + '_beam' + str(recog_params.get('recog_beam_width'))
        recog_dir += '_lp' + str(recog_params.get('recog_length_penalty'))
        recog_dir += '_cp' + str(recog_params.get('recog_coverage_penalty'))
        recog_dir += '_' + str(recog_params.get('recog_min_len_ratio')) + '_' + \
            str(recog_params.get('recog_max_len_ratio'))
        recog_dir += '_lm' + str(recog_params.get('recog_lm_weight'))

        ref_trn_path = mkdir_join(models[0].save_path, recog_dir, 'ref.trn')
        hyp_trn_path = mkdir_join(models[0].save_path, recog_dir, 'hyp.trn')
    else:
        ref_trn_path = mkdir_join(recog_dir, 'ref.trn')
        hyp_trn_path = mkdir_join(recog_dir, 'hyp.trn')

    wer, cer = 0, 0
    n_sub_w, n_ins_w, n_del_w = 0, 0, 0
    n_sub_c, n_ins_c, n_del_c = 0, 0, 0
    n_word, n_char = 0, 0
    wer_dist = {}  # calculate WER distribution based on input lengths
    n_oov_total = 0

    wer_oracle = 0
    n_oracle_hit = 0
    n_utt = 0

    # Reset data counter
    dataloader.reset(recog_params.get('recog_batch_size'))

    if progressbar:
        pbar = tqdm(total=len(dataloader))

    with codecs.open(hyp_trn_path, 'w', encoding='utf-8') as f_hyp, \
            codecs.open(ref_trn_path, 'w', encoding='utf-8') as f_ref:
        for batch in dataloader:
            speakers = batch['sessions' if dataloader.corpus ==
                             'swbd' else 'speakers']
            if streaming or recog_params.get('recog_block_sync'):
                nbest_hyps_id = models[0].decode_streaming(
                    batch['xs'],
                    recog_params,
                    dataloader.idx2token[0],
                    exclude_eos=True,
                    speaker=speakers[0])[0]
            else:
                nbest_hyps_id, aws = models[0].decode(
                    batch['xs'],
                    recog_params,
                    idx2token=dataloader.idx2token[0],
                    exclude_eos=True,
                    refs_id=batch['ys'],
                    utt_ids=batch['utt_ids'],
                    speakers=speakers,
                    ensemble_models=models[1:] if len(models) > 1 else [])

            for b in range(len(batch['xs'])):
                ref = batch['text'][b]
                nbest_hyps = [
                    dataloader.idx2token[0](hyp_id)
                    for hyp_id in nbest_hyps_id[b]
                ]
                n_oov_total += nbest_hyps[0].count('<unk>')

                # Resolving UNK
                if recog_params.get(
                        'recog_resolving_unk') and '<unk>' in nbest_hyps[0]:
                    recog_params_char = copy.deepcopy(recog_params)
                    recog_params_char['recog_lm_weight'] = 0
                    recog_params_char['recog_beam_width'] = 1
                    best_hyps_id_char, aw_char = models[0].decode(
                        batch['xs'][b:b + 1],
                        recog_params_char,
                        idx2token=dataloader.idx2token[1],
                        exclude_eos=True,
                        refs_id=batch['ys_sub1'],
                        utt_ids=batch['utt_ids'],
                        speakers=speakers,
                        task='ys_sub1')
                    # TODO(hirofumi): support ys_sub2

                    assert not streaming

                    nbest_hyps[0] = resolve_unk(
                        nbest_hyps[0],
                        best_hyps_id_char[0],
                        aws[b],
                        aw_char[0],
                        dataloader.idx2token[1],
                        subsample_factor_word=np.prod(models[0].subsample),
                        subsample_factor_char=np.prod(
                            models[0].subsample[:models[0].enc_n_layers_sub1 -
                                                1]))
                    logger.debug('Hyp (after OOV resolution): %s' %
                                 nbest_hyps[0])
                    nbest_hyps[0] = nbest_hyps[0].replace('*', '')

                    # Compute CER
                    ref_char = ref
                    hyp_char = nbest_hyps[0]
                    if dataloader.corpus == 'csj':
                        ref_char = ref_char.replace(' ', '')
                        hyp_char = hyp_char.replace(' ', '')
                    err_b, sub_b, ins_b, del_b = compute_wer(
                        ref=list(ref_char), hyp=list(hyp_char))
                    cer += err_b
                    n_sub_c += sub_b
                    n_ins_c += ins_b
                    n_del_c += del_b
                    n_char += len(ref_char)

                # Write to trn
                speaker = str(batch['speakers'][b]).replace('-', '_')
                if streaming:
                    utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001'
                else:
                    utt_id = str(batch['utt_ids'][b])
                f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n')
                f_hyp.write(nbest_hyps[0] + ' (' + speaker + '-' + utt_id +
                            ')\n')
                logger.debug('utt-id (%d/%d): %s' %
                             (n_utt + 1, len(dataloader), utt_id))
                logger.debug('Ref: %s' % ref)
                logger.debug('Hyp: %s' % nbest_hyps[0])
                logger.debug('-' * 150)

                if edit_distance and not streaming:
                    # Compute WER
                    err_b, sub_b, ins_b, del_b = compute_wer(
                        ref=ref.split(' '), hyp=nbest_hyps[0].split(' '))
                    wer += err_b
                    n_sub_w += sub_b
                    n_ins_w += ins_b
                    n_del_w += del_b
                    n_word += len(ref.split(' '))

                    # Compute oracle WER
                    if oracle and len(nbest_hyps) > 1:
                        wers_b = [err_b] + [
                            compute_wer(ref=ref.split(' '),
                                        hyp=hyp_n.split(' '))[0]
                            for hyp_n in nbest_hyps[1:]
                        ]
                        oracle_idx = np.argmin(np.array(wers_b))
                        if oracle_idx == 0:
                            n_oracle_hit += len(batch['utt_ids'])
                        wer_oracle += wers_b[oracle_idx]
                        # NOTE: OOV resolution is not considered

                    if fine_grained:
                        xlen_bin = (batch['xlens'][b] // 200 + 1) * 200
                        if xlen_bin in wer_dist.keys():
                            wer_dist[xlen_bin] += [err_b / 100]
                        else:
                            wer_dist[xlen_bin] = [err_b / 100]

            n_utt += len(batch['utt_ids'])
            if progressbar:
                pbar.update(len(batch['utt_ids']))

    if progressbar:
        pbar.close()

    # Reset data counters
    dataloader.reset(is_new_epoch=True)

    if edit_distance and not streaming:
        wer /= n_word
        n_sub_w /= n_word
        n_ins_w /= n_word
        n_del_w /= n_word

        if n_char > 0:
            cer /= n_char
            n_sub_c /= n_char
            n_ins_c /= n_char
            n_del_c /= n_char

        if recog_params.get('recog_beam_width') > 1:
            logger.info('WER (%s): %.2f %%' % (dataloader.set, wer))
            logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' %
                        (n_sub_w, n_ins_w, n_del_w))
            logger.info('CER (%s): %.2f %%' % (dataloader.set, cer))
            logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' %
                        (n_sub_c, n_ins_c, n_del_c))
            logger.info('OOV (total): %d' % (n_oov_total))

        if oracle:
            wer_oracle /= n_word
            oracle_hit_rate = n_oracle_hit * 100 / n_utt
            logger.info('Oracle WER (%s): %.2f %%' %
                        (dataloader.set, wer_oracle))
            logger.info('Oracle hit rate (%s): %.2f %%' %
                        (dataloader.set, oracle_hit_rate))

        if fine_grained:
            for len_bin, wers in sorted(wer_dist.items(), key=lambda x: x[0]):
                logger.info('  WER (%s): %.2f %% (%d)' %
                            (dataloader.set, sum(wers) / len(wers), len_bin))

    return wer, cer, n_oov_total
Esempio n. 14
0
def eval_wordpiece(models, dataset, recog_params, epoch,
                   recog_dir=None, streaming=False, progressbar=False):
    """Evaluate the wordpiece-level model by WER.

    Args:
        models (list): models to evaluate
        dataset (Dataset): evaluation dataset
        recog_params (dict):
        epoch (int):
        recog_dir (str):
        streaming (bool): streaming decoding for the session-level evaluation
        progressbar (bool): visualize the progressbar
    Returns:
        wer (float): Word error rate
        cer (float): Character error rate

    """
    # Reset data counter
    dataset.reset()

    if recog_dir is None:
        recog_dir = 'decode_' + dataset.set + '_ep' + str(epoch) + '_beam' + str(recog_params['recog_beam_width'])
        recog_dir += '_lp' + str(recog_params['recog_length_penalty'])
        recog_dir += '_cp' + str(recog_params['recog_coverage_penalty'])
        recog_dir += '_' + str(recog_params['recog_min_len_ratio']) + '_' + str(recog_params['recog_max_len_ratio'])
        recog_dir += '_lm' + str(recog_params['recog_lm_weight'])

        ref_trn_save_path = mkdir_join(models[0].save_path, recog_dir, 'ref.trn')
        hyp_trn_save_path = mkdir_join(models[0].save_path, recog_dir, 'hyp.trn')
    else:
        ref_trn_save_path = mkdir_join(recog_dir, 'ref.trn')
        hyp_trn_save_path = mkdir_join(recog_dir, 'hyp.trn')

    wer, cer = 0, 0
    n_sub_w, n_ins_w, n_del_w = 0, 0, 0
    n_sub_c, n_ins_c, n_del_c = 0, 0, 0
    n_word, n_char = 0, 0
    if progressbar:
        pbar = tqdm(total=len(dataset))

    with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path, 'w') as f_ref:
        while True:
            batch, is_new_epoch = dataset.next(recog_params['recog_batch_size'])
            if streaming or recog_params['recog_chunk_sync']:
                best_hyps_id, _ = models[0].decode_streaming(
                    batch['xs'], recog_params, dataset.idx2token[0],
                    exclude_eos=True)
            else:
                best_hyps_id, _ = models[0].decode(
                    batch['xs'], recog_params, dataset.idx2token[0],
                    exclude_eos=True,
                    refs_id=batch['ys'],
                    utt_ids=batch['utt_ids'],
                    speakers=batch['sessions' if dataset.corpus == 'swbd' else 'speakers'],
                    ensemble_models=models[1:] if len(models) > 1 else [])

            for b in range(len(batch['xs'])):
                ref = batch['text'][b]
                if ref[0] == '<':
                    ref = ref.split('>')[1]
                hyp = dataset.idx2token[0](best_hyps_id[b])

                # Write to trn
                speaker = str(batch['speakers'][b]).replace('-', '_')
                if streaming:
                    utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001'
                else:
                    utt_id = str(batch['utt_ids'][b])
                f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n')
                f_hyp.write(hyp + ' (' + speaker + '-' + utt_id + ')\n')
                logger.debug('utt-id: %s' % utt_id)
                logger.debug('Ref: %s' % ref)
                logger.debug('Hyp: %s' % hyp)
                logger.debug('-' * 150)

                if not streaming:
                    # Compute WER
                    wer_b, sub_b, ins_b, del_b = compute_wer(ref=ref.split(' '),
                                                             hyp=hyp.split(' '),
                                                             normalize=False)
                    wer += wer_b
                    n_sub_w += sub_b
                    n_ins_w += ins_b
                    n_del_w += del_b
                    n_word += len(ref.split(' '))

                    # Compute CER
                    if dataset.corpus == 'csj':
                        ref = ref.replace(' ', '')
                        hyp = hyp.replace(' ', '')
                    cer_b, sub_b, ins_b, del_b = compute_wer(ref=list(ref),
                                                             hyp=list(hyp),
                                                             normalize=False)
                    cer += cer_b
                    n_sub_c += sub_b
                    n_ins_c += ins_b
                    n_del_c += del_b
                    n_char += len(ref)

                if progressbar:
                    pbar.update(1)

            if is_new_epoch:
                break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    if not streaming:
        wer /= n_word
        n_sub_w /= n_word
        n_ins_w /= n_word
        n_del_w /= n_word

        cer /= n_char
        n_sub_c /= n_char
        n_ins_c /= n_char
        n_del_c /= n_char

    logger.debug('WER (%s): %.2f %%' % (dataset.set, wer))
    logger.debug('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_w, n_ins_w, n_del_w))
    logger.debug('CER (%s): %.2f %%' % (dataset.set, cer))
    logger.debug('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_c, n_ins_c, n_del_c))

    return wer, cer
Esempio n. 15
0
def eval_phone(models,
               dataloader,
               recog_params,
               epoch,
               recog_dir=None,
               streaming=False,
               progressbar=False,
               fine_grained=False,
               oracle=False,
               teacher_force=False):
    """Evaluate a phone-level model by PER.

    Args:
        models (List): models to evaluate
        dataloader (torch.utils.data.DataLoader): evaluation dataloader
        recog_params (dict):
        epoch (int):
        recog_dir (str):
        streaming (bool): streaming decoding for the session-level evaluation
        progressbar (bool): visualize the progressbar
        oracle (bool): calculate oracle PER
        fine_grained (bool): calculate fine-grained PER distributions based on input lengths
        teacher_force (bool): conduct decoding in teacher-forcing mode
    Returns:
        per (float): Phone error rate

    """
    if recog_dir is None:
        recog_dir = 'decode_' + dataloader.set + '_ep' + str(
            epoch) + '_beam' + str(recog_params['recog_beam_width'])
        recog_dir += '_lp' + str(recog_params['recog_length_penalty'])
        recog_dir += '_cp' + str(recog_params['recog_coverage_penalty'])
        recog_dir += '_' + str(
            recog_params['recog_min_len_ratio']) + '_' + str(
                recog_params['recog_max_len_ratio'])

        ref_trn_path = mkdir_join(models[0].save_path, recog_dir, 'ref.trn')
        hyp_trn_path = mkdir_join(models[0].save_path, recog_dir, 'hyp.trn')
    else:
        ref_trn_path = mkdir_join(recog_dir, 'ref.trn')
        hyp_trn_path = mkdir_join(recog_dir, 'hyp.trn')

    per = 0
    n_sub, n_ins, n_del = 0, 0, 0
    n_phone = 0
    per_dist = {}  # calculate PER distribution based on input lengths

    per_oracle = 0
    n_oracle_hit = 0
    n_utt = 0

    # Reset data counter
    dataloader.reset(recog_params['recog_batch_size'])

    if progressbar:
        pbar = tqdm(total=len(dataloader))

    with codecs.open(hyp_trn_path, 'w', encoding='utf-8') as f_hyp, \
            codecs.open(ref_trn_path, 'w', encoding='utf-8') as f_ref:
        while True:
            batch, is_new_epoch = dataloader.next(
                recog_params['recog_batch_size'])
            if streaming or recog_params['recog_block_sync']:
                nbest_hyps_id = models[0].decode_streaming(
                    batch['xs'],
                    recog_params,
                    dataloader.idx2token[0],
                    exclude_eos=True)[0]
            else:
                nbest_hyps_id = models[0].decode(
                    batch['xs'],
                    recog_params,
                    idx2token=dataloader.idx2token[0] if progressbar else None,
                    exclude_eos=True,
                    refs_id=batch['ys'],
                    utt_ids=batch['utt_ids'],
                    speakers=batch['sessions' if dataloader.corpus ==
                                   'swbd' else 'speakers'],
                    ensemble_models=models[1:] if len(models) > 1 else [])[0]

            for b in range(len(batch['xs'])):
                ref = batch['text'][b]
                nbest_hyps = [
                    dataloader.idx2token[0](hyp_id)
                    for hyp_id in nbest_hyps_id[b]
                ]

                # Write to trn
                speaker = str(batch['speakers'][b]).replace('-', '_')
                if streaming:
                    utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001'
                else:
                    utt_id = str(batch['utt_ids'][b])
                f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n')
                f_hyp.write(nbest_hyps[0] + ' (' + speaker + '-' + utt_id +
                            ')\n')
                logger.debug('utt-id: %s' % utt_id)
                logger.debug('Ref: %s' % ref)
                logger.debug('Hyp: %s' % nbest_hyps[0])
                logger.debug('-' * 150)

                if not streaming:
                    # Compute PER
                    err_b, sub_b, ins_b, del_b = compute_wer(
                        ref=ref.split(' '), hyp=nbest_hyps[0].split(' '))
                    per += err_b
                    n_sub += sub_b
                    n_ins += ins_b
                    n_del += del_b
                    n_phone += len(ref.split(' '))

                    # Compute oracle PER
                    if oracle and len(nbest_hyps) > 1:
                        pers_b = [err_b] + [
                            compute_wer(ref=ref.split(' '),
                                        hyp=hyp_n.split(' '))[0]
                            for hyp_n in nbest_hyps[1:]
                        ]
                        oracle_idx = np.argmin(np.array(pers_b))
                        if oracle_idx == 0:
                            n_oracle_hit += 1
                        per_oracle += pers_b[oracle_idx]

                n_utt += 1
                if progressbar:
                    pbar.update(1)

            if is_new_epoch:
                break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataloader.reset()

    if not streaming:
        per /= n_phone
        n_sub /= n_phone
        n_ins /= n_phone
        n_del /= n_phone

        if recog_params['recog_beam_width'] > 1:
            logger.info('PER (%s): %.2f %%' % (dataloader.set, per))
            logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' %
                        (n_sub, n_ins, n_del))

        if oracle:
            per_oracle /= n_phone
            oracle_hit_rate = n_oracle_hit * 100 / n_utt
            logger.info('Oracle PER (%s): %.2f %%' %
                        (dataloader.set, per_oracle))
            logger.info('Oracle hit rate (%s): %.2f %%' %
                        (dataloader.set, oracle_hit_rate))

        if fine_grained:
            for len_bin, pers in sorted(per_dist.items(), key=lambda x: x[0]):
                logger.info('  PER (%s): %.2f %% (%d)' %
                            (dataloader.set, sum(pers) / len(pers), len_bin))

    return per
Esempio n. 16
0
def main():

    args = parse_args_train(sys.argv[1:])

    # Load a conf file
    if args.resume:
        conf = load_config(
            os.path.join(os.path.dirname(args.resume), 'conf.yml'))
        for k, v in conf.items():
            if k != 'resume':
                setattr(args, k, v)

    # Load dataset
    batch_size = args.batch_size * args.n_gpus if args.n_gpus >= 1 else args.batch_size
    train_set = Dataset(corpus=args.corpus,
                        tsv_path=args.train_set,
                        dict_path=args.dict,
                        nlsyms=args.nlsyms,
                        unit=args.unit,
                        wp_model=args.wp_model,
                        batch_size=batch_size,
                        n_epochs=args.n_epochs,
                        min_n_tokens=args.min_n_tokens,
                        bptt=args.bptt,
                        shuffle=args.shuffle,
                        backward=args.backward,
                        serialize=args.serialize)
    dev_set = Dataset(corpus=args.corpus,
                      tsv_path=args.dev_set,
                      dict_path=args.dict,
                      nlsyms=args.nlsyms,
                      unit=args.unit,
                      wp_model=args.wp_model,
                      batch_size=batch_size,
                      bptt=args.bptt,
                      backward=args.backward,
                      serialize=args.serialize)
    eval_sets = [
        Dataset(corpus=args.corpus,
                tsv_path=s,
                dict_path=args.dict,
                nlsyms=args.nlsyms,
                unit=args.unit,
                wp_model=args.wp_model,
                batch_size=1,
                bptt=args.bptt,
                backward=args.backward,
                serialize=args.serialize) for s in args.eval_sets
    ]

    args.vocab = train_set.vocab

    # Set save path
    if args.resume:
        save_path = os.path.dirname(args.resume)
        dir_name = os.path.basename(save_path)
    else:
        dir_name = set_lm_name(args)
        save_path = mkdir_join(
            args.model_save_dir,
            '_'.join(os.path.basename(args.train_set).split('.')[:-1]),
            dir_name)
        save_path = set_save_path(save_path)  # avoid overwriting

    # Set logger
    set_logger(os.path.join(save_path, 'train.log'), stdout=args.stdout)

    # Model setting
    model = build_lm(args, save_path)

    if not args.resume:
        # Save the conf file as a yaml file
        save_config(vars(args), os.path.join(save_path, 'conf.yml'))

        # Save the nlsyms, dictionary, and wp_model
        if args.nlsyms:
            shutil.copy(args.nlsyms, os.path.join(save_path, 'nlsyms.txt'))
        shutil.copy(args.dict, os.path.join(save_path, 'dict.txt'))
        if args.unit == 'wp':
            shutil.copy(args.wp_model, os.path.join(save_path, 'wp.model'))

        for k, v in sorted(vars(args).items(), key=lambda x: x[0]):
            logger.info('%s: %s' % (k, str(v)))

        # Count total parameters
        for n in sorted(list(model.num_params_dict.keys())):
            n_params = model.num_params_dict[n]
            logger.info("%s %d" % (n, n_params))
        logger.info("Total %.2f M parameters" %
                    (model.total_parameters / 1000000))
        logger.info(model)

    # Set optimizer
    resume_epoch = 0
    if args.resume:
        epoch = int(args.resume.split('-')[-1])
        optimizer = set_optimizer(
            model,
            'sgd' if epoch > args.convert_to_sgd_epoch else args.optimizer,
            args.lr, args.weight_decay)
    else:
        optimizer = set_optimizer(model, args.optimizer, args.lr,
                                  args.weight_decay)

    # Wrap optimizer by learning rate scheduler
    is_transformer = args.lm_type in ['transformer', 'transformer_xl']
    optimizer = LRScheduler(
        optimizer,
        args.lr,
        decay_type=args.lr_decay_type,
        decay_start_epoch=args.lr_decay_start_epoch,
        decay_rate=args.lr_decay_rate,
        decay_patient_n_epochs=args.lr_decay_patient_n_epochs,
        early_stop_patient_n_epochs=args.early_stop_patient_n_epochs,
        warmup_start_lr=args.warmup_start_lr,
        warmup_n_steps=args.warmup_n_steps,
        model_size=getattr(args, 'transformer_d_model', 0),
        factor=args.lr_factor,
        noam=is_transformer,
        save_checkpoints_topk=1)

    if args.resume:
        # Restore the last saved model
        load_checkpoint(args.resume, model, optimizer)

        # Resume between convert_to_sgd_epoch -1 and convert_to_sgd_epoch
        if resume_epoch == args.convert_to_sgd_epoch:
            optimizer.convert_to_sgd(model,
                                     args.lr,
                                     args.weight_decay,
                                     decay_type='always',
                                     decay_rate=0.5)

    # GPU setting
    use_apex = args.train_dtype in ["O0", "O1", "O2", "O3"]
    amp = None
    if args.n_gpus >= 1:
        model.cudnn_setting(
            deterministic=not (is_transformer or args.cudnn_benchmark),
            benchmark=args.cudnn_benchmark)
        model.cuda()

        # Mix precision training setting
        if use_apex:
            from apex import amp
            model, optimizer.optimizer = amp.initialize(
                model, optimizer.optimizer, opt_level=args.train_dtype)
            amp.init()
            if args.resume:
                load_checkpoint(args.resume, amp=amp)
        model = CustomDataParallel(model,
                                   device_ids=list(range(0, args.n_gpus)))
    else:
        model = CPUWrapperLM(model)

    # Set process name
    logger.info('PID: %s' % os.getpid())
    logger.info('USERNAME: %s' % os.uname()[1])
    logger.info('#GPU: %d' % torch.cuda.device_count())
    setproctitle(args.job_name if args.job_name else dir_name)

    # Set reporter
    reporter = Reporter(save_path)

    hidden = None
    start_time_train = time.time()
    start_time_epoch = time.time()
    start_time_step = time.time()
    pbar_epoch = tqdm(total=len(train_set))
    accum_n_steps = 0
    n_steps = optimizer.n_steps * args.accum_grad_n_steps
    while True:
        # Compute loss in the training set
        ys_train, is_new_epoch = train_set.next()
        accum_n_steps += 1

        loss, hidden, observation = model(ys_train, hidden)
        reporter.add(observation)
        if use_apex:
            with amp.scale_loss(loss, optimizer.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        loss.detach()  # Trancate the graph
        if args.accum_grad_n_steps == 1 or accum_n_steps >= args.accum_grad_n_steps:
            if args.clip_grad_norm > 0:
                total_norm = torch.nn.utils.clip_grad_norm_(
                    model.module.parameters(), args.clip_grad_norm)
                reporter.add_tensorboard_scalar('total_norm', total_norm)
            optimizer.step()
            optimizer.zero_grad()
            accum_n_steps = 0
        loss_train = loss.item()
        del loss
        hidden = model.module.repackage_state(hidden)
        reporter.add_tensorboard_scalar('learning_rate', optimizer.lr)
        # NOTE: loss/acc/ppl are already added in the model
        reporter.step()
        pbar_epoch.update(ys_train.shape[0] * (ys_train.shape[1] - 1))
        n_steps += 1
        # NOTE: n_steps is different from the step counter in Noam Optimizer

        if n_steps % args.print_step == 0:
            # Compute loss in the dev set
            ys_dev = dev_set.next(bptt=args.bptt)[0]
            loss, _, observation = model(ys_dev, None, is_eval=True)
            reporter.add(observation, is_eval=True)
            loss_dev = loss.item()
            del loss
            reporter.step(is_eval=True)

            duration_step = time.time() - start_time_step
            logger.info(
                "step:%d(ep:%.2f) loss:%.3f(%.3f)/lr:%.5f/bs:%d (%.2f min)" %
                (n_steps, optimizer.n_epochs + train_set.epoch_detail,
                 loss_train, loss_dev, optimizer.lr, ys_train.shape[0],
                 duration_step / 60))
            start_time_step = time.time()

        # Save fugures of loss and accuracy
        if n_steps % (args.print_step * 10) == 0:
            reporter.snapshot()
            model.module.plot_attention()

        # Save checkpoint and evaluate model per epoch
        if is_new_epoch:
            duration_epoch = time.time() - start_time_epoch
            logger.info('========== EPOCH:%d (%.2f min) ==========' %
                        (optimizer.n_epochs + 1, duration_epoch / 60))

            if optimizer.n_epochs + 1 < args.eval_start_epoch:
                optimizer.epoch()  # lr decay
                reporter.epoch()  # plot

                # Save the model
                optimizer.save_checkpoint(model,
                                          save_path,
                                          remove_old=not is_transformer,
                                          amp=amp)
            else:
                start_time_eval = time.time()
                # dev
                model.module.reset_length(args.bptt)
                ppl_dev, _ = eval_ppl([model.module],
                                      dev_set,
                                      batch_size=1,
                                      bptt=args.bptt)
                model.module.reset_length(args.bptt)
                optimizer.epoch(ppl_dev)  # lr decay
                reporter.epoch(ppl_dev, name='perplexity')  # plot
                logger.info('PPL (%s, ep:%d): %.2f' %
                            (dev_set.set, optimizer.n_epochs, ppl_dev))

                if optimizer.is_topk or is_transformer:
                    # Save the model
                    optimizer.save_checkpoint(model,
                                              save_path,
                                              remove_old=not is_transformer,
                                              amp=amp)

                    # test
                    ppl_test_avg = 0.
                    for eval_set in eval_sets:
                        model.module.reset_length(args.bptt)
                        ppl_test, _ = eval_ppl([model.module],
                                               eval_set,
                                               batch_size=1,
                                               bptt=args.bptt)
                        model.module.reset_length(args.bptt)
                        logger.info(
                            'PPL (%s, ep:%d): %.2f' %
                            (eval_set.set, optimizer.n_epochs, ppl_test))
                        ppl_test_avg += ppl_test
                    if len(eval_sets) > 0:
                        logger.info('PPL (avg., ep:%d): %.2f' %
                                    (optimizer.n_epochs,
                                     ppl_test_avg / len(eval_sets)))

                duration_eval = time.time() - start_time_eval
                logger.info('Evaluation time: %.2f min' % (duration_eval / 60))

                # Early stopping
                if optimizer.is_early_stop:
                    break

                # Convert to fine-tuning stage
                if optimizer.n_epochs == args.convert_to_sgd_epoch:
                    optimizer.convert_to_sgd(model,
                                             args.lr,
                                             args.weight_decay,
                                             decay_type='always',
                                             decay_rate=0.5)

            pbar_epoch = tqdm(total=len(train_set))

            if optimizer.n_epochs >= args.n_epochs:
                break

            start_time_step = time.time()
            start_time_epoch = time.time()

    duration_train = time.time() - start_time_train
    logger.info('Total time: %.2f hour' % (duration_train / 3600))

    reporter.tf_writer.close()
    pbar_epoch.close()

    return save_path
Esempio n. 17
0
def eval_wordpiece(models, dataset, recog_params, epoch,
                   recog_dir=None, streaming=False, progressbar=False,
                   fine_grained=False):
    """Evaluate the wordpiece-level model by WER.

    Args:
        models (list): models to evaluate
        dataset (Dataset): evaluation dataset
        recog_params (dict):
        epoch (int):
        recog_dir (str):
        streaming (bool): streaming decoding for the session-level evaluation
        progressbar (bool): visualize the progressbar
        fine_grained (bool): calculate fine-grained WER distributions based on input lengths
    Returns:
        wer (float): Word error rate
        cer (float): Character error rate

    """
    # Reset data counter
    dataset.reset(recog_params['recog_batch_size'])

    if recog_dir is None:
        recog_dir = 'decode_' + dataset.set + '_ep' + str(epoch) + '_beam' + str(recog_params['recog_beam_width'])
        recog_dir += '_lp' + str(recog_params['recog_length_penalty'])
        recog_dir += '_cp' + str(recog_params['recog_coverage_penalty'])
        recog_dir += '_' + str(recog_params['recog_min_len_ratio']) + '_' + str(recog_params['recog_max_len_ratio'])
        recog_dir += '_lm' + str(recog_params['recog_lm_weight'])

        ref_trn_save_path = mkdir_join(models[0].save_path, recog_dir, 'ref.trn')
        hyp_trn_save_path = mkdir_join(models[0].save_path, recog_dir, 'hyp.trn')
    else:
        ref_trn_save_path = mkdir_join(recog_dir, 'ref.trn')
        hyp_trn_save_path = mkdir_join(recog_dir, 'hyp.trn')

    wer, cer = 0, 0
    n_sub_w, n_ins_w, n_del_w = 0, 0, 0
    n_sub_c, n_ins_c, n_del_c = 0, 0, 0
    n_word, n_char = 0, 0
    n_streamable, quantity_rate, n_utt = 0, 0, 0
    last_success_frame_ratio = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))

    # calculate WER distribution based on input lengths
    wer_dist = {}

    with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path, 'w') as f_ref:
        while True:
            batch, is_new_epoch = dataset.next(recog_params['recog_batch_size'])
            if streaming or recog_params['recog_chunk_sync']:
                best_hyps_id, _ = models[0].decode_streaming(
                    batch['xs'], recog_params, dataset.idx2token[0],
                    exclude_eos=True)
            else:
                best_hyps_id, _ = models[0].decode(
                    batch['xs'], recog_params,
                    idx2token=dataset.idx2token[0] if progressbar else None,
                    exclude_eos=True,
                    refs_id=batch['ys'],
                    utt_ids=batch['utt_ids'],
                    speakers=batch['sessions' if dataset.corpus == 'swbd' else 'speakers'],
                    ensemble_models=models[1:] if len(models) > 1 else [])

            for b in range(len(batch['xs'])):
                ref = batch['text'][b]
                if ref[0] == '<':
                    ref = ref.split('>')[1]
                hyp = dataset.idx2token[0](best_hyps_id[b])

                # Write to trn
                speaker = str(batch['speakers'][b]).replace('-', '_')
                if streaming:
                    utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001'
                else:
                    utt_id = str(batch['utt_ids'][b])
                f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n')
                f_hyp.write(hyp + ' (' + speaker + '-' + utt_id + ')\n')
                logger.debug('utt-id: %s' % utt_id)
                logger.debug('Ref: %s' % ref)
                logger.debug('Hyp: %s' % hyp)
                logger.debug('-' * 150)

                if not streaming:
                    # Compute WER
                    wer_b, sub_b, ins_b, del_b = compute_wer(ref=ref.split(' '),
                                                             hyp=hyp.split(' '),
                                                             normalize=False)
                    wer += wer_b
                    n_sub_w += sub_b
                    n_ins_w += ins_b
                    n_del_w += del_b
                    n_word += len(ref.split(' '))

                    if fine_grained:
                        xlen_bin = (batch['xlens'][b] // 200 + 1) * 200
                        if xlen_bin in wer_dist.keys():
                            wer_dist[xlen_bin] += [wer_b / 100]
                        else:
                            wer_dist[xlen_bin] = [wer_b / 100]

                    # Compute CER
                    if dataset.corpus == 'csj':
                        ref = ref.replace(' ', '')
                        hyp = hyp.replace(' ', '')
                    cer_b, sub_b, ins_b, del_b = compute_wer(ref=list(ref),
                                                             hyp=list(hyp),
                                                             normalize=False)
                    cer += cer_b
                    n_sub_c += sub_b
                    n_ins_c += ins_b
                    n_del_c += del_b
                    n_char += len(ref)
                    if models[0].streamable():
                        n_streamable += 1
                    else:
                        last_success_frame_ratio += models[0].last_success_frame_ratio()
                    quantity_rate += models[0].quantity_rate()
                    n_utt += 1

                if progressbar:
                    pbar.update(1)

            if is_new_epoch:
                break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    if not streaming:
        wer /= n_word
        n_sub_w /= n_word
        n_ins_w /= n_word
        n_del_w /= n_word

        cer /= n_char
        n_sub_c /= n_char
        n_ins_c /= n_char
        n_del_c /= n_char

        if n_utt - n_streamable > 0:
            last_success_frame_ratio /= (n_utt - n_streamable)
        n_streamable /= n_utt
        quantity_rate /= n_utt

        if fine_grained:
            for len_bin, wers in sorted(wer_dist.items(), key=lambda x: x[0]):
                logger.info('  WER (%s): %.2f %% (%d)' % (dataset.set, sum(wers) / len(wers), len_bin))

    logger.debug('WER (%s): %.2f %%' % (dataset.set, wer))
    logger.debug('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_w, n_ins_w, n_del_w))
    logger.debug('CER (%s): %.2f %%' % (dataset.set, cer))
    logger.debug('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_c, n_ins_c, n_del_c))

    logger.info('Streamablility (%s): %.2f %%' % (dataset.set, n_streamable * 100))
    logger.info('Quantity rate (%s): %.2f %%' % (dataset.set, quantity_rate * 100))
    logger.info('Last success frame ratio (%s): %.2f %%' % (dataset.set, last_success_frame_ratio))

    return wer, cer
Esempio n. 18
0
def eval_wordpiece_bleu(models,
                        dataloader,
                        params,
                        epoch=-1,
                        rank=0,
                        save_dir=None,
                        streaming=False,
                        progressbar=False,
                        edit_distance=True,
                        fine_grained=False,
                        oracle=False,
                        teacher_force=False):
    """Evaluate a wordpiece-level model by corpus-level BLEU.

    Args:
        models (List): models to evaluate
        dataloader (torch.utils.data.DataLoader): evaluation dataloader
        params (omegaconf.dictconfig.DictConfig): decoding hyperparameters
        epoch (int): current epoch
        rank (int): rank of current process group
        save_dir (str): directory path to save hypotheses
        streaming (bool): streaming decoding for session-level evaluation
        progressbar (bool): visualize progressbar
        edit_distance (bool): calculate edit-distance (can be skipped for RTF calculation)
        fine_grained (bool): calculate fine-grained corpus-level BLEU distributions based on input lengths
        oracle (bool): calculate oracle corpsu-level BLEU
        teacher_force (bool): conduct decoding in teacher-forcing mode
    Returns:
        c_bleu (float): corpus-level 4-gram BLEU

    """
    if save_dir is None:
        save_dir = 'decode_' + dataloader.set + '_ep' + \
            str(epoch) + '_beam' + str(params.get('recog_beam_width'))
        save_dir += '_lp' + str(params.get('recog_length_penalty'))
        save_dir += '_cp' + str(params.get('recog_coverage_penalty'))
        save_dir += '_' + str(params.get('recog_min_len_ratio')) + '_' + \
            str(params.get('recog_max_len_ratio'))
        save_dir += '_lm' + str(params.get('recog_lm_weight'))

        ref_trn_path = mkdir_join(models[0].save_path,
                                  save_dir,
                                  'ref.trn',
                                  rank=rank)
        hyp_trn_path = mkdir_join(models[0].save_path,
                                  save_dir,
                                  'hyp.trn',
                                  rank=rank)
    else:
        ref_trn_path = mkdir_join(save_dir, 'ref.trn', rank=rank)
        hyp_trn_path = mkdir_join(save_dir, 'hyp.trn', rank=rank)

    list_of_references_dist = {
    }  # calculate corpus-level BLEU distribution bucketed by input lengths
    hypotheses_dist = {}

    hypotheses_oracle = []
    n_oracle_hit = 0
    n_utt = 0

    # Reset data counter
    dataloader.reset(params.get('recog_batch_size'), 'seq')

    if progressbar:
        pbar = tqdm(total=len(dataloader))

    list_of_references = []
    hypotheses = []

    if rank == 0:
        f_hyp = codecs.open(hyp_trn_path, 'w', encoding='utf-8')
        f_ref = codecs.open(ref_trn_path, 'w', encoding='utf-8')

    for batch in dataloader:
        if streaming or params.get('recog_block_sync'):
            nbest_hyps_id = models[0].decode_streaming(
                batch['xs'],
                params,
                dataloader.idx2token[0],
                exclude_eos=True,
                speaker=batch['speakers'][0])[0]
        else:
            nbest_hyps_id = models[0].decode(
                batch['xs'],
                params,
                idx2token=dataloader.idx2token[0],
                exclude_eos=True,
                refs_id=batch['ys'],
                utt_ids=batch['utt_ids'],
                speakers=batch['speakers'],
                ensemble_models=models[1:] if len(models) > 1 else [],
                teacher_force=teacher_force)[0]

        for b in range(len(batch['xs'])):
            ref = batch['text'][b]
            if ref[0] == '<':
                ref = ref.split('>')[1]
            nbest_hyps = [
                dataloader.idx2token[0](hyp_id) for hyp_id in nbest_hyps_id[b]
            ]

            # Write to trn
            speaker = str(batch['speakers'][b]).replace('-', '_')
            if streaming:
                utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001'
            else:
                utt_id = str(batch['utt_ids'][b])
            if rank == 0:
                f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n')
                f_hyp.write(nbest_hyps[0] + ' (' + speaker + '-' + utt_id +
                            ')\n')
            logger.debug('utt-id (%d/%d): %s' %
                         (n_utt + 1, len(dataloader), utt_id))
            logger.debug('Ref: %s' % ref)
            logger.debug('Hyp: %s' % nbest_hyps[0])
            logger.debug('-' * 150)

            if edit_distance and not streaming:
                list_of_references += [[ref.split(' ')]]
                hypotheses += [nbest_hyps[0].split(' ')]

                if fine_grained:
                    xlen_bin = (batch['xlens'][b] // 200 + 1) * 200
                    if xlen_bin in hypotheses_dist.keys():
                        list_of_references_dist[xlen_bin] += [[ref.split(' ')]]
                        hypotheses_dist[xlen_bin] += [hypotheses[-1]]
                    else:
                        list_of_references_dist[xlen_bin] = [[ref.split(' ')]]
                        hypotheses_dist[xlen_bin] = [hypotheses[-1]]

                # Compute oracle corpus-level BLEU (selected by sentence-level BLEU)
                if oracle and len(nbest_hyps) > 1:
                    s_blues_b = [
                        sentence_bleu(ref.split(' '), hyp_n.split(' '))
                        for hyp_n in nbest_hyps
                    ]
                    oracle_idx = np.argmax(np.array(s_blues_b))
                    if oracle_idx == 0:
                        n_oracle_hit += len(batch['utt_ids'])
                    hypotheses_oracle += [nbest_hyps[oracle_idx].split(' ')]

        n_utt += len(batch['utt_ids'])
        if progressbar:
            pbar.update(len(batch['utt_ids']))

    if rank == 0:
        f_hyp.close()
        f_ref.close()
    if progressbar:
        pbar.close()

    # Reset data counters
    dataloader.reset(is_new_epoch=True)

    c_bleu = corpus_bleu(list_of_references, hypotheses) * 100

    if edit_distance and not streaming:
        if oracle:
            c_bleu_oracle = corpus_bleu(list_of_references,
                                        hypotheses_oracle) * 100
            oracle_hit_rate = n_oracle_hit * 100 / n_utt
            logger.info('Oracle corpus-level BLEU (%s): %.2f %%' %
                        (dataloader.set, c_bleu_oracle))
            logger.info('Oracle hit rate (%s): %.2f %%' %
                        (dataloader.set, oracle_hit_rate))

        if fine_grained:
            for len_bin, hypotheses_bin in sorted(hypotheses_dist.items(),
                                                  key=lambda x: x[0]):
                c_bleu_bin = corpus_bleu(list_of_references_dist[len_bin],
                                         hypotheses_bin) * 100
                logger.info('  corpus-level BLEU (%s): %.2f %% (%d)' %
                            (dataloader.set, c_bleu_bin, len_bin))

    logger.info('Corpus-level BLEU (%s): %.2f %%' % (dataloader.set, c_bleu))

    return c_bleu
Esempio n. 19
0
def main():

    args = parse()

    # Load a conf file
    if args.resume:
        conf = load_config(
            os.path.join(os.path.dirname(args.resume), 'conf.yml'))
        for k, v in conf.items():
            if k != 'resume':
                setattr(args, k, v)

    # Load dataset
    train_set = Dataset(corpus=args.corpus,
                        tsv_path=args.train_set,
                        dict_path=args.dict,
                        nlsyms=args.nlsyms,
                        unit=args.unit,
                        wp_model=args.wp_model,
                        batch_size=args.batch_size * args.n_gpus,
                        n_epochs=args.n_epochs,
                        min_n_tokens=args.min_n_tokens,
                        bptt=args.bptt,
                        backward=args.backward,
                        serialize=args.serialize)
    dev_set = Dataset(corpus=args.corpus,
                      tsv_path=args.dev_set,
                      dict_path=args.dict,
                      nlsyms=args.nlsyms,
                      unit=args.unit,
                      wp_model=args.wp_model,
                      batch_size=args.batch_size * args.n_gpus,
                      bptt=args.bptt,
                      backward=args.backward,
                      serialize=args.serialize)
    eval_sets = []
    for s in args.eval_sets:
        eval_sets += [
            Dataset(corpus=args.corpus,
                    tsv_path=s,
                    dict_path=args.dict,
                    nlsyms=args.nlsyms,
                    unit=args.unit,
                    wp_model=args.wp_model,
                    batch_size=1,
                    bptt=args.bptt,
                    backward=args.backward,
                    serialize=args.serialize)
        ]

    args.vocab = train_set.vocab

    # Set save path
    if args.resume:
        save_path = os.path.dirname(args.resume)
        dir_name = os.path.basename(save_path)
    else:
        dir_name = make_model_name(args)
        save_path = mkdir_join(
            args.model,
            '_'.join(os.path.basename(args.train_set).split('.')[:-1]),
            dir_name)
        save_path = set_save_path(save_path)  # avoid overwriting

    # Set logger
    logger = set_logger(os.path.join(save_path, 'train.log'), key='training')

    # Model setting
    if 'gated_conv' in args.lm_type:
        model = GatedConvLM(args)
    else:
        model = RNNLM(args)
    model.save_path = save_path

    if args.resume:
        # Set optimizer
        epoch = int(args.resume.split('-')[-1])
        model.set_optimizer(
            optimizer='sgd'
            if epoch > conf['convert_to_sgd_epoch'] + 1 else conf['optimizer'],
            learning_rate=float(conf['learning_rate']),  # on-the-fly
            weight_decay=float(conf['weight_decay']))

        # Restore the last saved model
        model, checkpoint = load_checkpoint(model, args.resume, resume=True)
        lr_controller = checkpoint['lr_controller']
        epoch = checkpoint['epoch']
        step = checkpoint['step']
        ppl_dev_best = checkpoint['metric_dev_best']

        # Resume between convert_to_sgd_epoch and convert_to_sgd_epoch + 1
        if epoch == conf['convert_to_sgd_epoch'] + 1:
            model.set_optimizer(optimizer='sgd',
                                learning_rate=args.learning_rate,
                                weight_decay=float(conf['weight_decay']))
            logger.info('========== Convert to SGD ==========')
    else:
        # Save the conf file as a yaml file
        save_config(vars(args), os.path.join(model.save_path, 'conf.yml'))

        # Save the nlsyms, dictionar, and wp_model
        if args.nlsyms:
            shutil.copy(args.nlsyms, os.path.join(model.save_path,
                                                  'nlsyms.txt'))
        shutil.copy(args.dict, os.path.join(model.save_path, 'dict.txt'))
        if args.unit == 'wp':
            shutil.copy(args.wp_model, os.path.join(model.save_path,
                                                    'wp.model'))

        for k, v in sorted(vars(args).items(), key=lambda x: x[0]):
            logger.info('%s: %s' % (k, str(v)))

        # Count total parameters
        for n in sorted(list(model.num_params_dict.keys())):
            nparams = model.num_params_dict[n]
            logger.info("%s %d" % (n, nparams))
        logger.info("Total %.2f M parameters" %
                    (model.total_parameters / 1000000))
        logger.info(model)

        # Set optimizer
        model.set_optimizer(optimizer=args.optimizer,
                            learning_rate=float(args.learning_rate),
                            weight_decay=float(args.weight_decay))

        epoch, step = 1, 1
        ppl_dev_best = 10000

        # Set learning rate controller
        lr_controller = Controller(
            learning_rate=float(args.learning_rate),
            decay_type=args.decay_type,
            decay_start_epoch=args.decay_start_epoch,
            decay_rate=args.decay_rate,
            decay_patient_n_epochs=args.decay_patient_n_epochs,
            lower_better=True,
            best_value=ppl_dev_best)

    train_set.epoch = epoch - 1  # start from index:0

    # GPU setting
    if args.n_gpus >= 1:
        model = CustomDataParallel(model,
                                   device_ids=list(range(0, args.n_gpus, 1)),
                                   deterministic=False,
                                   benchmark=True)
        model.cuda()

    logger.info('PID: %s' % os.getpid())
    logger.info('USERNAME: %s' % os.uname()[1])

    # Set process name
    if args.job_name:
        setproctitle(args.job_name)
    else:
        setproctitle(dir_name)

    # Set reporter
    reporter = Reporter(model.module.save_path, tensorboard=True)

    hidden = None
    start_time_train = time.time()
    start_time_epoch = time.time()
    start_time_step = time.time()
    not_improved_epoch = 0
    pbar_epoch = tqdm(total=len(train_set))
    while True:
        # Compute loss in the training set
        ys_train, is_new_epoch = train_set.next()

        model.module.optimizer.zero_grad()
        loss, hidden, reporter = model(ys_train, hidden, reporter)
        if len(model.device_ids) > 1:
            loss.backward(torch.ones(len(model.device_ids)))
        else:
            loss.backward()
        loss.detach()  # Trancate the graph
        if args.clip_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.module.parameters(),
                                           args.clip_grad_norm)
        model.module.optimizer.step()
        loss_train = loss.item()
        del loss
        if 'gated_conv' not in args.lm_type:
            hidden = model.module.repackage_hidden(hidden)
        reporter.step(is_eval=False)

        if step % args.print_step == 0:
            # Compute loss in the dev set
            ys_dev = dev_set.next()[0]
            loss, _, reporter = model(ys_dev, None, reporter, is_eval=True)
            loss_dev = loss.item()
            del loss
            reporter.step(is_eval=True)

            duration_step = time.time() - start_time_step
            logger.info(
                "step:%d(ep:%.2f) loss:%.3f(%.3f)/ppl:%.3f(%.3f)/lr:%.5f/bs:%d (%.2f min)"
                % (step, train_set.epoch_detail, loss_train, loss_dev,
                   np.exp(loss_train), np.exp(loss_dev), lr_controller.lr,
                   ys_train.shape[0], duration_step / 60))
            start_time_step = time.time()
        step += args.n_gpus
        pbar_epoch.update(ys_train.shape[0] * (ys_train.shape[1] - 1))

        # Save fugures of loss and accuracy
        if step % (args.print_step * 10) == 0:
            reporter.snapshot()

        # Save checkpoint and evaluate model per epoch
        if is_new_epoch:
            duration_epoch = time.time() - start_time_epoch
            logger.info('========== EPOCH:%d (%.2f min) ==========' %
                        (epoch, duration_epoch / 60))

            if epoch < args.eval_start_epoch:
                # Save the model
                save_checkpoint(model.module,
                                model.module.save_path,
                                lr_controller,
                                epoch,
                                step - 1,
                                ppl_dev_best,
                                remove_old_checkpoints=True)
            else:
                start_time_eval = time.time()
                # dev
                ppl_dev, _ = eval_ppl([model.module],
                                      dev_set,
                                      batch_size=1,
                                      bptt=args.bptt)
                logger.info('PPL (%s): %.2f' % (dev_set.set, ppl_dev))

                # Update learning rate
                model.module.optimizer = lr_controller.decay(
                    model.module.optimizer, epoch=epoch, value=ppl_dev)

                if ppl_dev < ppl_dev_best:
                    ppl_dev_best = ppl_dev
                    not_improved_epoch = 0
                    logger.info('||||| Best Score |||||')

                    # Save the model
                    save_checkpoint(model.module,
                                    model.module.save_path,
                                    lr_controller,
                                    epoch,
                                    step - 1,
                                    ppl_dev_best,
                                    remove_old_checkpoints=True)

                    # test
                    ppl_test_avg = 0.
                    for eval_set in eval_sets:
                        ppl_test, _ = eval_ppl([model.module],
                                               eval_set,
                                               batch_size=1,
                                               bptt=args.bptt)
                        logger.info('PPL (%s): %.2f' %
                                    (eval_set.set, ppl_test))
                        ppl_test_avg += ppl_test
                    if len(eval_sets) > 0:
                        logger.info('PPL (avg.): %.2f' %
                                    (ppl_test_avg / len(eval_sets)))
                else:
                    not_improved_epoch += 1

                duration_eval = time.time() - start_time_eval
                logger.info('Evaluation time: %.2f min' % (duration_eval / 60))

                # Early stopping
                if not_improved_epoch == args.not_improved_patient_n_epochs:
                    break

                # Convert to fine-tuning stage
                if epoch == args.convert_to_sgd_epoch:
                    model.module.set_optimizer(
                        'sgd',
                        learning_rate=args.learning_rate,
                        weight_decay=float(args.weight_decay))
                    lr_controller = Controller(
                        learning_rate=args.learning_rate,
                        decay_type='epoch',
                        decay_start_epoch=epoch,
                        decay_rate=0.5,
                        lower_better=True)
                    logger.info('========== Convert to SGD ==========')

            pbar_epoch = tqdm(total=len(train_set))

            if epoch == args.n_epochs:
                break

            start_time_step = time.time()
            start_time_epoch = time.time()
            epoch += 1

    duration_train = time.time() - start_time_train
    logger.info('Total time: %.2f hour' % (duration_train / 3600))

    if reporter.tensorboard:
        reporter.tf_writer.close()
    pbar_epoch.close()

    return model.module.save_path
Esempio n. 20
0
def main():

    args = parse()

    # Load a conf file
    dir_name = os.path.dirname(args.recog_model[0])
    conf = load_config(os.path.join(dir_name, 'conf.yml'))

    # Overwrite conf
    for k, v in conf.items():
        if 'recog' not in k:
            setattr(args, k, v)
    recog_params = vars(args)

    # Setting for logging
    if os.path.isfile(os.path.join(args.recog_dir, 'plot.log')):
        os.remove(os.path.join(args.recog_dir, 'plot.log'))
    logger = set_logger(os.path.join(args.recog_dir, 'plot.log'),
                        key='decoding', stdout=args.recog_stdout)

    for i, s in enumerate(args.recog_sets):
        subsample_factor = 1
        subsample = [int(s) for s in args.subsample.split('_')]
        if args.conv_poolings:
            for p in args.conv_poolings.split('_'):
                p = int(p.split(',')[0].replace('(', ''))
                if p > 1:
                    subsample_factor *= p
        subsample_factor *= np.prod(subsample)

        # Load dataset
        dataset = Dataset(corpus=args.corpus,
                          tsv_path=s,
                          dict_path=os.path.join(dir_name, 'dict.txt'),
                          dict_path_sub1=os.path.join(dir_name, 'dict_sub1.txt') if os.path.isfile(
                              os.path.join(dir_name, 'dict_sub1.txt')) else False,
                          nlsyms=args.nlsyms,
                          wp_model=os.path.join(dir_name, 'wp.model'),
                          unit=args.unit,
                          unit_sub1=args.unit_sub1,
                          batch_size=args.recog_batch_size,
                          is_test=True)

        if i == 0:
            # Load the ASR model
            model = Speech2Text(args, dir_name)
            model = load_checkpoint(model, args.recog_model[0])[0]
            epoch = int(args.recog_model[0].split('-')[-1])

            if not args.recog_unit:
                args.recog_unit = args.unit

            logger.info('recog unit: %s' % args.recog_unit)
            logger.info('epoch: %d' % (epoch - 1))
            logger.info('batch size: %d' % args.recog_batch_size)

            # GPU setting
            model.cuda()

        save_path = mkdir_join(args.recog_dir, 'ctc_probs')

        # Clean directory
        if save_path is not None and os.path.isdir(save_path):
            shutil.rmtree(save_path)
            os.mkdir(save_path)

        while True:
            batch, is_new_epoch = dataset.next(recog_params['recog_batch_size'])
            best_hyps_id, _, _ = model.decode(batch['xs'], recog_params,
                                              exclude_eos=False)

            # Get CTC probs
            ctc_probs, indices_topk, xlens = model.get_ctc_probs(
                batch['xs'], temperature=1, topk=min(100, model.vocab))
            # NOTE: ctc_probs: '[B, T, topk]'

            for b in range(len(batch['xs'])):
                tokens = dataset.idx2token[0](best_hyps_id[b], return_list=True)
                spk = batch['speakers'][b]

                plot_ctc_probs(
                    ctc_probs[b, :xlens[b]],
                    indices_topk[b],
                    n_frames=xlens[b],
                    subsample_factor=subsample_factor,
                    spectrogram=batch['xs'][b][:, :dataset.input_dim],
                    save_path=mkdir_join(save_path, spk, batch['utt_ids'][b] + '.png'),
                    figsize=(20, 8))

                hyp = ' '.join(tokens)
                logger.info('utt-id: %s' % batch['utt_ids'][b])
                logger.info('Ref: %s' % batch['text'][b].lower())
                logger.info('Hyp: %s' % hyp)
                logger.info('-' * 50)

            if is_new_epoch:
                break
Esempio n. 21
0
def eval_char(models,
              dataloader,
              params,
              epoch=-1,
              rank=0,
              save_dir=None,
              streaming=False,
              progressbar=False,
              task_idx=0,
              edit_distance=True,
              fine_grained=False,
              oracle=False,
              teacher_force=False):
    """Evaluate a character-level model by WER & CER.

    Args:
        models (List): models to evaluate
        dataloader (torch.utils.data.DataLoader): evaluation dataloader
        params (omegaconf.dictconfig.DictConfig): decoding hyperparameters
        epoch (int): current epoch
        rank (int): rank of current process group
        save_dir (str): directory path to save hypotheses
        streaming (bool): streaming decoding for session-level evaluation
        progressbar (bool): visualize progressbar
        edit_distance (bool): calculate edit-distance (can be skipped for RTF calculation)
        task_idx (int): index of target task in interest
            0: main task
            1: sub task
            2: sub sub task
        fine_grained (bool): calculate fine-grained WER distributions based on input lengths
        oracle (bool): calculate oracle WER
        teacher_force (bool): conduct decoding in teacher-forcing mode
    Returns:
        wer (float): Word error rate
        cer (float): Character error rate

    """
    if save_dir is None:
        save_dir = 'decode_' + dataloader.set + '_ep' + \
            str(epoch) + '_beam' + str(params.get('recog_beam_width'))
        save_dir += '_lp' + str(params.get('recog_length_penalty'))
        save_dir += '_cp' + str(params.get('recog_coverage_penalty'))
        save_dir += '_' + str(params.get('recog_min_len_ratio')) + '_' + \
            str(params.get('recog_max_len_ratio'))
        save_dir += '_lm' + str(params.get('recog_lm_weight'))

        ref_trn_path = mkdir_join(models[0].save_path,
                                  save_dir,
                                  'ref.trn',
                                  rank=rank)
        hyp_trn_path = mkdir_join(models[0].save_path,
                                  save_dir,
                                  'hyp.trn',
                                  rank=rank)
    else:
        ref_trn_path = mkdir_join(save_dir, 'ref.trn', rank=rank)
        hyp_trn_path = mkdir_join(save_dir, 'hyp.trn', rank=rank)

    wer, cer = 0, 0
    n_sub_w, n_ins_w, n_del_w = 0, 0, 0
    n_sub_c, n_ins_c, n_del_c = 0, 0, 0
    n_word, n_char = 0, 0
    cer_dist = {}  # calculate CER distribution based on input lengths

    cer_oracle = 0
    n_oracle_hit = 0

    n_streamable, quantity_rate, n_utt = 0, 0, 0
    last_success_frame_ratio = 0

    # Reset data counter
    dataloader.reset(params.get('recog_batch_size'), 'seq')

    if progressbar:
        pbar = tqdm(total=len(dataloader))

    if rank == 0:
        f_hyp = codecs.open(hyp_trn_path, 'w', encoding='utf-8')
        f_ref = codecs.open(ref_trn_path, 'w', encoding='utf-8')

    if task_idx == 0:
        task = 'ys'
    elif task_idx == 1:
        task = 'ys_sub1'
    elif task_idx == 2:
        task = 'ys_sub2'
    elif task_idx == 3:
        task = 'ys_sub3'

    for batch in dataloader:
        speakers = batch['sessions' if dataloader.corpus ==
                         'swbd' else 'speakers']
        if streaming or params.get('recog_block_sync'):
            nbest_hyps_id = models[0].decode_streaming(batch['xs'],
                                                       params,
                                                       dataloader.idx2token[0],
                                                       exclude_eos=True,
                                                       speaker=speakers[0])[0]
        else:
            nbest_hyps_id = models[0].decode(
                batch['xs'],
                params,
                idx2token=dataloader.idx2token[0],
                exclude_eos=True,
                refs_id=batch['ys'] if task_idx == 0 else batch['ys_sub' +
                                                                str(task_idx)],
                utt_ids=batch['utt_ids'],
                speakers=speakers,
                task=task,
                ensemble_models=models[1:] if len(models) > 1 else [],
                teacher_force=teacher_force)[0]

        for b in range(len(batch['xs'])):
            # assert len(batch['xs']) == 1, 'batch is 1'
            ref = batch['text'][b]
            nbest_hyps_tmp = [
                dataloader.idx2token[0](hyp_id) for hyp_id in nbest_hyps_id[b]
            ]
            # print(nbest_hyps_id)
            # print(nbest_hyps_tmp)
            # assert False, 'vv'
            # Truncate the first and last spaces for the char_space unit
            nbest_hyps = []
            for hyp in nbest_hyps_tmp:
                if len(hyp) > 0 and hyp[0] == ' ':
                    hyp = hyp[1:]
                if len(hyp) > 0 and hyp[-1] == ' ':
                    hyp = hyp[:-1]
                nbest_hyps.append(hyp)

            # Write to trn
            speaker = str(batch['speakers'][b]).replace('-', '_')
            if streaming:
                utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001'
            else:
                utt_id = str(batch['utt_ids'][b])
            if rank == 0:
                f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n')
                f_hyp.write(nbest_hyps[0] + ' (' + speaker + '-' + utt_id +
                            ')\n')
            logger.debug('utt-id (%d/%d): %s' %
                         (n_utt + 1, len(dataloader), utt_id))
            logger.debug('Ref: %s' % ref)
            logger.debug('Hyp: %s' % nbest_hyps[0])
            logger.debug('-' * 150)

            if edit_distance and not streaming:
                if ('char' in dataloader.unit and 'nowb' not in dataloader.unit
                    ) or (task_idx > 0 and dataloader.unit_sub1 == 'char'):
                    # Compute WER
                    err_b, sub_b, ins_b, del_b = compute_wer(
                        ref=ref.split(' '), hyp=nbest_hyps[0].split(' '))
                    wer += err_b
                    n_sub_w += sub_b
                    n_ins_w += ins_b
                    n_del_w += del_b
                    n_word += len(ref.split(' '))
                    # NOTE: sentence error rate for Chinese

                # Compute CER
                if dataloader.corpus == 'csj':
                    ref = ref.replace(' ', '')
                    nbest_hyps[0] = nbest_hyps[0].replace(' ', '')
                err_b, sub_b, ins_b, del_b = compute_wer(ref=list(ref),
                                                         hyp=list(
                                                             nbest_hyps[0]))
                cer += err_b
                n_sub_c += sub_b
                n_ins_c += ins_b
                n_del_c += del_b
                n_char += len(ref)

                # Compute oracle CER
                if oracle and len(nbest_hyps) > 1:
                    cers_b = [err_b] + [
                        compute_wer(ref=list(ref), hyp=list(hyp_n))[0]
                        for hyp_n in nbest_hyps[1:]
                    ]
                    oracle_idx = np.argmin(np.array(cers_b))
                    if oracle_idx == 0:
                        n_oracle_hit += len(batch['utt_ids'])
                    cer_oracle += cers_b[oracle_idx]

                if fine_grained:
                    xlen_bin = (batch['xlens'][b] // 200 + 1) * 200
                    if xlen_bin in cer_dist.keys():
                        cer_dist[xlen_bin] += [err_b / 100]
                    else:
                        cer_dist[xlen_bin] = [err_b / 100]

                if models[0].streamable():
                    n_streamable += len(batch['utt_ids'])
                else:
                    last_success_frame_ratio += models[
                        0].last_success_frame_ratio()
                quantity_rate += models[0].quantity_rate()

        n_utt += len(batch['utt_ids'])
        if progressbar:
            pbar.update(len(batch['utt_ids']))

    if rank == 0:
        f_hyp.close()
        f_ref.close()
    if progressbar:
        pbar.close()

    # Reset data counters
    dataloader.reset(is_new_epoch=True)

    if edit_distance and not streaming:
        if ('char' in dataloader.unit and 'nowb' not in dataloader.unit) or (
                task_idx > 0 and dataloader.unit_sub1 == 'char'):
            wer /= n_word
            n_sub_w /= n_word
            n_ins_w /= n_word
            n_del_w /= n_word
        else:
            wer = n_sub_w = n_ins_w = n_del_w = 0

        cer /= n_char
        n_sub_c /= n_char
        n_ins_c /= n_char
        n_del_c /= n_char

        if n_utt - n_streamable > 0:
            last_success_frame_ratio /= (n_utt - n_streamable)
        n_streamable /= n_utt
        quantity_rate /= n_utt

        if params.get('recog_beam_width') > 1:
            logger.info('WER (%s): %.2f %%' % (dataloader.set, wer))
            logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' %
                        (n_sub_w, n_ins_w, n_del_w))
            logger.info('CER (%s): %.2f %%' % (dataloader.set, cer))
            logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' %
                        (n_sub_c, n_ins_c, n_del_c))

        if oracle:
            cer_oracle /= n_char
            oracle_hit_rate = n_oracle_hit * 100 / n_utt
            logger.info('Oracle CER (%s): %.2f %%' %
                        (dataloader.set, cer_oracle))
            logger.info('Oracle hit rate (%s): %.2f %%' %
                        (dataloader.set, oracle_hit_rate))

        if fine_grained:
            for len_bin, cers in sorted(cer_dist.items(), key=lambda x: x[0]):
                logger.info('  CER (%s): %.2f %% (%d)' %
                            (dataloader.set, sum(cers) / len(cers), len_bin))

        logger.info('Streamability (%s): %.2f %%' %
                    (dataloader.set, n_streamable * 100))
        logger.info('Quantity rate (%s): %.2f %%' %
                    (dataloader.set, quantity_rate * 100))
        logger.info('Last success frame ratio (%s): %.2f %%' %
                    (dataloader.set, last_success_frame_ratio))

    return wer, cer
Esempio n. 22
0
def eval_phone(models,
               dataset,
               recog_params,
               epoch,
               recog_dir=None,
               progressbar=False):
    """Evaluate a phone-level model by PER.

    Args:
        models (list): models to evaluate
        dataset (Dataset): evaluation dataset
        recog_params (dict):
        epoch (int):
        recog_dir (str):
        progressbar (bool): visualize the progressbar
    Returns:
        per (float): Phone error rate

    """
    # Reset data counter
    dataset.reset()

    if recog_dir is None:
        recog_dir = 'decode_' + dataset.set + '_ep' + str(
            epoch) + '_beam' + str(recog_params['recog_beam_width'])
        recog_dir += '_lp' + str(recog_params['recog_length_penalty'])
        recog_dir += '_cp' + str(recog_params['recog_coverage_penalty'])
        recog_dir += '_' + str(
            recog_params['recog_min_len_ratio']) + '_' + str(
                recog_params['recog_max_len_ratio'])

        ref_trn_save_path = mkdir_join(models[0].save_path, recog_dir,
                                       'ref.trn')
        hyp_trn_save_path = mkdir_join(models[0].save_path, recog_dir,
                                       'hyp.trn')
    else:
        ref_trn_save_path = mkdir_join(recog_dir, 'ref.trn')
        hyp_trn_save_path = mkdir_join(recog_dir, 'hyp.trn')

    per = 0
    n_sub, n_ins, n_del = 0, 0, 0
    n_phone = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))

    with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path,
                                                     'w') as f_ref:
        while True:
            batch, is_new_epoch = dataset.getitem(
                recog_params['recog_batch_size'])
            best_hyps_id, _, _ = models[0].decode(
                batch['xs'],
                recog_params,
                dataset.idx2token[0],
                exclude_eos=True,
                refs_id=batch['ys'],
                utt_ids=batch['utt_ids'],
                speakers=batch['sessions']
                if dataset.corpus == 'swbd' else batch['speakers'],
                ensemble_models=models[1:] if len(models) > 1 else [])

            for b in range(len(batch['xs'])):
                ref = batch['text'][b]
                hyp = dataset.idx2token[0](best_hyps_id[b])

                # Write to trn
                utt_id = str(batch['utt_ids'][b])
                speaker = str(batch['speakers'][b]).replace('-', '_')
                f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n')
                f_hyp.write(hyp + ' (' + speaker + '-' + utt_id + ')\n')
                logger.info('utt-id: %s' % batch['utt_ids'][b])
                logger.info('Ref: %s' % ref)
                logger.info('Hyp: %s' % hyp)
                logger.info('-' * 150)

                # Compute PER
                per_b, sub_b, ins_b, del_b = compute_wer(ref=ref.split(' '),
                                                         hyp=hyp.split(' '),
                                                         normalize=False)
                per += per_b
                n_sub += sub_b
                n_ins += ins_b
                n_del += del_b
                n_phone += len(ref.split(' '))

                if progressbar:
                    pbar.update(1)

            if is_new_epoch:
                break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    per /= n_phone
    n_sub /= n_phone
    n_ins /= n_phone
    n_del /= n_phone

    logger.info('PER (%s): %.2f %%' % (dataset.set, per))
    logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub, n_ins, n_del))

    return per
Esempio n. 23
0
def eval_wordpiece_bleu(models,
                        dataset,
                        recog_params,
                        epoch,
                        recog_dir=None,
                        streaming=False,
                        progressbar=False,
                        fine_grained=False):
    """Evaluate the wordpiece-level model by BLEU.

    Args:
        models (list): models to evaluate
        dataset (Dataset): evaluation dataset
        recog_params (dict):
        epoch (int):
        recog_dir (str):
        streaming (bool): streaming decoding for the session-level evaluation
        progressbar (bool): visualize the progressbar
        fine_grained (bool): calculate fine-grained BLEU distributions based on input lengths
    Returns:
        bleu (float): 4-gram BLEU

    """
    if recog_dir is None:
        recog_dir = 'decode_' + dataset.set + '_ep' + str(
            epoch) + '_beam' + str(recog_params['recog_beam_width'])
        recog_dir += '_lp' + str(recog_params['recog_length_penalty'])
        recog_dir += '_cp' + str(recog_params['recog_coverage_penalty'])
        recog_dir += '_' + str(
            recog_params['recog_min_len_ratio']) + '_' + str(
                recog_params['recog_max_len_ratio'])
        recog_dir += '_lm' + str(recog_params['recog_lm_weight'])

        ref_trn_path = mkdir_join(models[0].save_path, recog_dir, 'ref.trn')
        hyp_trn_path = mkdir_join(models[0].save_path, recog_dir, 'hyp.trn')
    else:
        ref_trn_path = mkdir_join(recog_dir, 'ref.trn')
        hyp_trn_path = mkdir_join(recog_dir, 'hyp.trn')

    s_bleu = 0
    n_sentence = 0
    s_bleu_dist = {
    }  # calculate sentence-level BLEU distribution based on input lengths

    # Reset data counter
    dataset.reset(recog_params['recog_batch_size'])

    if progressbar:
        pbar = tqdm(total=len(dataset))

    list_of_references = []
    hypotheses = []

    with codecs.open(hyp_trn_path, 'w', encoding='utf-8') as f_hyp, \
            codecs.open(ref_trn_path, 'w', encoding='utf-8') as f_ref:
        while True:
            batch, is_new_epoch = dataset.next(
                recog_params['recog_batch_size'])
            if streaming or recog_params['recog_chunk_sync']:
                best_hyps_id, _ = models[0].decode_streaming(
                    batch['xs'],
                    recog_params,
                    dataset.idx2token[0],
                    exclude_eos=True)
            else:
                best_hyps_id, _ = models[0].decode(
                    batch['xs'],
                    recog_params,
                    idx2token=dataset.idx2token[0] if progressbar else None,
                    exclude_eos=True,
                    refs_id=batch['ys'],
                    utt_ids=batch['utt_ids'],
                    speakers=batch['sessions' if dataset.corpus ==
                                   'swbd' else 'speakers'],
                    ensemble_models=models[1:] if len(models) > 1 else [])

            for b in range(len(batch['xs'])):
                ref = batch['text'][b]
                if ref[0] == '<':
                    ref = ref.split('>')[1]
                hyp = dataset.idx2token[0](best_hyps_id[b])

                # Write to trn
                # speaker = str(batch['speakers'][b]).replace('-', '_')
                if streaming:
                    utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001'
                else:
                    utt_id = str(batch['utt_ids'][b])
                f_ref.write(ref + '\n')
                f_hyp.write(hyp + '\n')
                logger.debug('utt-id: %s' % utt_id)
                logger.debug('Ref: %s' % ref)
                logger.debug('Hyp: %s' % hyp)
                logger.debug('-' * 150)

                if not streaming:
                    list_of_references += [[ref.split(' ')]]
                    hypotheses += [hyp.split(' ')]
                    n_sentence += 1

                    # Compute sentence-level BLEU
                    if fine_grained:
                        s_bleu_b = sentence_bleu([ref.split(' ')],
                                                 hyp.split(' '))
                        s_bleu += s_bleu_b * 100

                        xlen_bin = (batch['xlens'][b] // 200 + 1) * 200
                        if xlen_bin in s_bleu_dist.keys():
                            s_bleu_dist[xlen_bin] += [s_bleu_b / 100]
                        else:
                            s_bleu_dist[xlen_bin] = [s_bleu_b / 100]

                if progressbar:
                    pbar.update(1)

            if is_new_epoch:
                break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    c_bleu = corpus_bleu(list_of_references, hypotheses) * 100
    if not streaming and fine_grained:
        s_bleu /= n_sentence
        for len_bin, s_bleus in sorted(s_bleu_dist.items(),
                                       key=lambda x: x[0]):
            logger.info('  sentence-level BLEU (%s): %.2f %% (%d)' %
                        (dataset.set, sum(s_bleus) / len(s_bleus), len_bin))

    logger.debug('Corpus-level BLEU (%s): %.2f %%' % (dataset.set, c_bleu))

    return c_bleu
Esempio n. 24
0
def main():

    args = parse_args_train(sys.argv[1:])
    args_init = copy.deepcopy(args)
    args_teacher = copy.deepcopy(args)

    # Load a conf file
    if args.resume:
        conf = load_config(os.path.join(os.path.dirname(args.resume), 'conf.yml'))
        for k, v in conf.items():
            if k != 'resume':
                setattr(args, k, v)
    recog_params = vars(args)

    args = compute_susampling_factor(args)

    # Load dataset
    batch_size = args.batch_size * args.n_gpus if args.n_gpus >= 1 else args.batch_size
    train_set = Dataset(corpus=args.corpus,
                        tsv_path=args.train_set,
                        tsv_path_sub1=args.train_set_sub1,
                        tsv_path_sub2=args.train_set_sub2,
                        dict_path=args.dict,
                        dict_path_sub1=args.dict_sub1,
                        dict_path_sub2=args.dict_sub2,
                        nlsyms=args.nlsyms,
                        unit=args.unit,
                        unit_sub1=args.unit_sub1,
                        unit_sub2=args.unit_sub2,
                        wp_model=args.wp_model,
                        wp_model_sub1=args.wp_model_sub1,
                        wp_model_sub2=args.wp_model_sub2,
                        batch_size=batch_size,
                        n_epochs=args.n_epochs,
                        min_n_frames=args.min_n_frames,
                        max_n_frames=args.max_n_frames,
                        shuffle_bucket=args.shuffle_bucket,
                        sort_by='input',
                        short2long=args.sort_short2long,
                        sort_stop_epoch=args.sort_stop_epoch,
                        dynamic_batching=args.dynamic_batching,
                        ctc=args.ctc_weight > 0,
                        ctc_sub1=args.ctc_weight_sub1 > 0,
                        ctc_sub2=args.ctc_weight_sub2 > 0,
                        subsample_factor=args.subsample_factor,
                        subsample_factor_sub1=args.subsample_factor_sub1,
                        subsample_factor_sub2=args.subsample_factor_sub2,
                        discourse_aware=args.discourse_aware)
    dev_set = Dataset(corpus=args.corpus,
                      tsv_path=args.dev_set,
                      tsv_path_sub1=args.dev_set_sub1,
                      tsv_path_sub2=args.dev_set_sub2,
                      dict_path=args.dict,
                      dict_path_sub1=args.dict_sub1,
                      dict_path_sub2=args.dict_sub2,
                      nlsyms=args.nlsyms,
                      unit=args.unit,
                      unit_sub1=args.unit_sub1,
                      unit_sub2=args.unit_sub2,
                      wp_model=args.wp_model,
                      wp_model_sub1=args.wp_model_sub1,
                      wp_model_sub2=args.wp_model_sub2,
                      batch_size=batch_size,
                      min_n_frames=args.min_n_frames,
                      max_n_frames=args.max_n_frames,
                      ctc=args.ctc_weight > 0,
                      ctc_sub1=args.ctc_weight_sub1 > 0,
                      ctc_sub2=args.ctc_weight_sub2 > 0,
                      subsample_factor=args.subsample_factor,
                      subsample_factor_sub1=args.subsample_factor_sub1,
                      subsample_factor_sub2=args.subsample_factor_sub2)
    eval_sets = [Dataset(corpus=args.corpus,
                         tsv_path=s,
                         dict_path=args.dict,
                         nlsyms=args.nlsyms,
                         unit=args.unit,
                         wp_model=args.wp_model,
                         batch_size=1,
                         is_test=True) for s in args.eval_sets]

    args.vocab = train_set.vocab
    args.vocab_sub1 = train_set.vocab_sub1
    args.vocab_sub2 = train_set.vocab_sub2
    args.input_dim = train_set.input_dim

    # Set save path
    if args.resume:
        save_path = os.path.dirname(args.resume)
        dir_name = os.path.basename(save_path)
    else:
        dir_name = set_asr_model_name(args)
        if args.mbr_training:
            assert args.asr_init
            save_path = mkdir_join(os.path.dirname(args.asr_init), dir_name)
        else:
            save_path = mkdir_join(args.model_save_dir, '_'.join(
                os.path.basename(args.train_set).split('.')[:-1]), dir_name)
        save_path = set_save_path(save_path)  # avoid overwriting

    # Set logger
    set_logger(os.path.join(save_path, 'train.log'), stdout=args.stdout)

    # Load a LM conf file for LM fusion & LM initialization
    if not args.resume and args.external_lm:
        lm_conf = load_config(os.path.join(os.path.dirname(args.external_lm), 'conf.yml'))
        args.lm_conf = argparse.Namespace()
        for k, v in lm_conf.items():
            setattr(args.lm_conf, k, v)
        assert args.unit == args.lm_conf.unit
        assert args.vocab == args.lm_conf.vocab

    # Model setting
    model = Speech2Text(args, save_path, train_set.idx2token[0])

    if not args.resume:
        # Save the conf file as a yaml file
        save_config(vars(args), os.path.join(save_path, 'conf.yml'))
        if args.external_lm:
            save_config(args.lm_conf, os.path.join(save_path, 'conf_lm.yml'))

        # Save the nlsyms, dictionary, and wp_model
        if args.nlsyms:
            shutil.copy(args.nlsyms, os.path.join(save_path, 'nlsyms.txt'))
        for sub in ['', '_sub1', '_sub2']:
            if getattr(args, 'dict' + sub):
                shutil.copy(getattr(args, 'dict' + sub), os.path.join(save_path, 'dict' + sub + '.txt'))
            if getattr(args, 'unit' + sub) == 'wp':
                shutil.copy(getattr(args, 'wp_model' + sub), os.path.join(save_path, 'wp' + sub + '.model'))

        for k, v in sorted(vars(args).items(), key=lambda x: x[0]):
            logger.info('%s: %s' % (k, str(v)))

        # Count total parameters
        for n in sorted(list(model.num_params_dict.keys())):
            n_params = model.num_params_dict[n]
            logger.info("%s %d" % (n, n_params))
        logger.info("Total %.2f M parameters" % (model.total_parameters / 1000000))
        logger.info(model)

        # Initialize with pre-trained model's parameters
        if args.asr_init:
            # Load the ASR model (full model)
            conf_init = load_config(os.path.join(os.path.dirname(args.asr_init), 'conf.yml'))
            for k, v in conf_init.items():
                setattr(args_init, k, v)
            model_init = Speech2Text(args_init)
            load_checkpoint(args.asr_init, model_init)

            # Overwrite parameters
            param_dict = dict(model_init.named_parameters())
            for n, p in model.named_parameters():
                if n in param_dict.keys() and p.size() == param_dict[n].size():
                    if args.asr_init_enc_only and 'enc' not in n:
                        continue
                    p.data = param_dict[n].data
                    logger.info('Overwrite %s' % n)

    # Set optimizer
    resume_epoch = 0
    if args.resume:
        resume_epoch = int(args.resume.split('-')[-1])
        optimizer = set_optimizer(model, 'sgd' if resume_epoch > args.convert_to_sgd_epoch else args.optimizer,
                                  args.lr, args.weight_decay)
    else:
        optimizer = set_optimizer(model, args.optimizer, args.lr, args.weight_decay)

    # Wrap optimizer by learning rate scheduler
    is_transformer = 'former' in args.enc_type or 'former' in args.dec_type
    optimizer = LRScheduler(optimizer, args.lr,
                            decay_type=args.lr_decay_type,
                            decay_start_epoch=args.lr_decay_start_epoch,
                            decay_rate=args.lr_decay_rate,
                            decay_patient_n_epochs=args.lr_decay_patient_n_epochs,
                            early_stop_patient_n_epochs=args.early_stop_patient_n_epochs,
                            lower_better=args.metric not in ['accuracy', 'bleu'],
                            warmup_start_lr=args.warmup_start_lr,
                            warmup_n_steps=args.warmup_n_steps,
                            model_size=getattr(args, 'transformer_d_model', 0),
                            factor=args.lr_factor,
                            noam=args.optimizer == 'noam',
                            save_checkpoints_topk=10 if is_transformer else 1)

    if args.resume:
        # Restore the last saved model
        load_checkpoint(args.resume, model, optimizer)

        # Resume between convert_to_sgd_epoch -1 and convert_to_sgd_epoch
        if resume_epoch == args.convert_to_sgd_epoch:
            optimizer.convert_to_sgd(model, args.lr, args.weight_decay,
                                     decay_type='always', decay_rate=0.5)

    # Load the teacher ASR model
    teacher = None
    if args.teacher:
        assert os.path.isfile(args.teacher), 'There is no checkpoint.'
        conf_teacher = load_config(os.path.join(os.path.dirname(args.teacher), 'conf.yml'))
        for k, v in conf_teacher.items():
            setattr(args_teacher, k, v)
        # Setting for knowledge distillation
        args_teacher.ss_prob = 0
        args.lsm_prob = 0
        teacher = Speech2Text(args_teacher)
        load_checkpoint(args.teacher, teacher)

    # Load the teacher LM
    teacher_lm = None
    if args.teacher_lm:
        assert os.path.isfile(args.teacher_lm), 'There is no checkpoint.'
        conf_lm = load_config(os.path.join(os.path.dirname(args.teacher_lm), 'conf.yml'))
        args_lm = argparse.Namespace()
        for k, v in conf_lm.items():
            setattr(args_lm, k, v)
        teacher_lm = build_lm(args_lm)
        load_checkpoint(args.teacher_lm, teacher_lm)

    # GPU setting
    use_apex = args.train_dtype in ["O0", "O1", "O2", "O3"]
    amp = None
    if args.n_gpus >= 1:
        model.cudnn_setting(deterministic=not (is_transformer or args.cudnn_benchmark),
                            benchmark=not is_transformer and args.cudnn_benchmark)
        model.cuda()

        # Mix precision training setting
        if use_apex:
            from apex import amp
            model, optimizer.optimizer = amp.initialize(model, optimizer.optimizer,
                                                        opt_level=args.train_dtype)
            from neural_sp.models.seq2seq.decoders.ctc import CTC
            amp.register_float_function(CTC, "loss_fn")
            # NOTE: see https://github.com/espnet/espnet/pull/1779
            amp.init()
            if args.resume:
                load_checkpoint(args.resume, amp=amp)
        model = CustomDataParallel(model, device_ids=list(range(0, args.n_gpus)))

        if teacher is not None:
            teacher.cuda()
        if teacher_lm is not None:
            teacher_lm.cuda()
    else:
        model = CPUWrapperASR(model)

    # Set process name
    logger.info('PID: %s' % os.getpid())
    logger.info('USERNAME: %s' % os.uname()[1])
    logger.info('#GPU: %d' % torch.cuda.device_count())
    setproctitle(args.job_name if args.job_name else dir_name)

    # Set reporter
    reporter = Reporter(save_path)

    if args.mtl_per_batch:
        # NOTE: from easier to harder tasks
        tasks = []
        if 1 - args.bwd_weight - args.ctc_weight - args.sub1_weight - args.sub2_weight > 0:
            tasks += ['ys']
        if args.bwd_weight > 0:
            tasks = ['ys.bwd'] + tasks
        if args.ctc_weight > 0:
            tasks = ['ys.ctc'] + tasks
        if args.mbr_ce_weight > 0:
            tasks = ['ys.mbr'] + tasks
        for sub in ['sub1', 'sub2']:
            if getattr(args, 'train_set_' + sub):
                if getattr(args, sub + '_weight') - getattr(args, 'ctc_weight_' + sub) > 0:
                    tasks = ['ys_' + sub] + tasks
                if getattr(args, 'ctc_weight_' + sub) > 0:
                    tasks = ['ys_' + sub + '.ctc'] + tasks
    else:
        tasks = ['all']

    start_time_train = time.time()
    start_time_epoch = time.time()
    start_time_step = time.time()
    accum_n_steps = 0
    n_steps = optimizer.n_steps * args.accum_grad_n_steps
    epoch_detail_prev = 0
    for ep in range(resume_epoch, args.n_epochs):
        pbar_epoch = tqdm(total=len(train_set))
        session_prev = None

        for batch_train, is_new_epoch in train_set:
            # Compute loss in the training set
            if args.discourse_aware and batch_train['sessions'][0] != session_prev:
                model.module.reset_session()
            session_prev = batch_train['sessions'][0]
            accum_n_steps += 1

            # Change mini-batch depending on task
            if accum_n_steps == 1:
                loss_train = 0  # moving average over gradient accumulation
            for task in tasks:
                loss, observation = model(batch_train, task,
                                          teacher=teacher, teacher_lm=teacher_lm)
                reporter.add(observation)
                if use_apex:
                    with amp.scale_loss(loss, optimizer.optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
                loss.detach()  # Trancate the graph
                loss_train = (loss_train * (accum_n_steps - 1) + loss.item()) / accum_n_steps
                if accum_n_steps >= args.accum_grad_n_steps or is_new_epoch:
                    if args.clip_grad_norm > 0:
                        total_norm = torch.nn.utils.clip_grad_norm_(
                            model.module.parameters(), args.clip_grad_norm)
                        reporter.add_tensorboard_scalar('total_norm', total_norm)
                    optimizer.step()
                    optimizer.zero_grad()
                    accum_n_steps = 0
                    # NOTE: parameters are forcibly updated at the end of every epoch
                del loss

            pbar_epoch.update(len(batch_train['utt_ids']))
            reporter.add_tensorboard_scalar('learning_rate', optimizer.lr)
            # NOTE: loss/acc/ppl are already added in the model
            reporter.step()
            n_steps += 1
            # NOTE: n_steps is different from the step counter in Noam Optimizer

            if n_steps % args.print_step == 0:
                # Compute loss in the dev set
                batch_dev = iter(dev_set).next(batch_size=1 if 'transducer' in args.dec_type else None)[0]
                # Change mini-batch depending on task
                for task in tasks:
                    loss, observation = model(batch_dev, task, is_eval=True)
                    reporter.add(observation, is_eval=True)
                    loss_dev = loss.item()
                    del loss
                reporter.step(is_eval=True)

                duration_step = time.time() - start_time_step
                if args.input_type == 'speech':
                    xlen = max(len(x) for x in batch_train['xs'])
                    ylen = max(len(y) for y in batch_train['ys'])
                elif args.input_type == 'text':
                    xlen = max(len(x) for x in batch_train['ys'])
                    ylen = max(len(y) for y in batch_train['ys_sub1'])
                logger.info("step:%d(ep:%.2f) loss:%.3f(%.3f)/lr:%.7f/bs:%d/xlen:%d/ylen:%d (%.2f min)" %
                            (n_steps, optimizer.n_epochs + train_set.epoch_detail,
                             loss_train, loss_dev,
                             optimizer.lr, len(batch_train['utt_ids']),
                             xlen, ylen, duration_step / 60))
                start_time_step = time.time()

            # Save fugures of loss and accuracy
            if n_steps % (args.print_step * 10) == 0:
                reporter.snapshot()
                model.module.plot_attention()
                model.module.plot_ctc()

            # Ealuate model every 0.1 epoch during MBR training
            if args.mbr_training:
                if int(train_set.epoch_detail * 10) != int(epoch_detail_prev * 10):
                    # dev
                    evaluate([model.module], dev_set, recog_params, args,
                             int(train_set.epoch_detail * 10) / 10, logger)
                    # Save the model
                    optimizer.save_checkpoint(
                        model, save_path, remove_old=False, amp=amp,
                        epoch_detail=train_set.epoch_detail)
                epoch_detail_prev = train_set.epoch_detail

            if is_new_epoch:
                break

        # Save checkpoint and evaluate model per epoch
        duration_epoch = time.time() - start_time_epoch
        logger.info('========== EPOCH:%d (%.2f min) ==========' %
                    (optimizer.n_epochs + 1, duration_epoch / 60))

        if optimizer.n_epochs + 1 < args.eval_start_epoch:
            optimizer.epoch()  # lr decay
            reporter.epoch()  # plot

            # Save the model
            optimizer.save_checkpoint(
                model, save_path, remove_old=not is_transformer and args.remove_old_checkpoints, amp=amp)
        else:
            start_time_eval = time.time()
            # dev
            metric_dev = evaluate([model.module], dev_set, recog_params, args,
                                  optimizer.n_epochs + 1, logger)
            optimizer.epoch(metric_dev)  # lr decay
            reporter.epoch(metric_dev, name=args.metric)  # plot

            if optimizer.is_topk or is_transformer:
                # Save the model
                optimizer.save_checkpoint(
                    model, save_path, remove_old=not is_transformer and args.remove_old_checkpoints, amp=amp)

                # test
                if optimizer.is_topk:
                    for eval_set in eval_sets:
                        evaluate([model.module], eval_set, recog_params, args,
                                 optimizer.n_epochs, logger)

            duration_eval = time.time() - start_time_eval
            logger.info('Evaluation time: %.2f min' % (duration_eval / 60))

            # Early stopping
            if optimizer.is_early_stop:
                break

            # Convert to fine-tuning stage
            if optimizer.n_epochs == args.convert_to_sgd_epoch:
                optimizer.convert_to_sgd(model, args.lr, args.weight_decay,
                                         decay_type='always', decay_rate=0.5)

            if optimizer.n_epochs >= args.n_epochs:
                break
            # if args.ss_prob > 0:
            #     model.module.scheduled_sampling_trigger()

            start_time_step = time.time()
            start_time_epoch = time.time()

    duration_train = time.time() - start_time_train
    logger.info('Total time: %.2f hour' % (duration_train / 3600))

    reporter.tf_writer.close()
    pbar_epoch.close()

    return save_path
Esempio n. 25
0
def eval_char(models, dataloader, recog_params, epoch,
              recog_dir=None, streaming=False, progressbar=False, task_idx=0):
    """Evaluate the character-level model by WER & CER.

    Args:
        models (list): models to evaluate
        dataloader (torch.utils.data.DataLoader): evaluation dataloader
        recog_params (dict):
        epoch (int):
        recog_dir (str):
        streaming (bool): streaming decoding for the session-level evaluation
        progressbar (bool): visualize the progressbar
        task_idx (int): the index of the target task in interest
            0: main task
            1: sub task
            2: sub sub task
    Returns:
        wer (float): Word error rate
        cer (float): Character error rate

    """
    if recog_dir is None:
        recog_dir = 'decode_' + dataloader.set + '_ep' + str(epoch) + '_beam' + str(recog_params['recog_beam_width'])
        recog_dir += '_lp' + str(recog_params['recog_length_penalty'])
        recog_dir += '_cp' + str(recog_params['recog_coverage_penalty'])
        recog_dir += '_' + str(recog_params['recog_min_len_ratio']) + '_' + str(recog_params['recog_max_len_ratio'])
        recog_dir += '_lm' + str(recog_params['recog_lm_weight'])

        ref_trn_path = mkdir_join(models[0].save_path, recog_dir, 'ref.trn')
        hyp_trn_path = mkdir_join(models[0].save_path, recog_dir, 'hyp.trn')
    else:
        ref_trn_path = mkdir_join(recog_dir, 'ref.trn')
        hyp_trn_path = mkdir_join(recog_dir, 'hyp.trn')

    wer, cer = 0, 0
    n_sub_w, n_ins_w, n_del_w = 0, 0, 0
    n_sub_c, n_ins_c, n_del_c = 0, 0, 0
    n_word, n_char = 0, 0
    n_streamable, quantity_rate, n_utt = 0, 0, 0
    last_success_frame_ratio = 0

    # Reset data counter
    dataloader.reset(recog_params['recog_batch_size'])

    if progressbar:
        pbar = tqdm(total=len(dataloader))

    if task_idx == 0:
        task = 'ys'
    elif task_idx == 1:
        task = 'ys_sub1'
    elif task_idx == 2:
        task = 'ys_sub2'
    elif task_idx == 3:
        task = 'ys_sub3'

    with codecs.open(hyp_trn_path, 'w', encoding='utf-8') as f_hyp, \
            codecs.open(ref_trn_path, 'w', encoding='utf-8') as f_ref:
        while True:
            batch, is_new_epoch = dataloader.next(recog_params['recog_batch_size'])
            if streaming or recog_params['recog_chunk_sync']:
                best_hyps_id, _ = models[0].decode_streaming(
                    batch['xs'], recog_params, dataloader.idx2token[0],
                    exclude_eos=True)
            else:
                best_hyps_id, _ = models[0].decode(
                    batch['xs'], recog_params,
                    idx2token=dataloader.idx2token[task_idx] if progressbar else None,
                    exclude_eos=True,
                    refs_id=batch['ys'] if task_idx == 0 else batch['ys_sub' + str(task_idx)],
                    utt_ids=batch['utt_ids'],
                    speakers=batch['sessions' if dataloader.corpus == 'swbd' else 'speakers'],
                    task=task,
                    ensemble_models=models[1:] if len(models) > 1 else [])

            for b in range(len(batch['xs'])):
                ref = batch['text'][b]
                hyp = dataloader.idx2token[task_idx](best_hyps_id[b])

                # Truncate the first and last spaces for the char_space unit
                if len(hyp) > 0 and hyp[0] == ' ':
                    hyp = hyp[1:]
                if len(hyp) > 0 and hyp[-1] == ' ':
                    hyp = hyp[:-1]

                # Write to trn
                speaker = str(batch['speakers'][b]).replace('-', '_')
                if streaming:
                    utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001'
                else:
                    utt_id = str(batch['utt_ids'][b])
                f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n')
                f_hyp.write(hyp + ' (' + speaker + '-' + utt_id + ')\n')
                logger.debug('utt-id: %s' % utt_id)
                logger.debug('Ref: %s' % ref)
                logger.debug('Hyp: %s' % hyp)
                logger.debug('-' * 150)

                if not streaming:
                    if ('char' in dataloader.unit and 'nowb' not in dataloader.unit) or (task_idx > 0 and dataloader.unit_sub1 == 'char'):
                        # Compute WER
                        wer_b, sub_b, ins_b, del_b = compute_wer(ref=ref.split(' '),
                                                                 hyp=hyp.split(' '),
                                                                 normalize=False)
                        wer += wer_b
                        n_sub_w += sub_b
                        n_ins_w += ins_b
                        n_del_w += del_b
                        n_word += len(ref.split(' '))
                        # NOTE: sentence error rate for Chinese

                    # Compute CER
                    if dataloader.corpus == 'csj':
                        ref = ref.replace(' ', '')
                        hyp = hyp.replace(' ', '')
                    cer_b, sub_b, ins_b, del_b = compute_wer(ref=list(ref),
                                                             hyp=list(hyp),
                                                             normalize=False)
                    cer += cer_b
                    n_sub_c += sub_b
                    n_ins_c += ins_b
                    n_del_c += del_b
                    n_char += len(ref)
                    if models[0].streamable():
                        n_streamable += 1
                    else:
                        last_success_frame_ratio += models[0].last_success_frame_ratio()
                    quantity_rate += models[0].quantity_rate()
                    n_utt += 1

                if progressbar:
                    pbar.update(1)

            if is_new_epoch:
                break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataloader.reset()

    if not streaming:
        if ('char' in dataloader.unit and 'nowb' not in dataloader.unit) or (task_idx > 0 and dataloader.unit_sub1 == 'char'):
            wer /= n_word
            n_sub_w /= n_word
            n_ins_w /= n_word
            n_del_w /= n_word
        else:
            wer = n_sub_w = n_ins_w = n_del_w = 0

        cer /= n_char
        n_sub_c /= n_char
        n_ins_c /= n_char
        n_del_c /= n_char

        if n_utt - n_streamable > 0:
            last_success_frame_ratio /= (n_utt - n_streamable)
        n_streamable /= n_utt
        quantity_rate /= n_utt

    logger.debug('WER (%s): %.2f %%' % (dataloader.set, wer))
    logger.debug('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_w, n_ins_w, n_del_w))
    logger.debug('CER (%s): %.2f %%' % (dataloader.set, cer))
    logger.debug('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_c, n_ins_c, n_del_c))

    logger.info('Streamability (%s): %.2f %%' % (dataloader.set, n_streamable * 100))
    logger.info('Quantity rate (%s): %.2f %%' % (dataloader.set, quantity_rate * 100))
    logger.info('Last success frame ratio (%s): %.2f %%' % (dataloader.set, last_success_frame_ratio))

    return wer, cer
Esempio n. 26
0
def main():

    args = parse()
    args_init = copy.deepcopy(args)
    args_teacher = copy.deepcopy(args)

    # Load a conf file
    if args.resume:
        conf = load_config(
            os.path.join(os.path.dirname(args.resume), 'conf.yml'))
        for k, v in conf.items():
            if k != 'resume':
                setattr(args, k, v)
    recog_params = vars(args)

    # Compute subsampling factor
    subsample_factor = 1
    subsample_factor_sub1 = 1
    subsample_factor_sub2 = 1
    subsample = [int(s) for s in args.subsample.split('_')]
    if args.conv_poolings and 'conv' in args.enc_type:
        for p in args.conv_poolings.split('_'):
            subsample_factor *= int(p.split(',')[0].replace('(', ''))
    else:
        subsample_factor = np.prod(subsample)
    if args.train_set_sub1:
        if args.conv_poolings and 'conv' in args.enc_type:
            subsample_factor_sub1 = subsample_factor * np.prod(
                subsample[:args.enc_n_layers_sub1 - 1])
        else:
            subsample_factor_sub1 = subsample_factor
    if args.train_set_sub2:
        if args.conv_poolings and 'conv' in args.enc_type:
            subsample_factor_sub2 = subsample_factor * np.prod(
                subsample[:args.enc_n_layers_sub2 - 1])
        else:
            subsample_factor_sub2 = subsample_factor

    # Set save path
    if args.resume:
        save_path = os.path.dirname(args.resume)
        dir_name = os.path.basename(save_path)
    else:
        dir_name = set_asr_model_name(args, subsample_factor)
        save_path = mkdir_join(
            args.model_save_dir,
            '_'.join(os.path.basename(args.train_set).split('.')[:-1]),
            dir_name)
        save_path = set_save_path(save_path)  # avoid overwriting

    # Set logger
    logger = set_logger(os.path.join(save_path, 'train.log'),
                        key='training',
                        stdout=args.stdout)

    # for multi-GPUs
    if args.n_gpus > 1:
        logger.info("Batch size is automatically reduced from %d to %d" %
                    (args.batch_size, args.batch_size // 2))
        args.batch_size //= 2

    skip_thought = 'skip' in args.enc_type

    # Load dataset
    train_set = Dataset(corpus=args.corpus,
                        tsv_path=args.train_set,
                        tsv_path_sub1=args.train_set_sub1,
                        tsv_path_sub2=args.train_set_sub2,
                        dict_path=args.dict,
                        dict_path_sub1=args.dict_sub1,
                        dict_path_sub2=args.dict_sub2,
                        nlsyms=args.nlsyms,
                        unit=args.unit,
                        unit_sub1=args.unit_sub1,
                        unit_sub2=args.unit_sub2,
                        wp_model=args.wp_model,
                        wp_model_sub1=args.wp_model_sub1,
                        wp_model_sub2=args.wp_model_sub2,
                        batch_size=args.batch_size * args.n_gpus,
                        n_epochs=args.n_epochs,
                        min_n_frames=args.min_n_frames,
                        max_n_frames=args.max_n_frames,
                        sort_by='input',
                        short2long=True,
                        sort_stop_epoch=args.sort_stop_epoch,
                        dynamic_batching=args.dynamic_batching,
                        ctc=args.ctc_weight > 0,
                        ctc_sub1=args.ctc_weight_sub1 > 0,
                        ctc_sub2=args.ctc_weight_sub2 > 0,
                        subsample_factor=subsample_factor,
                        subsample_factor_sub1=subsample_factor_sub1,
                        subsample_factor_sub2=subsample_factor_sub2,
                        discourse_aware=args.discourse_aware,
                        skip_thought=skip_thought)
    dev_set = Dataset(corpus=args.corpus,
                      tsv_path=args.dev_set,
                      tsv_path_sub1=args.dev_set_sub1,
                      tsv_path_sub2=args.dev_set_sub2,
                      dict_path=args.dict,
                      dict_path_sub1=args.dict_sub1,
                      dict_path_sub2=args.dict_sub2,
                      nlsyms=args.nlsyms,
                      unit=args.unit,
                      unit_sub1=args.unit_sub1,
                      unit_sub2=args.unit_sub2,
                      wp_model=args.wp_model,
                      wp_model_sub1=args.wp_model_sub1,
                      wp_model_sub2=args.wp_model_sub2,
                      batch_size=args.batch_size * args.n_gpus,
                      min_n_frames=args.min_n_frames,
                      max_n_frames=args.max_n_frames,
                      ctc=args.ctc_weight > 0,
                      ctc_sub1=args.ctc_weight_sub1 > 0,
                      ctc_sub2=args.ctc_weight_sub2 > 0,
                      subsample_factor=subsample_factor,
                      subsample_factor_sub1=subsample_factor_sub1,
                      subsample_factor_sub2=subsample_factor_sub2,
                      discourse_aware=args.discourse_aware,
                      skip_thought=skip_thought)
    eval_sets = []
    for s in args.eval_sets:
        eval_sets += [
            Dataset(corpus=args.corpus,
                    tsv_path=s,
                    dict_path=args.dict,
                    nlsyms=args.nlsyms,
                    unit=args.unit,
                    wp_model=args.wp_model,
                    batch_size=1,
                    discourse_aware=args.discourse_aware,
                    skip_thought=skip_thought,
                    is_test=True)
        ]

    args.vocab = train_set.vocab
    args.vocab_sub1 = train_set.vocab_sub1
    args.vocab_sub2 = train_set.vocab_sub2
    args.input_dim = train_set.input_dim

    # Load a LM conf file for LM fusion & LM initialization
    if not args.resume and (args.lm_fusion or args.lm_init):
        if args.lm_fusion:
            lm_conf = load_config(
                os.path.join(os.path.dirname(args.lm_fusion), 'conf.yml'))
        elif args.lm_init:
            lm_conf = load_config(
                os.path.join(os.path.dirname(args.lm_init), 'conf.yml'))
        args.lm_conf = argparse.Namespace()
        for k, v in lm_conf.items():
            setattr(args.lm_conf, k, v)
        assert args.unit == args.lm_conf.unit
        assert args.vocab == args.lm_conf.vocab

    # Model setting
    model = Speech2Text(args, save_path) if not skip_thought else SkipThought(
        args, save_path)

    if args.resume:
        # Set optimizer
        epoch = int(args.resume.split('-')[-1])
        optimizer = set_optimizer(
            model, 'sgd' if epoch > conf['convert_to_sgd_epoch'] else
            conf['optimizer'], conf['lr'], conf['weight_decay'])

        # Wrap optimizer by learning rate scheduler
        noam = 'transformer' in conf['enc_type'] or conf[
            'dec_type'] == 'transformer'
        optimizer = LRScheduler(
            optimizer,
            conf['lr'],
            decay_type=conf['lr_decay_type'],
            decay_start_epoch=conf['lr_decay_start_epoch'],
            decay_rate=conf['lr_decay_rate'],
            decay_patient_n_epochs=conf['lr_decay_patient_n_epochs'],
            early_stop_patient_n_epochs=conf['early_stop_patient_n_epochs'],
            warmup_start_lr=conf['warmup_start_lr'],
            warmup_n_steps=conf['warmup_n_steps'],
            model_size=conf['d_model'],
            factor=conf['lr_factor'],
            noam=noam)

        # Restore the last saved model
        model, optimizer = load_checkpoint(model,
                                           args.resume,
                                           optimizer,
                                           resume=True)

        # Resume between convert_to_sgd_epoch -1 and convert_to_sgd_epoch
        if epoch == conf['convert_to_sgd_epoch']:
            optimizer.convert_to_sgd(model,
                                     'sgd',
                                     args.lr,
                                     conf['weight_decay'],
                                     decay_type='always',
                                     decay_rate=0.5)
    else:
        # Save the conf file as a yaml file
        save_config(vars(args), os.path.join(save_path, 'conf.yml'))
        if args.lm_fusion:
            save_config(args.lm_conf, os.path.join(save_path, 'conf_lm.yml'))

        # Save the nlsyms, dictionar, and wp_model
        if args.nlsyms:
            shutil.copy(args.nlsyms, os.path.join(save_path, 'nlsyms.txt'))
        for sub in ['', '_sub1', '_sub2']:
            if getattr(args, 'dict' + sub):
                shutil.copy(getattr(args, 'dict' + sub),
                            os.path.join(save_path, 'dict' + sub + '.txt'))
            if getattr(args, 'unit' + sub) == 'wp':
                shutil.copy(getattr(args, 'wp_model' + sub),
                            os.path.join(save_path, 'wp' + sub + '.model'))

        for k, v in sorted(vars(args).items(), key=lambda x: x[0]):
            logger.info('%s: %s' % (k, str(v)))

        # Count total parameters
        for n in sorted(list(model.num_params_dict.keys())):
            n_params = model.num_params_dict[n]
            logger.info("%s %d" % (n, n_params))
        logger.info("Total %.2f M parameters" %
                    (model.total_parameters / 1000000))
        logger.info(model)

        # Initialize with pre-trained model's parameters
        if args.asr_init and os.path.isfile(args.asr_init):
            # Load the ASR model
            conf_init = load_config(
                os.path.join(os.path.dirname(args.asr_init), 'conf.yml'))
            for k, v in conf_init.items():
                setattr(args_init, k, v)
            model_init = Speech2Text(args_init)
            model_init = load_checkpoint(model_init, args.asr_init)[0]

            # Overwrite parameters
            only_enc = (args.enc_n_layers != args_init.enc_n_layers) or (
                args.unit != args_init.unit) or args_init.ctc_weight == 1
            param_dict = dict(model_init.named_parameters())
            for n, p in model.named_parameters():
                if n in param_dict.keys() and p.size() == param_dict[n].size():
                    if only_enc and 'enc' not in n:
                        continue
                    if args.lm_fusion_type == 'cache' and 'output' in n:
                        continue
                    p.data = param_dict[n].data
                    logger.info('Overwrite %s' % n)

        # Set optimizer
        optimizer = set_optimizer(model, args.optimizer, args.lr,
                                  args.weight_decay)

        # Wrap optimizer by learning rate scheduler
        noam = 'transformer' in args.enc_type or args.dec_type == 'transformer'
        optimizer = LRScheduler(
            optimizer,
            args.lr,
            decay_type=args.lr_decay_type,
            decay_start_epoch=args.lr_decay_start_epoch,
            decay_rate=args.lr_decay_rate,
            decay_patient_n_epochs=args.lr_decay_patient_n_epochs,
            early_stop_patient_n_epochs=args.early_stop_patient_n_epochs,
            warmup_start_lr=args.warmup_start_lr,
            warmup_n_steps=args.warmup_n_steps,
            model_size=args.d_model,
            factor=args.lr_factor,
            noam=noam)

    # Load the teacher ASR model
    teacher = None
    if args.teacher and os.path.isfile(args.teacher):
        conf_teacher = load_config(
            os.path.join(os.path.dirname(args.teacher), 'conf.yml'))
        for k, v in conf_teacher.items():
            setattr(args_teacher, k, v)
        # Setting for knowledge distillation
        args_teacher.ss_prob = 0
        args.lsm_prob = 0
        teacher = Speech2Text(args_teacher)
        teacher = load_checkpoint(teacher, args.teacher)[0]

    # Load the teacher LM
    teacher_lm = None
    if args.teacher_lm and os.path.isfile(args.teacher_lm):
        conf_lm = load_config(
            os.path.join(os.path.dirname(args.teacher_lm), 'conf.yml'))
        args_lm = argparse.Namespace()
        for k, v in conf_lm.items():
            setattr(args_lm, k, v)
        teacher_lm = build_lm(args_lm)
        teacher_lm = load_checkpoint(teacher_lm, args.teacher_lm)[0]

    # GPU setting
    if args.n_gpus >= 1:
        torch.backends.cudnn.benchmark = True
        model = CustomDataParallel(model,
                                   device_ids=list(range(0, args.n_gpus)))
        model.cuda()
        if teacher is not None:
            teacher.cuda()
        if teacher_lm is not None:
            teacher_lm.cuda()

    # Set process name
    logger.info('PID: %s' % os.getpid())
    logger.info('USERNAME: %s' % os.uname()[1])
    setproctitle(args.job_name if args.job_name else dir_name)

    # Set reporter
    reporter = Reporter(save_path)

    if args.mtl_per_batch:
        # NOTE: from easier to harder tasks
        tasks = []
        if 1 - args.bwd_weight - args.ctc_weight - args.sub1_weight - args.sub2_weight > 0:
            tasks += ['ys']
        if args.bwd_weight > 0:
            tasks = ['ys.bwd'] + tasks
        if args.ctc_weight > 0:
            tasks = ['ys.ctc'] + tasks
        for sub in ['sub1', 'sub2']:
            if getattr(args, 'train_set_' + sub):
                if getattr(args, sub + '_weight') - getattr(
                        args, 'ctc_weight_' + sub) > 0:
                    tasks = ['ys_' + sub] + tasks
                if getattr(args, 'ctc_weight_' + sub) > 0:
                    tasks = ['ys_' + sub + '.ctc'] + tasks
    else:
        tasks = ['all']

    start_time_train = time.time()
    start_time_epoch = time.time()
    start_time_step = time.time()
    pbar_epoch = tqdm(total=len(train_set))
    accum_n_tokens = 0
    while True:
        # Compute loss in the training set
        batch_train, is_new_epoch = train_set.next()
        accum_n_tokens += sum([len(y) for y in batch_train['ys']])

        # Change mini-batch depending on task
        for task in tasks:
            if skip_thought:
                loss, reporter = model(batch_train['ys'],
                                       ys_prev=batch_train['ys_prev'],
                                       ys_next=batch_train['ys_next'],
                                       reporter=reporter)
            else:
                loss, reporter = model(batch_train,
                                       reporter,
                                       task,
                                       teacher=teacher,
                                       teacher_lm=teacher_lm)
            loss.backward()
            loss.detach()  # Trancate the graph
            if args.accum_grad_n_tokens == 0 or accum_n_tokens >= args.accum_grad_n_tokens:
                if args.clip_grad_norm > 0:
                    total_norm = torch.nn.utils.clip_grad_norm_(
                        model.module.parameters(), args.clip_grad_norm)
                    reporter.add_tensorboard_scalar('total_norm', total_norm)
                optimizer.step()
                optimizer.zero_grad()
                accum_n_tokens = 0
            loss_train = loss.item()
            del loss
        reporter.add_tensorboard_scalar('learning_rate', optimizer.lr)
        # NOTE: loss/acc/ppl are already added in the model
        reporter.step()

        if optimizer.n_steps % args.print_step == 0:
            # Compute loss in the dev set
            batch_dev = dev_set.next()[0]
            # Change mini-batch depending on task
            for task in tasks:
                if skip_thought:
                    loss, reporter = model(batch_dev['ys'],
                                           ys_prev=batch_dev['ys_prev'],
                                           ys_next=batch_dev['ys_next'],
                                           reporter=reporter,
                                           is_eval=True)
                else:
                    loss, reporter = model(batch_dev,
                                           reporter,
                                           task,
                                           is_eval=True)
                loss_dev = loss.item()
                del loss
            # NOTE: this makes training slow
            # Compute WER/CER regardless of the output unit (greedy decoding)
            # best_hyps_id, _, _ = model.module.decode(
            #     batch_dev['xs'], recog_params, dev_set.idx2token[0], exclude_eos=True)
            # cer = 0.
            # ref_n_words, ref_n_chars = 0, 0
            # for b in range(len(batch_dev['xs'])):
            #     ref = batch_dev['text'][b]
            #     hyp = dev_set.idx2token[0](best_hyps_id[b])
            #     cer += editdistance.eval(hyp, ref)
            #     ref_n_words += len(ref.split())
            #     ref_n_chars += len(ref)
            # wer = cer / ref_n_words
            # cer /= ref_n_chars
            # reporter.add_tensorboard_scalar('dev/WER', wer)
            # reporter.add_tensorboard_scalar('dev/CER', cer)
            # logger.info('WER (dev)', wer)
            # logger.info('CER (dev)', cer)
            reporter.step(is_eval=True)

            duration_step = time.time() - start_time_step
            if args.input_type == 'speech':
                xlen = max(len(x) for x in batch_train['xs'])
                ylen = max(len(y) for y in batch_train['ys'])
            elif args.input_type == 'text':
                xlen = max(len(x) for x in batch_train['ys'])
                ylen = max(len(y) for y in batch_train['ys_sub1'])
            logger.info(
                "step:%d(ep:%.2f) loss:%.3f(%.3f)/lr:%.7f/bs:%d/xlen:%d/ylen:%d (%.2f min)"
                %
                (optimizer.n_steps, optimizer.n_epochs +
                 train_set.epoch_detail, loss_train, loss_dev, optimizer.lr,
                 len(batch_train['utt_ids']), xlen, ylen, duration_step / 60))
            start_time_step = time.time()
        pbar_epoch.update(len(batch_train['utt_ids']))

        # Save fugures of loss and accuracy
        if optimizer.n_steps % (args.print_step * 10) == 0:
            reporter.snapshot()
            model.module.plot_attention()

        # Save checkpoint and evaluate model per epoch
        if is_new_epoch:
            duration_epoch = time.time() - start_time_epoch
            logger.info('========== EPOCH:%d (%.2f min) ==========' %
                        (optimizer.n_epochs + 1, duration_epoch / 60))

            if optimizer.n_epochs + 1 < args.eval_start_epoch:
                optimizer.epoch()  # lr decay
                reporter.epoch()  # plot

                # Save the model
                save_checkpoint(model,
                                save_path,
                                optimizer,
                                optimizer.n_epochs,
                                remove_old_checkpoints=not noam)
            else:
                start_time_eval = time.time()
                # dev
                metric_dev = eval_epoch([model.module], dev_set, recog_params,
                                        args, optimizer.n_epochs + 1, logger)
                optimizer.epoch(metric_dev)  # lr decay
                reporter.epoch(metric_dev)  # plot

                if optimizer.is_best:
                    # Save the model
                    save_checkpoint(model,
                                    save_path,
                                    optimizer,
                                    optimizer.n_epochs,
                                    remove_old_checkpoints=not noam)

                    # test
                    for eval_set in eval_sets:
                        eval_epoch([model.module], eval_set, recog_params,
                                   args, optimizer.n_epochs, logger)

                    # start scheduled sampling
                    if args.ss_prob > 0:
                        model.module.scheduled_sampling_trigger()

                duration_eval = time.time() - start_time_eval
                logger.info('Evaluation time: %.2f min' % (duration_eval / 60))

                # Early stopping
                if optimizer.is_early_stop:
                    break

                # Convert to fine-tuning stage
                if optimizer.n_epochs == args.convert_to_sgd_epoch:
                    optimizer.convert_to_sgd(model,
                                             'sgd',
                                             args.lr,
                                             args.weight_decay,
                                             decay_type='always',
                                             decay_rate=0.5)

            pbar_epoch = tqdm(total=len(train_set))

            if optimizer.n_epochs == args.n_epochs:
                break

            start_time_step = time.time()
            start_time_epoch = time.time()

    duration_train = time.time() - start_time_train
    logger.info('Total time: %.2f hour' % (duration_train / 3600))

    reporter.tf_writer.close()
    pbar_epoch.close()

    return save_path
Esempio n. 27
0
def main():

    # Load configuration
    args, recog_params, dir_name = parse_args_eval(sys.argv[1:])

    # Setting for logging
    if os.path.isfile(os.path.join(args.recog_dir, 'plot.log')):
        os.remove(os.path.join(args.recog_dir, 'plot.log'))
    set_logger(os.path.join(args.recog_dir, 'plot.log'),
               stdout=args.recog_stdout)

    for i, s in enumerate(args.recog_sets):
        # Load dataset
        dataset = Dataset(
            corpus=args.corpus,
            tsv_path=s,
            dict_path=os.path.join(dir_name, 'dict.txt'),
            dict_path_sub1=os.path.join(dir_name, 'dict_sub1.txt') if
            os.path.isfile(os.path.join(dir_name, 'dict_sub1.txt')) else False,
            nlsyms=args.nlsyms,
            wp_model=os.path.join(dir_name, 'wp.model'),
            unit=args.unit,
            unit_sub1=args.unit_sub1,
            batch_size=args.recog_batch_size,
            is_test=True)

        if i == 0:
            # Load the ASR model
            model = Speech2Text(args, dir_name)
            epoch = int(args.recog_model[0].split('-')[-1])
            if args.recog_n_average > 1:
                # Model averaging for Transformer
                model = average_checkpoints(model,
                                            args.recog_model[0],
                                            n_average=args.recog_n_average)
            else:
                load_checkpoint(args.recog_model[0], model)

            if not args.recog_unit:
                args.recog_unit = args.unit

            logger.info('recog unit: %s' % args.recog_unit)
            logger.info('epoch: %d' % epoch)
            logger.info('batch size: %d' % args.recog_batch_size)

            # GPU setting
            if args.recog_n_gpus >= 1:
                model.cudnn_setting(deterministic=True, benchmark=False)
                model.cuda()

        save_path = mkdir_join(args.recog_dir, 'ctc_probs')

        # Clean directory
        if save_path is not None and os.path.isdir(save_path):
            shutil.rmtree(save_path)
            os.mkdir(save_path)

        while True:
            batch, is_new_epoch = dataset.next(
                recog_params['recog_batch_size'])
            best_hyps_id, _ = model.decode(batch['xs'], recog_params)

            # Get CTC probs
            ctc_probs, topk_ids, xlens = model.get_ctc_probs(batch['xs'],
                                                             temperature=1,
                                                             topk=min(
                                                                 100,
                                                                 model.vocab))
            # NOTE: ctc_probs: '[B, T, topk]'

            for b in range(len(batch['xs'])):
                tokens = dataset.idx2token[0](best_hyps_id[b],
                                              return_list=True)
                spk = batch['speakers'][b]

                plot_ctc_probs(
                    ctc_probs[b, :xlens[b]],
                    topk_ids[b],
                    subsample_factor=args.subsample_factor,
                    spectrogram=batch['xs'][b][:, :dataset.input_dim],
                    save_path=mkdir_join(save_path, spk,
                                         batch['utt_ids'][b] + '.png'),
                    figsize=(20, 8))

                hyp = ' '.join(tokens)
                logger.info('utt-id: %s' % batch['utt_ids'][b])
                logger.info('Ref: %s' % batch['text'][b].lower())
                logger.info('Hyp: %s' % hyp)
                logger.info('-' * 50)

            if is_new_epoch:
                break
Esempio n. 28
0
def main():

    args = parse()

    # Load a conf file
    if args.resume:
        conf = load_config(
            os.path.join(os.path.dirname(args.resume), 'conf.yml'))
        for k, v in conf.items():
            if k != 'resume':
                setattr(args, k, v)

    # Set save path
    if args.resume:
        save_path = os.path.dirname(args.resume)
        dir_name = os.path.basename(save_path)
    else:
        dir_name = set_lm_name(args)
        save_path = mkdir_join(
            args.model_save_dir,
            '_'.join(os.path.basename(args.train_set).split('.')[:-1]),
            dir_name)
        save_path = set_save_path(save_path)  # avoid overwriting

    # Set logger
    logger = set_logger(os.path.join(save_path, 'train.log'),
                        key='training',
                        stdout=args.stdout)

    # Load dataset
    train_set = Dataset(corpus=args.corpus,
                        tsv_path=args.train_set,
                        dict_path=args.dict,
                        nlsyms=args.nlsyms,
                        unit=args.unit,
                        wp_model=args.wp_model,
                        batch_size=args.batch_size * args.n_gpus,
                        n_epochs=args.n_epochs,
                        min_n_tokens=args.min_n_tokens,
                        bptt=args.bptt,
                        backward=args.backward,
                        serialize=args.serialize)
    dev_set = Dataset(corpus=args.corpus,
                      tsv_path=args.dev_set,
                      dict_path=args.dict,
                      nlsyms=args.nlsyms,
                      unit=args.unit,
                      wp_model=args.wp_model,
                      batch_size=args.batch_size * args.n_gpus,
                      bptt=args.bptt,
                      backward=args.backward,
                      serialize=args.serialize)
    eval_sets = []
    for s in args.eval_sets:
        eval_sets += [
            Dataset(corpus=args.corpus,
                    tsv_path=s,
                    dict_path=args.dict,
                    nlsyms=args.nlsyms,
                    unit=args.unit,
                    wp_model=args.wp_model,
                    batch_size=1,
                    bptt=args.bptt,
                    backward=args.backward,
                    serialize=args.serialize)
        ]

    args.vocab = train_set.vocab

    # Model setting
    model = build_lm(args, save_path)

    if args.resume:
        # Set optimizer
        epoch = int(args.resume.split('-')[-1])
        optimizer = set_optimizer(
            model, 'sgd' if epoch > conf['convert_to_sgd_epoch'] else
            conf['optimizer'], conf['lr'], conf['weight_decay'])

        # Wrap optimizer by learning rate scheduler
        optimizer = LRScheduler(
            optimizer,
            conf['lr'],
            decay_type=conf['lr_decay_type'],
            decay_start_epoch=conf['lr_decay_start_epoch'],
            decay_rate=conf['lr_decay_rate'],
            decay_patient_n_epochs=conf['lr_decay_patient_n_epochs'],
            early_stop_patient_n_epochs=conf['early_stop_patient_n_epochs'],
            warmup_start_lr=conf['warmup_start_lr'],
            warmup_n_steps=conf['warmup_n_steps'],
            model_size=conf['d_model'],
            factor=conf['lr_factor'],
            noam=conf['lm_type'] == 'transformer')

        # Restore the last saved model
        model, optimizer = load_checkpoint(model,
                                           args.resume,
                                           optimizer,
                                           resume=True)

        # Resume between convert_to_sgd_epoch -1 and convert_to_sgd_epoch
        if epoch == conf['convert_to_sgd_epoch']:
            n_epochs = optimizer.n_epochs
            n_steps = optimizer.n_steps
            optimizer = set_optimizer(model, 'sgd', args.lr,
                                      conf['weight_decay'])
            optimizer = LRScheduler(optimizer,
                                    args.lr,
                                    decay_type='always',
                                    decay_start_epoch=0,
                                    decay_rate=0.5)
            optimizer._epoch = n_epochs
            optimizer._step = n_steps
            logger.info('========== Convert to SGD ==========')
    else:
        # Save the conf file as a yaml file
        save_config(vars(args), os.path.join(save_path, 'conf.yml'))

        # Save the nlsyms, dictionar, and wp_model
        if args.nlsyms:
            shutil.copy(args.nlsyms, os.path.join(save_path, 'nlsyms.txt'))
        shutil.copy(args.dict, os.path.join(save_path, 'dict.txt'))
        if args.unit == 'wp':
            shutil.copy(args.wp_model, os.path.join(save_path, 'wp.model'))

        for k, v in sorted(vars(args).items(), key=lambda x: x[0]):
            logger.info('%s: %s' % (k, str(v)))

        # Count total parameters
        for n in sorted(list(model.num_params_dict.keys())):
            n_params = model.num_params_dict[n]
            logger.info("%s %d" % (n, n_params))
        logger.info("Total %.2f M parameters" %
                    (model.total_parameters / 1000000))
        logger.info(model)

        # Set optimizer
        optimizer = set_optimizer(model, args.optimizer, args.lr,
                                  args.weight_decay)

        # Wrap optimizer by learning rate scheduler
        optimizer = LRScheduler(
            optimizer,
            args.lr,
            decay_type=args.lr_decay_type,
            decay_start_epoch=args.lr_decay_start_epoch,
            decay_rate=args.lr_decay_rate,
            decay_patient_n_epochs=args.lr_decay_patient_n_epochs,
            early_stop_patient_n_epochs=args.early_stop_patient_n_epochs,
            warmup_start_lr=args.warmup_start_lr,
            warmup_n_steps=args.warmup_n_steps,
            model_size=args.d_model,
            factor=args.lr_factor,
            noam=args.lm_type == 'transformer')

    # GPU setting
    if args.n_gpus >= 1:
        torch.backends.cudnn.benchmark = True
        model = CustomDataParallel(model,
                                   device_ids=list(range(0, args.n_gpus)))
        model.cuda()

    # Set process name
    logger.info('PID: %s' % os.getpid())
    logger.info('USERNAME: %s' % os.uname()[1])
    setproctitle(args.job_name if args.job_name else dir_name)

    # Set reporter
    reporter = Reporter(save_path)

    hidden = None
    start_time_train = time.time()
    start_time_epoch = time.time()
    start_time_step = time.time()
    pbar_epoch = tqdm(total=len(train_set))
    accum_n_tokens = 0
    while True:
        # Compute loss in the training set
        ys_train, is_new_epoch = train_set.next()
        accum_n_tokens += sum([len(y) for y in ys_train])
        optimizer.zero_grad()
        loss, hidden, reporter = model(ys_train, hidden, reporter)
        loss.backward()
        loss.detach()  # Trancate the graph
        if args.accum_grad_n_tokens == 0 or accum_n_tokens >= args.accum_grad_n_tokens:
            if args.clip_grad_norm > 0:
                total_norm = torch.nn.utils.clip_grad_norm_(
                    model.module.parameters(), args.clip_grad_norm)
                reporter.add_tensorboard_scalar('total_norm', total_norm)
            optimizer.step()
            optimizer.zero_grad()
            accum_n_tokens = 0
        loss_train = loss.item()
        del loss
        hidden = model.module.repackage_state(hidden)
        reporter.add_tensorboard_scalar('learning_rate', optimizer.lr)
        # NOTE: loss/acc/ppl are already added in the model
        reporter.step()

        if optimizer.n_steps % args.print_step == 0:
            # Compute loss in the dev set
            ys_dev = dev_set.next()[0]
            loss, _, reporter = model(ys_dev, None, reporter, is_eval=True)
            loss_dev = loss.item()
            del loss
            reporter.step(is_eval=True)

            duration_step = time.time() - start_time_step
            logger.info(
                "step:%d(ep:%.2f) loss:%.3f(%.3f)/ppl:%.3f(%.3f)/lr:%.5f/bs:%d (%.2f min)"
                % (optimizer.n_steps,
                   optimizer.n_epochs + train_set.epoch_detail, loss_train,
                   loss_dev, np.exp(loss_train), np.exp(loss_dev),
                   optimizer.lr, ys_train.shape[0], duration_step / 60))
            start_time_step = time.time()
        pbar_epoch.update(ys_train.shape[0] * (ys_train.shape[1] - 1))

        # Save fugures of loss and accuracy
        if optimizer.n_steps % (args.print_step * 10) == 0:
            reporter.snapshot()
            if args.lm_type == 'transformer':
                model.module.plot_attention()

        # Save checkpoint and evaluate model per epoch
        if is_new_epoch:
            duration_epoch = time.time() - start_time_epoch
            logger.info('========== EPOCH:%d (%.2f min) ==========' %
                        (optimizer.n_epochs + 1, duration_epoch / 60))

            if optimizer.n_epochs + 1 < args.eval_start_epoch:
                optimizer.epoch()  # lr decay
                reporter.epoch()  # plot

                # Save the model
                save_checkpoint(
                    model,
                    save_path,
                    optimizer,
                    optimizer.n_epochs,
                    remove_old_checkpoints=args.lm_type != 'transformer')
            else:
                start_time_eval = time.time()
                # dev
                ppl_dev, _ = eval_ppl([model.module],
                                      dev_set,
                                      batch_size=1,
                                      bptt=args.bptt)
                logger.info('PPL (%s, epoch:%d): %.2f' %
                            (dev_set.set, optimizer.n_epochs, ppl_dev))
                optimizer.epoch(ppl_dev)  # lr decay
                reporter.epoch(ppl_dev, name='perplexity')  # plot

                if optimizer.is_best:
                    # Save the model
                    save_checkpoint(
                        model,
                        save_path,
                        optimizer,
                        optimizer.n_epochs,
                        remove_old_checkpoints=args.lm_type != 'transformer')

                    # test
                    ppl_test_avg = 0.
                    for eval_set in eval_sets:
                        ppl_test, _ = eval_ppl([model.module],
                                               eval_set,
                                               batch_size=1,
                                               bptt=args.bptt)
                        logger.info(
                            'PPL (%s, epoch:%d): %.2f' %
                            (eval_set.set, optimizer.n_epochs, ppl_test))
                        ppl_test_avg += ppl_test
                    if len(eval_sets) > 0:
                        logger.info('PPL (avg., epoch:%d): %.2f' %
                                    (optimizer.n_epochs,
                                     ppl_test_avg / len(eval_sets)))

                duration_eval = time.time() - start_time_eval
                logger.info('Evaluation time: %.2f min' % (duration_eval / 60))

                # Early stopping
                if optimizer.is_early_stop:
                    break

                # Convert to fine-tuning stage
                if optimizer.n_epochs == args.convert_to_sgd_epoch:
                    n_epochs = optimizer.n_epochs
                    n_steps = optimizer.n_steps
                    optimizer = set_optimizer(model, 'sgd', args.lr,
                                              args.weight_decay)
                    optimizer = LRScheduler(optimizer,
                                            args.lr,
                                            decay_type='always',
                                            decay_start_epoch=0,
                                            decay_rate=0.5)
                    optimizer._epoch = n_epochs
                    optimizer._step = n_steps
                    logger.info('========== Convert to SGD ==========')

            pbar_epoch = tqdm(total=len(train_set))

            if optimizer.n_epochs == args.n_epochs:
                break

            start_time_step = time.time()
            start_time_epoch = time.time()

    duration_train = time.time() - start_time_train
    logger.info('Total time: %.2f hour' % (duration_train / 3600))

    reporter.tf_writer.close()
    pbar_epoch.close()

    return save_path
Esempio n. 29
0
def main():

    # Load configuration
    args, dir_name = parse_args_eval(sys.argv[1:])

    # Setting for logging
    if os.path.isfile(os.path.join(args.recog_dir, 'plot.log')):
        os.remove(os.path.join(args.recog_dir, 'plot.log'))
    set_logger(os.path.join(args.recog_dir, 'plot.log'),
               stdout=args.recog_stdout)

    for i, s in enumerate(args.recog_sets):
        # Load dataloader
        dataloader = build_dataloader(
            args=args,
            tsv_path=s,
            batch_size=1,
            is_test=True,
            first_n_utterances=args.recog_first_n_utt,
            longform_max_n_frames=args.recog_longform_max_n_frames)

        if i == 0:
            # Load ASR model
            model = Speech2Text(args, dir_name)
            epoch = int(float(args.recog_model[0].split('-')[-1]) * 10) / 10
            if args.recog_n_average > 1:
                # Model averaging for Transformer
                model = average_checkpoints(model,
                                            args.recog_model[0],
                                            n_average=args.recog_n_average)
            else:
                load_checkpoint(args.recog_model[0], model)

            if not args.recog_unit:
                args.recog_unit = args.unit

            logger.info('recog unit: %s' % args.recog_unit)
            logger.info('epoch: %d' % epoch)
            logger.info('batch size: %d' % args.recog_batch_size)

            # GPU setting
            if args.recog_n_gpus >= 1:
                model.cudnn_setting(deterministic=True, benchmark=False)
                model.cuda()

        save_path = mkdir_join(args.recog_dir, 'ctc_probs')

        # Clean directory
        if save_path is not None and os.path.isdir(save_path):
            shutil.rmtree(save_path)
            os.mkdir(save_path)

        for batch in dataloader:
            nbest_hyps_id, _ = model.decode(batch['xs'], args,
                                            dataloader.idx2token[0])
            best_hyps_id = [h[0] for h in nbest_hyps_id]

            # Get CTC probs
            ctc_probs, topk_ids, xlens = model.get_ctc_probs(batch['xs'],
                                                             temperature=1,
                                                             topk=min(
                                                                 100,
                                                                 model.vocab))
            # NOTE: ctc_probs: '[B, T, topk]'

            for b in range(len(batch['xs'])):
                tokens = dataloader.idx2token[0](best_hyps_id[b],
                                                 return_list=True)
                spk = batch['speakers'][b]

                plot_ctc_probs(
                    ctc_probs[b, :xlens[b]],
                    topk_ids[b],
                    factor=args.subsample_factor,
                    spectrogram=batch['xs'][b][:, :dataloader.input_dim],
                    save_path=mkdir_join(save_path, spk,
                                         batch['utt_ids'][b] + '.png'),
                    figsize=(20, 8))

                hyp = ' '.join(tokens)
                logger.info('utt-id: %s' % batch['utt_ids'][b])
                logger.info('Ref: %s' % batch['text'][b].lower())
                logger.info('Hyp: %s' % hyp)
                logger.info('-' * 50)
Esempio n. 30
0
def main():

    # Load configuration
    args, dir_name = parse_args_eval(sys.argv[1:])

    # Setting for logging
    if os.path.isfile(os.path.join(args.recog_dir, 'plot.log')):
        os.remove(os.path.join(args.recog_dir, 'plot.log'))
    set_logger(os.path.join(args.recog_dir, 'plot.log'), stdout=args.recog_stdout)

    for i, s in enumerate(args.recog_sets):
        # Load dataloader
        dataloader = build_dataloader(args=args,
                                      tsv_path=s,
                                      batch_size=1,
                                      is_test=True,
                                      first_n_utterances=args.recog_first_n_utt,
                                      longform_max_n_frames=args.recog_longform_max_n_frames)

        if i == 0:
            # Load ASR model
            model = Speech2Text(args, dir_name)
            epoch = int(float(args.recog_model[0].split('-')[-1]) * 10) / 10
            if args.recog_n_average > 1:
                # Model averaging for Transformer
                model = average_checkpoints(model, args.recog_model[0],
                                            n_average=args.recog_n_average)
            else:
                load_checkpoint(args.recog_model[0], model)

            # Ensemble (different models)
            ensemble_models = [model]
            if len(args.recog_model) > 1:
                for recog_model_e in args.recog_model[1:]:
                    conf_e = load_config(os.path.join(os.path.dirname(recog_model_e), 'conf.yml'))
                    args_e = copy.deepcopy(args)
                    for k, v in conf_e.items():
                        if 'recog' not in k:
                            setattr(args_e, k, v)
                    model_e = Speech2Text(args_e)
                    load_checkpoint(recog_model_e, model_e)
                    if args.recog_n_gpus >= 1:
                        model_e.cuda()
                    ensemble_models += [model_e]

            # Load LM for shallow fusion
            if not args.lm_fusion:
                # first path
                if args.recog_lm is not None and args.recog_lm_weight > 0:
                    conf_lm = load_config(os.path.join(os.path.dirname(args.recog_lm), 'conf.yml'))
                    args_lm = argparse.Namespace()
                    for k, v in conf_lm.items():
                        setattr(args_lm, k, v)
                    lm = build_lm(args_lm)
                    load_checkpoint(args.recog_lm, lm)
                    if args_lm.backward:
                        model.lm_bwd = lm
                    else:
                        model.lm_fwd = lm
                # NOTE: only support for first path

            if not args.recog_unit:
                args.recog_unit = args.unit

            logger.info('recog unit: %s' % args.recog_unit)
            logger.info('recog oracle: %s' % args.recog_oracle)
            logger.info('epoch: %d' % epoch)
            logger.info('batch size: %d' % args.recog_batch_size)
            logger.info('beam width: %d' % args.recog_beam_width)
            logger.info('min length ratio: %.3f' % args.recog_min_len_ratio)
            logger.info('max length ratio: %.3f' % args.recog_max_len_ratio)
            logger.info('length penalty: %.3f' % args.recog_length_penalty)
            logger.info('length norm: %s' % args.recog_length_norm)
            logger.info('coverage penalty: %.3f' % args.recog_coverage_penalty)
            logger.info('coverage threshold: %.3f' % args.recog_coverage_threshold)
            logger.info('CTC weight: %.3f' % args.recog_ctc_weight)
            logger.info('fist LM path: %s' % args.recog_lm)
            logger.info('LM weight: %.3f' % args.recog_lm_weight)
            logger.info('GNMT: %s' % args.recog_gnmt_decoding)
            logger.info('forward-backward attention: %s' % args.recog_fwd_bwd_attention)
            logger.info('resolving UNK: %s' % args.recog_resolving_unk)
            logger.info('ensemble: %d' % (len(ensemble_models)))
            logger.info('ASR decoder state carry over: %s' % (args.recog_asr_state_carry_over))
            logger.info('LM state carry over: %s' % (args.recog_lm_state_carry_over))
            logger.info('model average (Transformer): %d' % (args.recog_n_average))

            # GPU setting
            if args.recog_n_gpus >= 1:
                model.cudnn_setting(deterministic=True, benchmark=False)
                model.cuda()

        save_path = mkdir_join(args.recog_dir, 'att_weights')

        # Clean directory
        if save_path is not None and os.path.isdir(save_path):
            shutil.rmtree(save_path)
            os.mkdir(save_path)

        for batch in dataloader:
            nbest_hyps_id, aws = model.decode(
                batch['xs'], args, dataloader.idx2token[0],
                exclude_eos=False,
                refs_id=batch['ys'],
                ensemble_models=ensemble_models[1:] if len(ensemble_models) > 1 else [],
                speakers=batch['sessions'] if dataloader.corpus == 'swbd' else batch['speakers'])
            best_hyps_id = [h[0] for h in nbest_hyps_id]

            # Get CTC probs
            ctc_probs, topk_ids = None, None
            if args.ctc_weight > 0:
                ctc_probs, topk_ids, xlens = model.get_ctc_probs(
                    batch['xs'], task='ys', temperature=1, topk=min(100, model.vocab))
                # NOTE: ctc_probs: '[B, T, topk]'
            ctc_probs_sub1, topk_ids_sub1 = None, None
            if args.ctc_weight_sub1 > 0:
                ctc_probs_sub1, topk_ids_sub1, xlens_sub1 = model.get_ctc_probs(
                    batch['xs'], task='ys_sub1', temperature=1, topk=min(100, model.vocab_sub1))

            if model.bwd_weight > 0.5:
                # Reverse the order
                best_hyps_id = [hyp[::-1] for hyp in best_hyps_id]
                aws = [[aw[0][:, ::-1]] for aw in aws]

            for b in range(len(batch['xs'])):
                tokens = dataloader.idx2token[0](best_hyps_id[b], return_list=True)
                spk = batch['speakers'][b]

                plot_attention_weights(
                    aws[b][0][:, :len(tokens)], tokens,
                    spectrogram=batch['xs'][b][:, :dataloader.input_dim] if args.input_type == 'speech' else None,
                    factor=args.subsample_factor,
                    ref=batch['text'][b].lower(),
                    save_path=mkdir_join(save_path, spk, batch['utt_ids'][b] + '.png'),
                    figsize=(20, 8),
                    ctc_probs=ctc_probs[b, :xlens[b]] if ctc_probs is not None else None,
                    ctc_topk_ids=topk_ids[b] if topk_ids is not None else None,
                    ctc_probs_sub1=ctc_probs_sub1[b, :xlens_sub1[b]] if ctc_probs_sub1 is not None else None,
                    ctc_topk_ids_sub1=topk_ids_sub1[b] if topk_ids_sub1 is not None else None)

                if model.bwd_weight > 0.5:
                    hyp = ' '.join(tokens[::-1])
                else:
                    hyp = ' '.join(tokens)
                logger.info('utt-id: %s' % batch['utt_ids'][b])
                logger.info('Ref: %s' % batch['text'][b].lower())
                logger.info('Hyp: %s' % hyp)
                logger.info('-' * 50)