Beispiel #1
0
def eval_epoch(models, dataset, recog_params, args, epoch, logger):
    if args.metric == 'edit_distance':
        if args.unit in ['word', 'word_char']:
            metric = eval_word(models, dataset, recog_params, epoch=epoch)[0]
            logger.info('WER (%s, ep:%d): %.2f %%' %
                        (dataset.set, epoch, metric))
        elif args.unit == 'wp':
            metric, cer = eval_wordpiece(models,
                                         dataset,
                                         recog_params,
                                         epoch=epoch)
            logger.info('WER (%s, ep:%d): %.2f %%' %
                        (dataset.set, epoch, metric))
            logger.info('CER (%s, ep:%d): %.2f %%' % (dataset.set, epoch, cer))
        elif 'char' in args.unit:
            metric, cer = eval_char(models, dataset, recog_params, epoch=epoch)
            logger.info('WER (%s, ep:%d): %.2f %%' %
                        (dataset.set, epoch, metric))
            logger.info('CER (%s, ep:%d): %.2f %%' % (dataset.set, epoch, cer))
        elif 'phone' in args.unit:
            metric = eval_phone(models, dataset, recog_params, epoch=epoch)
            logger.info('PER (%s, ep:%d): %.2f %%' %
                        (dataset.set, epoch, metric))
    elif args.metric == 'ppl':
        metric = eval_ppl(models, dataset, batch_size=args.batch_size)[0]
        logger.info('PPL (%s, ep:%d): %.2f' % (dataset.set, epoch, metric))
    elif args.metric == 'loss':
        metric = eval_ppl(models, dataset, batch_size=args.batch_size)[1]
        logger.info('Loss (%s, ep:%d): %.2f' % (dataset.set, epoch, metric))
    else:
        raise NotImplementedError(args.metric)
    return metric
Beispiel #2
0
def evaluate(models, dataloader, recog_params, args, epoch, logger):

    if args.metric == 'edit_distance':
        if args.unit in ['word', 'word_char']:
            metric = eval_word(models, dataloader, recog_params,
                               epoch=epoch)[0]
            logger.info('WER (%s, ep:%d): %.2f %%' %
                        (dataloader.set, epoch, metric))

        elif args.unit == 'wp':
            metric, cer = eval_wordpiece(models,
                                         dataloader,
                                         recog_params,
                                         epoch=epoch)
            logger.info('WER (%s, ep:%d): %.2f %%' %
                        (dataloader.set, epoch, metric))
            logger.info('CER (%s, ep:%d): %.2f %%' %
                        (dataloader.set, epoch, cer))

        elif 'char' in args.unit:
            wer, cer = eval_char(models, dataloader, recog_params, epoch=epoch)
            logger.info('WER (%s, ep:%d): %.2f %%' %
                        (dataloader.set, epoch, wer))
            logger.info('CER (%s, ep:%d): %.2f %%' %
                        (dataloader.set, epoch, cer))
            if dataloader.corpus in ['aishell1']:
                metric = cer
            else:
                metric = wer

        elif 'phone' in args.unit:
            metric = eval_phone(models, dataloader, recog_params, epoch=epoch)
            logger.info('PER (%s, ep:%d): %.2f %%' %
                        (dataloader.set, epoch, metric))

    elif args.metric == 'ppl':
        metric = eval_ppl(models, dataloader, batch_size=args.batch_size)[0]
        logger.info('PPL (%s, ep:%d): %.3f' % (dataloader.set, epoch, metric))

    elif args.metric == 'loss':
        metric = eval_ppl(models, dataloader, batch_size=args.batch_size)[1]
        logger.info('Loss (%s, ep:%d): %.5f' % (dataloader.set, epoch, metric))

    elif args.metric == 'accuracy':
        metric = eval_accuracy(models, dataloader, batch_size=args.batch_size)
        logger.info('Accuracy (%s, ep:%d): %.3f' %
                    (dataloader.set, epoch, metric))

    elif args.metric == 'bleu':
        metric = eval_wordpiece_bleu(models,
                                     dataloader,
                                     recog_params,
                                     epoch=epoch)
        logger.info('BLEU (%s, ep:%d): %.3f' % (dataloader.set, epoch, metric))

    else:
        raise NotImplementedError(args.metric)

    return metric
Beispiel #3
0
def main():

    # Load configuration
    args, _, dir_name = parse_args_eval(sys.argv[1:])

    # Setting for logging
    if os.path.isfile(os.path.join(args.recog_dir, 'decode.log')):
        os.remove(os.path.join(args.recog_dir, 'decode.log'))
    set_logger(os.path.join(args.recog_dir, 'decode.log'),
               stdout=args.recog_stdout)

    ppl_avg = 0
    for i, s in enumerate(args.recog_sets):
        # Load dataset
        dataset = Dataset(corpus=args.corpus,
                          tsv_path=s,
                          dict_path=os.path.join(dir_name, 'dict.txt'),
                          wp_model=os.path.join(dir_name, 'wp.model'),
                          unit=args.unit,
                          batch_size=args.recog_batch_size,
                          bptt=args.bptt,
                          backward=args.backward,
                          serialize=args.serialize,
                          is_test=True)

        if i == 0:
            # Load the LM
            model = build_lm(args)
            load_checkpoint(args.recog_model[0], model)
            epoch = int(args.recog_model[0].split('-')[-1])
            # NOTE: model averaging is not helpful for LM

            logger.info('epoch: %d' % epoch)
            logger.info('batch size: %d' % args.recog_batch_size)
            logger.info('BPTT: %d' % (args.bptt))
            logger.info('cache size: %d' % (args.recog_n_caches))
            logger.info('cache theta: %.3f' % (args.recog_cache_theta))
            logger.info('cache lambda: %.3f' % (args.recog_cache_lambda))
            logger.info('model average (Transformer): %d' %
                        (args.recog_n_average))
            model.cache_theta = args.recog_cache_theta
            model.cache_lambda = args.recog_cache_lambda

            # GPU setting
            if args.recog_n_gpus > 0:
                model.cuda()

        start_time = time.time()

        ppl, _ = eval_ppl([model],
                          dataset,
                          batch_size=1,
                          bptt=args.bptt,
                          n_caches=args.recog_n_caches,
                          progressbar=True)
        ppl_avg += ppl
        print('PPL (%s): %.2f' % (dataset.set, ppl))
        logger.info('Elasped time: %.2f [sec]:' % (time.time() - start_time))

    logger.info('PPL (avg.): %.2f\n' % (ppl_avg / len(args.recog_sets)))
Beispiel #4
0
def main():

    args = parse_args_train(sys.argv[1:])

    # Load a conf file
    if args.resume:
        conf = load_config(
            os.path.join(os.path.dirname(args.resume), 'conf.yml'))
        for k, v in conf.items():
            if k != 'resume':
                setattr(args, k, v)

    # Load dataset
    batch_size = args.batch_size * args.n_gpus if args.n_gpus >= 1 else args.batch_size
    train_set = Dataset(corpus=args.corpus,
                        tsv_path=args.train_set,
                        dict_path=args.dict,
                        nlsyms=args.nlsyms,
                        unit=args.unit,
                        wp_model=args.wp_model,
                        batch_size=batch_size,
                        n_epochs=args.n_epochs,
                        min_n_tokens=args.min_n_tokens,
                        bptt=args.bptt,
                        shuffle=args.shuffle,
                        backward=args.backward,
                        serialize=args.serialize)
    dev_set = Dataset(corpus=args.corpus,
                      tsv_path=args.dev_set,
                      dict_path=args.dict,
                      nlsyms=args.nlsyms,
                      unit=args.unit,
                      wp_model=args.wp_model,
                      batch_size=batch_size,
                      bptt=args.bptt,
                      backward=args.backward,
                      serialize=args.serialize)
    eval_sets = [
        Dataset(corpus=args.corpus,
                tsv_path=s,
                dict_path=args.dict,
                nlsyms=args.nlsyms,
                unit=args.unit,
                wp_model=args.wp_model,
                batch_size=1,
                bptt=args.bptt,
                backward=args.backward,
                serialize=args.serialize) for s in args.eval_sets
    ]

    args.vocab = train_set.vocab

    # Set save path
    if args.resume:
        save_path = os.path.dirname(args.resume)
        dir_name = os.path.basename(save_path)
    else:
        dir_name = set_lm_name(args)
        save_path = mkdir_join(
            args.model_save_dir,
            '_'.join(os.path.basename(args.train_set).split('.')[:-1]),
            dir_name)
        save_path = set_save_path(save_path)  # avoid overwriting

    # Set logger
    set_logger(os.path.join(save_path, 'train.log'), stdout=args.stdout)

    # Model setting
    model = build_lm(args, save_path)

    if not args.resume:
        # Save the conf file as a yaml file
        save_config(vars(args), os.path.join(save_path, 'conf.yml'))

        # Save the nlsyms, dictionary, and wp_model
        if args.nlsyms:
            shutil.copy(args.nlsyms, os.path.join(save_path, 'nlsyms.txt'))
        shutil.copy(args.dict, os.path.join(save_path, 'dict.txt'))
        if args.unit == 'wp':
            shutil.copy(args.wp_model, os.path.join(save_path, 'wp.model'))

        for k, v in sorted(vars(args).items(), key=lambda x: x[0]):
            logger.info('%s: %s' % (k, str(v)))

        # Count total parameters
        for n in sorted(list(model.num_params_dict.keys())):
            n_params = model.num_params_dict[n]
            logger.info("%s %d" % (n, n_params))
        logger.info("Total %.2f M parameters" %
                    (model.total_parameters / 1000000))
        logger.info(model)

    # Set optimizer
    resume_epoch = 0
    if args.resume:
        epoch = int(args.resume.split('-')[-1])
        optimizer = set_optimizer(
            model,
            'sgd' if epoch > args.convert_to_sgd_epoch else args.optimizer,
            args.lr, args.weight_decay)
    else:
        optimizer = set_optimizer(model, args.optimizer, args.lr,
                                  args.weight_decay)

    # Wrap optimizer by learning rate scheduler
    is_transformer = args.lm_type in ['transformer', 'transformer_xl']
    optimizer = LRScheduler(
        optimizer,
        args.lr,
        decay_type=args.lr_decay_type,
        decay_start_epoch=args.lr_decay_start_epoch,
        decay_rate=args.lr_decay_rate,
        decay_patient_n_epochs=args.lr_decay_patient_n_epochs,
        early_stop_patient_n_epochs=args.early_stop_patient_n_epochs,
        warmup_start_lr=args.warmup_start_lr,
        warmup_n_steps=args.warmup_n_steps,
        model_size=getattr(args, 'transformer_d_model', 0),
        factor=args.lr_factor,
        noam=is_transformer,
        save_checkpoints_topk=1)

    if args.resume:
        # Restore the last saved model
        load_checkpoint(args.resume, model, optimizer)

        # Resume between convert_to_sgd_epoch -1 and convert_to_sgd_epoch
        if resume_epoch == args.convert_to_sgd_epoch:
            optimizer.convert_to_sgd(model,
                                     args.lr,
                                     args.weight_decay,
                                     decay_type='always',
                                     decay_rate=0.5)

    # GPU setting
    use_apex = args.train_dtype in ["O0", "O1", "O2", "O3"]
    amp = None
    if args.n_gpus >= 1:
        model.cudnn_setting(
            deterministic=not (is_transformer or args.cudnn_benchmark),
            benchmark=args.cudnn_benchmark)
        model.cuda()

        # Mix precision training setting
        if use_apex:
            from apex import amp
            model, optimizer.optimizer = amp.initialize(
                model, optimizer.optimizer, opt_level=args.train_dtype)
            amp.init()
            if args.resume:
                load_checkpoint(args.resume, amp=amp)
        model = CustomDataParallel(model,
                                   device_ids=list(range(0, args.n_gpus)))
    else:
        model = CPUWrapperLM(model)

    # Set process name
    logger.info('PID: %s' % os.getpid())
    logger.info('USERNAME: %s' % os.uname()[1])
    logger.info('#GPU: %d' % torch.cuda.device_count())
    setproctitle(args.job_name if args.job_name else dir_name)

    # Set reporter
    reporter = Reporter(save_path)

    hidden = None
    start_time_train = time.time()
    start_time_epoch = time.time()
    start_time_step = time.time()
    pbar_epoch = tqdm(total=len(train_set))
    accum_n_steps = 0
    n_steps = optimizer.n_steps * args.accum_grad_n_steps
    while True:
        # Compute loss in the training set
        ys_train, is_new_epoch = train_set.next()
        accum_n_steps += 1

        loss, hidden, observation = model(ys_train, hidden)
        reporter.add(observation)
        if use_apex:
            with amp.scale_loss(loss, optimizer.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        loss.detach()  # Trancate the graph
        if args.accum_grad_n_steps == 1 or accum_n_steps >= args.accum_grad_n_steps:
            if args.clip_grad_norm > 0:
                total_norm = torch.nn.utils.clip_grad_norm_(
                    model.module.parameters(), args.clip_grad_norm)
                reporter.add_tensorboard_scalar('total_norm', total_norm)
            optimizer.step()
            optimizer.zero_grad()
            accum_n_steps = 0
        loss_train = loss.item()
        del loss
        hidden = model.module.repackage_state(hidden)
        reporter.add_tensorboard_scalar('learning_rate', optimizer.lr)
        # NOTE: loss/acc/ppl are already added in the model
        reporter.step()
        pbar_epoch.update(ys_train.shape[0] * (ys_train.shape[1] - 1))
        n_steps += 1
        # NOTE: n_steps is different from the step counter in Noam Optimizer

        if n_steps % args.print_step == 0:
            # Compute loss in the dev set
            ys_dev = dev_set.next(bptt=args.bptt)[0]
            loss, _, observation = model(ys_dev, None, is_eval=True)
            reporter.add(observation, is_eval=True)
            loss_dev = loss.item()
            del loss
            reporter.step(is_eval=True)

            duration_step = time.time() - start_time_step
            logger.info(
                "step:%d(ep:%.2f) loss:%.3f(%.3f)/lr:%.5f/bs:%d (%.2f min)" %
                (n_steps, optimizer.n_epochs + train_set.epoch_detail,
                 loss_train, loss_dev, optimizer.lr, ys_train.shape[0],
                 duration_step / 60))
            start_time_step = time.time()

        # Save fugures of loss and accuracy
        if n_steps % (args.print_step * 10) == 0:
            reporter.snapshot()
            model.module.plot_attention()

        # Save checkpoint and evaluate model per epoch
        if is_new_epoch:
            duration_epoch = time.time() - start_time_epoch
            logger.info('========== EPOCH:%d (%.2f min) ==========' %
                        (optimizer.n_epochs + 1, duration_epoch / 60))

            if optimizer.n_epochs + 1 < args.eval_start_epoch:
                optimizer.epoch()  # lr decay
                reporter.epoch()  # plot

                # Save the model
                optimizer.save_checkpoint(model,
                                          save_path,
                                          remove_old=not is_transformer,
                                          amp=amp)
            else:
                start_time_eval = time.time()
                # dev
                model.module.reset_length(args.bptt)
                ppl_dev, _ = eval_ppl([model.module],
                                      dev_set,
                                      batch_size=1,
                                      bptt=args.bptt)
                model.module.reset_length(args.bptt)
                optimizer.epoch(ppl_dev)  # lr decay
                reporter.epoch(ppl_dev, name='perplexity')  # plot
                logger.info('PPL (%s, ep:%d): %.2f' %
                            (dev_set.set, optimizer.n_epochs, ppl_dev))

                if optimizer.is_topk or is_transformer:
                    # Save the model
                    optimizer.save_checkpoint(model,
                                              save_path,
                                              remove_old=not is_transformer,
                                              amp=amp)

                    # test
                    ppl_test_avg = 0.
                    for eval_set in eval_sets:
                        model.module.reset_length(args.bptt)
                        ppl_test, _ = eval_ppl([model.module],
                                               eval_set,
                                               batch_size=1,
                                               bptt=args.bptt)
                        model.module.reset_length(args.bptt)
                        logger.info(
                            'PPL (%s, ep:%d): %.2f' %
                            (eval_set.set, optimizer.n_epochs, ppl_test))
                        ppl_test_avg += ppl_test
                    if len(eval_sets) > 0:
                        logger.info('PPL (avg., ep:%d): %.2f' %
                                    (optimizer.n_epochs,
                                     ppl_test_avg / len(eval_sets)))

                duration_eval = time.time() - start_time_eval
                logger.info('Evaluation time: %.2f min' % (duration_eval / 60))

                # Early stopping
                if optimizer.is_early_stop:
                    break

                # Convert to fine-tuning stage
                if optimizer.n_epochs == args.convert_to_sgd_epoch:
                    optimizer.convert_to_sgd(model,
                                             args.lr,
                                             args.weight_decay,
                                             decay_type='always',
                                             decay_rate=0.5)

            pbar_epoch = tqdm(total=len(train_set))

            if optimizer.n_epochs >= args.n_epochs:
                break

            start_time_step = time.time()
            start_time_epoch = time.time()

    duration_train = time.time() - start_time_train
    logger.info('Total time: %.2f hour' % (duration_train / 3600))

    reporter.tf_writer.close()
    pbar_epoch.close()

    return save_path
Beispiel #5
0
def main():

    args = parse()

    # Load a conf file
    dir_name = os.path.dirname(args.recog_model[0])
    conf = load_config(os.path.join(dir_name, 'conf.yml'))

    # Overwrite conf
    for k, v in conf.items():
        if 'recog' not in k:
            setattr(args, k, v)
    recog_params = vars(args)

    # Setting for logging
    if os.path.isfile(os.path.join(args.recog_dir, 'decode.log')):
        os.remove(os.path.join(args.recog_dir, 'decode.log'))
    logger = set_logger(os.path.join(args.recog_dir, 'decode.log'),
                        key='decoding')

    skip_thought = 'skip' in args.enc_type

    wer_avg, cer_avg, per_avg = 0, 0, 0
    ppl_avg, loss_avg = 0, 0
    for i, s in enumerate(args.recog_sets):
        # Load dataset
        dataset = Dataset(
            corpus=args.corpus,
            tsv_path=s,
            dict_path=os.path.join(dir_name, 'dict.txt'),
            dict_path_sub1=os.path.join(dir_name, 'dict_sub1.txt') if
            os.path.isfile(os.path.join(dir_name, 'dict_sub1.txt')) else False,
            dict_path_sub2=os.path.join(dir_name, 'dict_sub2.txt') if
            os.path.isfile(os.path.join(dir_name, 'dict_sub2.txt')) else False,
            nlsyms=os.path.join(dir_name, 'nlsyms.txt'),
            wp_model=os.path.join(dir_name, 'wp.model'),
            wp_model_sub1=os.path.join(dir_name, 'wp_sub1.model'),
            wp_model_sub2=os.path.join(dir_name, 'wp_sub2.model'),
            unit=args.unit,
            unit_sub1=args.unit_sub1,
            unit_sub2=args.unit_sub2,
            batch_size=args.recog_batch_size,
            skip_thought=skip_thought,
            is_test=True)

        if i == 0:
            # Load the ASR model
            if skip_thought:
                model = SkipThought(args, dir_name)
            else:
                model = Speech2Text(args, dir_name)
            model, checkpoint = load_checkpoint(model, args.recog_model[0])
            epoch = checkpoint['epoch']

            # ensemble (different models)
            ensemble_models = [model]
            if len(args.recog_model) > 1:
                for recog_model_e in args.recog_model[1:]:
                    conf_e = load_config(
                        os.path.join(os.path.dirname(recog_model_e),
                                     'conf.yml'))
                    args_e = copy.deepcopy(args)
                    for k, v in conf_e.items():
                        if 'recog' not in k:
                            setattr(args_e, k, v)
                    model_e = Speech2Text(args_e)
                    model_e, _ = load_checkpoint(model_e, recog_model_e)
                    model_e.cuda()
                    ensemble_models += [model_e]

            # Load the LM for shallow fusion
            if not args.lm_fusion:
                if args.recog_lm is not None and args.recog_lm_weight > 0:
                    conf_lm = load_config(
                        os.path.join(os.path.dirname(args.recog_lm),
                                     'conf.yml'))
                    args_lm = argparse.Namespace()
                    for k, v in conf_lm.items():
                        setattr(args_lm, k, v)
                    lm = select_lm(args_lm)
                    lm, _ = load_checkpoint(lm, args.recog_lm)
                    if args_lm.backward:
                        model.lm_bwd = lm
                    else:
                        model.lm_fwd = lm

                if args.recog_lm_bwd is not None and args.recog_lm_weight > 0 \
                        and (args.recog_fwd_bwd_attention or args.recog_reverse_lm_rescoring):
                    conf_lm = load_config(
                        os.path.join(os.path.dirname(args.recog_lm_bwd),
                                     'conf.yml'))
                    args_lm_bwd = argparse.Namespace()
                    for k, v in conf_lm.items():
                        setattr(args_lm_bwd, k, v)
                    lm_bwd = select_lm(args_lm_bwd)
                    lm_bwd, _ = load_checkpoint(lm_bwd, args.recog_lm_bwd)
                    model.lm_bwd = lm_bwd

            if not args.recog_unit:
                args.recog_unit = args.unit

            logger.info('recog unit: %s' % args.recog_unit)
            logger.info('recog metric: %s' % args.recog_metric)
            logger.info('recog oracle: %s' % args.recog_oracle)
            logger.info('epoch: %d' % (epoch - 1))
            logger.info('batch size: %d' % args.recog_batch_size)
            logger.info('beam width: %d' % args.recog_beam_width)
            logger.info('min length ratio: %.3f' % args.recog_min_len_ratio)
            logger.info('max length ratio: %.3f' % args.recog_max_len_ratio)
            logger.info('length penalty: %.3f' % args.recog_length_penalty)
            logger.info('coverage penalty: %.3f' % args.recog_coverage_penalty)
            logger.info('coverage threshold: %.3f' %
                        args.recog_coverage_threshold)
            logger.info('CTC weight: %.3f' % args.recog_ctc_weight)
            logger.info('LM path: %s' % args.recog_lm)
            logger.info('LM path (bwd): %s' % args.recog_lm_bwd)
            logger.info('LM weight: %.3f' % args.recog_lm_weight)
            logger.info('GNMT: %s' % args.recog_gnmt_decoding)
            logger.info('forward-backward attention: %s' %
                        args.recog_fwd_bwd_attention)
            logger.info('reverse LM rescoring: %s' %
                        args.recog_reverse_lm_rescoring)
            logger.info('resolving UNK: %s' % args.recog_resolving_unk)
            logger.info('ensemble: %d' % (len(ensemble_models)))
            logger.info('ASR decoder state carry over: %s' %
                        (args.recog_asr_state_carry_over))
            logger.info('LM state carry over: %s' %
                        (args.recog_lm_state_carry_over))
            logger.info('cache size: %d' % (args.recog_n_caches))
            logger.info('cache type: %s' % (args.recog_cache_type))
            logger.info('cache word frequency threshold: %s' %
                        (args.recog_cache_word_freq))
            logger.info('cache theta (speech): %.3f' %
                        (args.recog_cache_theta_speech))
            logger.info('cache lambda (speech): %.3f' %
                        (args.recog_cache_lambda_speech))
            logger.info('cache theta (lm): %.3f' % (args.recog_cache_theta_lm))
            logger.info('cache lambda (lm): %.3f' %
                        (args.recog_cache_lambda_lm))

            # GPU setting
            model.cuda()

        start_time = time.time()

        if args.recog_metric == 'edit_distance':
            if args.recog_unit in ['word', 'word_char']:
                wer, cer, _ = eval_word(ensemble_models,
                                        dataset,
                                        recog_params,
                                        epoch=epoch - 1,
                                        recog_dir=args.recog_dir,
                                        progressbar=True)
                wer_avg += wer
                cer_avg += cer
            elif args.recog_unit == 'wp':
                wer, cer = eval_wordpiece(ensemble_models,
                                          dataset,
                                          recog_params,
                                          epoch=epoch - 1,
                                          recog_dir=args.recog_dir,
                                          progressbar=True)
                wer_avg += wer
                cer_avg += cer
            elif 'char' in args.recog_unit:
                wer, cer = eval_char(ensemble_models,
                                     dataset,
                                     recog_params,
                                     epoch=epoch - 1,
                                     recog_dir=args.recog_dir,
                                     progressbar=True,
                                     task_idx=0)
                #  task_idx=1 if args.recog_unit and 'char' in args.recog_unit else 0)
                wer_avg += wer
                cer_avg += cer
            elif 'phone' in args.recog_unit:
                per = eval_phone(ensemble_models,
                                 dataset,
                                 recog_params,
                                 epoch=epoch - 1,
                                 recog_dir=args.recog_dir,
                                 progressbar=True)
                per_avg += per
            else:
                raise ValueError(args.recog_unit)
        elif args.recog_metric == 'acc':
            raise NotImplementedError
        elif args.recog_metric in ['ppl', 'loss']:
            ppl, loss = eval_ppl(ensemble_models,
                                 dataset,
                                 recog_params=recog_params,
                                 progressbar=True)
            ppl_avg += ppl
            loss_avg += loss
        elif args.recog_metric == 'bleu':
            raise NotImplementedError
        else:
            raise NotImplementedError
        logger.info('Elasped time: %.2f [sec]:' % (time.time() - start_time))

    if args.recog_metric == 'edit_distance':
        if 'phone' in args.recog_unit:
            logger.info('PER (avg.): %.2f %%\n' %
                        (per_avg / len(args.recog_sets)))
        else:
            logger.info('WER / CER (avg.): %.2f / %.2f %%\n' %
                        (wer_avg / len(args.recog_sets),
                         cer_avg / len(args.recog_sets)))
    elif args.recog_metric in ['ppl', 'loss']:
        logger.info('PPL (avg.): %.2f\n' % (ppl_avg / len(args.recog_sets)))
        print('PPL (avg.): %.2f' % (ppl_avg / len(args.recog_sets)))
        logger.info('Loss (avg.): %.2f\n' % (loss_avg / len(args.recog_sets)))
        print('Loss (avg.): %.2f' % (loss_avg / len(args.recog_sets)))
Beispiel #6
0
def main():

    # Load configuration
    args, dir_name = parse_args_eval(sys.argv[1:])

    # Setting for logging
    if os.path.isfile(os.path.join(args.recog_dir, 'decode.log')):
        os.remove(os.path.join(args.recog_dir, 'decode.log'))
    set_logger(os.path.join(args.recog_dir, 'decode.log'),
               stdout=args.recog_stdout)

    wer_avg, cer_avg, per_avg = 0, 0, 0
    ppl_avg, loss_avg = 0, 0
    acc_avg = 0
    bleu_avg = 0
    for i, s in enumerate(args.recog_sets):
        # Load dataloader
        dataloader = build_dataloader(
            args=args,
            tsv_path=s,
            batch_size=1,
            is_test=True,
            first_n_utterances=args.recog_first_n_utt,
            longform_max_n_frames=args.recog_longform_max_n_frames)

        if i == 0:
            # Load ASR model
            model = Speech2Text(args, dir_name)
            epoch = int(float(args.recog_model[0].split('-')[-1]) * 10) / 10
            if args.recog_n_average > 1:
                # Model averaging for Transformer
                # topk_list = load_checkpoint(args.recog_model[0], model)
                model = average_checkpoints(
                    model,
                    args.recog_model[0],
                    # topk_list=topk_list,
                    n_average=args.recog_n_average)
            else:
                load_checkpoint(args.recog_model[0], model)

            # Ensemble (different models)
            ensemble_models = [model]
            if len(args.recog_model) > 1:
                for recog_model_e in args.recog_model[1:]:
                    conf_e = load_config(
                        os.path.join(os.path.dirname(recog_model_e),
                                     'conf.yml'))
                    args_e = copy.deepcopy(args)
                    for k, v in conf_e.items():
                        if 'recog' not in k:
                            setattr(args_e, k, v)
                    model_e = Speech2Text(args_e)
                    load_checkpoint(recog_model_e, model_e)
                    if args.recog_n_gpus >= 1:
                        model_e.cuda()
                    ensemble_models += [model_e]

            # Load LM for shallow fusion
            if not args.lm_fusion:
                # first path
                if args.recog_lm is not None and args.recog_lm_weight > 0:
                    conf_lm = load_config(
                        os.path.join(os.path.dirname(args.recog_lm),
                                     'conf.yml'))
                    args_lm = argparse.Namespace()
                    for k, v in conf_lm.items():
                        setattr(args_lm, k, v)
                    args_lm.recog_mem_len = args.recog_mem_len
                    lm = build_lm(args_lm,
                                  wordlm=args.recog_wordlm,
                                  lm_dict_path=os.path.join(
                                      os.path.dirname(args.recog_lm),
                                      'dict.txt'),
                                  asr_dict_path=os.path.join(
                                      dir_name, 'dict.txt'))
                    load_checkpoint(args.recog_lm, lm)
                    if args_lm.backward:
                        model.lm_bwd = lm
                    else:
                        model.lm_fwd = lm

                # second path (forward)
                if args.recog_lm_second is not None and args.recog_lm_second_weight > 0:
                    conf_lm_second = load_config(
                        os.path.join(os.path.dirname(args.recog_lm_second),
                                     'conf.yml'))
                    args_lm_second = argparse.Namespace()
                    for k, v in conf_lm_second.items():
                        setattr(args_lm_second, k, v)
                    args_lm_second.recog_mem_len = args.recog_mem_len
                    lm_second = build_lm(args_lm_second)
                    load_checkpoint(args.recog_lm_second, lm_second)
                    model.lm_second = lm_second

                # second path (backward)
                if args.recog_lm_bwd is not None and args.recog_lm_bwd_weight > 0:
                    conf_lm = load_config(
                        os.path.join(os.path.dirname(args.recog_lm_bwd),
                                     'conf.yml'))
                    args_lm_bwd = argparse.Namespace()
                    for k, v in conf_lm.items():
                        setattr(args_lm_bwd, k, v)
                    args_lm_bwd.recog_mem_len = args.recog_mem_len
                    lm_bwd = build_lm(args_lm_bwd)
                    load_checkpoint(args.recog_lm_bwd, lm_bwd)
                    model.lm_bwd = lm_bwd

            if not args.recog_unit:
                args.recog_unit = args.unit

            logger.info('recog unit: %s' % args.recog_unit)
            logger.info('recog metric: %s' % args.recog_metric)
            logger.info('recog oracle: %s' % args.recog_oracle)
            logger.info('epoch: %d' % epoch)
            logger.info('batch size: %d' % args.recog_batch_size)
            logger.info('beam width: %d' % args.recog_beam_width)
            logger.info('min length ratio: %.3f' % args.recog_min_len_ratio)
            logger.info('max length ratio: %.3f' % args.recog_max_len_ratio)
            logger.info('length penalty: %.3f' % args.recog_length_penalty)
            logger.info('length norm: %s' % args.recog_length_norm)
            logger.info('coverage penalty: %.3f' % args.recog_coverage_penalty)
            logger.info('coverage threshold: %.3f' %
                        args.recog_coverage_threshold)
            logger.info('CTC weight: %.3f' % args.recog_ctc_weight)
            logger.info('fist LM path: %s' % args.recog_lm)
            logger.info('second LM path: %s' % args.recog_lm_second)
            logger.info('backward LM path: %s' % args.recog_lm_bwd)
            logger.info('LM weight (first-pass): %.3f' % args.recog_lm_weight)
            logger.info('LM weight (second-pass): %.3f' %
                        args.recog_lm_second_weight)
            logger.info('LM weight (backward): %.3f' %
                        args.recog_lm_bwd_weight)
            logger.info('GNMT: %s' % args.recog_gnmt_decoding)
            logger.info('forward-backward attention: %s' %
                        args.recog_fwd_bwd_attention)
            logger.info('resolving UNK: %s' % args.recog_resolving_unk)
            logger.info('ensemble: %d' % (len(ensemble_models)))
            logger.info('ASR decoder state carry over: %s' %
                        (args.recog_asr_state_carry_over))
            logger.info('LM state carry over: %s' %
                        (args.recog_lm_state_carry_over))
            logger.info('model average (Transformer): %d' %
                        (args.recog_n_average))

            # GPU setting
            if args.recog_n_gpus >= 1:
                model.cudnn_setting(deterministic=True, benchmark=False)
                model.cuda()

        start_time = time.time()

        if args.recog_metric == 'edit_distance':
            if args.recog_unit in ['word', 'word_char']:
                wer, cer, _ = eval_word(ensemble_models,
                                        dataloader,
                                        args,
                                        epoch=epoch - 1,
                                        recog_dir=args.recog_dir,
                                        progressbar=True,
                                        fine_grained=True,
                                        oracle=True)
                wer_avg += wer
                cer_avg += cer
            elif args.recog_unit == 'wp':
                wer, cer = eval_wordpiece(ensemble_models,
                                          dataloader,
                                          args,
                                          epoch=epoch - 1,
                                          recog_dir=args.recog_dir,
                                          streaming=args.recog_streaming,
                                          progressbar=True,
                                          fine_grained=True,
                                          oracle=True)
                wer_avg += wer
                cer_avg += cer
            elif 'char' in args.recog_unit:
                wer, cer = eval_char(ensemble_models,
                                     dataloader,
                                     args,
                                     epoch=epoch - 1,
                                     recog_dir=args.recog_dir,
                                     progressbar=True,
                                     task_idx=0,
                                     fine_grained=True,
                                     oracle=True)
                #  task_idx=1 if args.recog_unit and 'char' in args.recog_unit else 0)
                wer_avg += wer
                cer_avg += cer
            elif 'phone' in args.recog_unit:
                per = eval_phone(ensemble_models,
                                 dataloader,
                                 args,
                                 epoch=epoch - 1,
                                 recog_dir=args.recog_dir,
                                 progressbar=True,
                                 fine_grained=True,
                                 oracle=True)
                per_avg += per
            else:
                raise ValueError(args.recog_unit)
        elif args.recog_metric in ['ppl', 'loss']:
            ppl, loss = eval_ppl(ensemble_models, dataloader, progressbar=True)
            ppl_avg += ppl
            loss_avg += loss
        elif args.recog_metric == 'accuracy':
            acc_avg += eval_accuracy(ensemble_models,
                                     dataloader,
                                     progressbar=True)
        elif args.recog_metric == 'bleu':
            bleu = eval_wordpiece_bleu(ensemble_models,
                                       dataloader,
                                       args,
                                       epoch=epoch - 1,
                                       recog_dir=args.recog_dir,
                                       streaming=args.recog_streaming,
                                       progressbar=True,
                                       fine_grained=True,
                                       oracle=True)
            bleu_avg += bleu
        else:
            raise NotImplementedError(args.recog_metric)
        elapsed_time = time.time() - start_time
        logger.info('Elapsed time: %.3f [sec]' % elapsed_time)
        logger.info('RTF: %.3f' % (elapsed_time /
                                   (dataloader.n_frames * 0.01)))

    if args.recog_metric == 'edit_distance':
        if 'phone' in args.recog_unit:
            logger.info('PER (avg.): %.2f %%\n' %
                        (per_avg / len(args.recog_sets)))
        else:
            logger.info('WER / CER (avg.): %.2f / %.2f %%\n' %
                        (wer_avg / len(args.recog_sets),
                         cer_avg / len(args.recog_sets)))
    elif args.recog_metric in ['ppl', 'loss']:
        logger.info('PPL (avg.): %.2f\n' % (ppl_avg / len(args.recog_sets)))
        print('PPL (avg.): %.3f' % (ppl_avg / len(args.recog_sets)))
        logger.info('Loss (avg.): %.2f\n' % (loss_avg / len(args.recog_sets)))
        print('Loss (avg.): %.3f' % (loss_avg / len(args.recog_sets)))
    elif args.recog_metric == 'accuracy':
        logger.info('Accuracy (avg.): %.2f\n' %
                    (acc_avg / len(args.recog_sets)))
        print('Accuracy (avg.): %.3f' % (acc_avg / len(args.recog_sets)))
    elif args.recog_metric == 'bleu':
        logger.info('BLEU (avg.): %.2f\n' % (bleu / len(args.recog_sets)))
        print('BLEU (avg.): %.3f' % (bleu / len(args.recog_sets)))
Beispiel #7
0
def main():

    args = parse()

    # Load a conf file
    dir_name = os.path.dirname(args.recog_model[0])
    conf = load_config(os.path.join(dir_name, 'conf.yml'))

    # Overwrite conf
    for k, v in conf.items():
        if 'recog' not in k:
            setattr(args, k, v)

    # Setting for logging
    if os.path.isfile(os.path.join(args.recog_dir, 'decode.log')):
        os.remove(os.path.join(args.recog_dir, 'decode.log'))
    logger = set_logger(os.path.join(args.recog_dir, 'decode.log'),
                        key='decoding')

    ppl_avg = 0
    for i, s in enumerate(args.recog_sets):
        # Load dataset
        dataset = Dataset(corpus=args.corpus,
                          tsv_path=s,
                          dict_path=os.path.join(dir_name, 'dict.txt'),
                          wp_model=os.path.join(dir_name, 'wp.model'),
                          unit=args.unit,
                          batch_size=args.recog_batch_size,
                          bptt=args.bptt,
                          backward=args.backward,
                          serialize=args.serialize,
                          is_test=True)

        if i == 0:
            # Load the LM
            model = select_lm(args)
            model, checkpoint = load_checkpoint(model, args.recog_model[0])
            epoch = checkpoint['epoch']
            model.save_path = dir_name

            logger.info('epoch: %d' % (epoch - 1))
            logger.info('batch size: %d' % args.recog_batch_size)
            # logger.info('recog unit: %s' % args.recog_unit)
            # logger.info('ensemble: %d' % (len(ensemble_models)))
            logger.info('BPTT: %d' % (args.bptt))
            logger.info('cache size: %d' % (args.recog_n_caches))
            logger.info('cache theta: %.3f' % (args.recog_cache_theta))
            logger.info('cache lambda: %.3f' % (args.recog_cache_lambda))
            model.cache_theta = args.recog_cache_theta
            model.cache_lambda = args.recog_cache_lambda

            # GPU setting
            model.cuda()

        start_time = time.time()

        # TODO(hirofumi): ensemble
        ppl, _ = eval_ppl([model],
                          dataset,
                          batch_size=1,
                          bptt=args.bptt,
                          n_caches=args.recog_n_caches,
                          progressbar=True)
        ppl_avg += ppl
        print('PPL (%s): %.2f' % (dataset.set, ppl))
        logger.info('Elasped time: %.2f [sec]:' % (time.time() - start_time))

    logger.info('PPL (avg.): %.2f\n' % (ppl_avg / len(args.recog_sets)))
Beispiel #8
0
def main():

    args = parse()
    args_pt = copy.deepcopy(args)

    # Load a conf file
    if args.resume:
        conf = load_config(
            os.path.join(os.path.dirname(args.resume), 'conf.yml'))
        for k, v in conf.items():
            if k != 'resume':
                setattr(args, k, v)
    recog_params = vars(args)

    # Automatically reduce batch size in multi-GPU setting
    if args.n_gpus > 1:
        args.batch_size -= 10
        args.print_step //= args.n_gpus

    subsample_factor = 1
    subsample_factor_sub1 = 1
    subsample_factor_sub2 = 1
    subsample_factor_sub3 = 1
    subsample = [int(s) for s in args.subsample.split('_')]
    if args.conv_poolings:
        for p in args.conv_poolings.split('_'):
            p = int(p.split(',')[1].replace(')', ''))
            if p > 1:
                subsample_factor *= p
    if args.train_set_sub1:
        subsample_factor_sub1 = subsample_factor * np.prod(
            subsample[:args.enc_n_layers_sub1 - 1])
    if args.train_set_sub2:
        subsample_factor_sub2 = subsample_factor * np.prod(
            subsample[:args.enc_n_layers_sub2 - 1])
    if args.train_set_sub3:
        subsample_factor_sub3 = subsample_factor * np.prod(
            subsample[:args.enc_n_layers_sub3 - 1])
    subsample_factor *= np.prod(subsample)

    # Load dataset
    train_set = Dataset(corpus=args.corpus,
                        tsv_path=args.train_set,
                        tsv_path_sub1=args.train_set_sub1,
                        tsv_path_sub2=args.train_set_sub2,
                        tsv_path_sub3=args.train_set_sub3,
                        dict_path=args.dict,
                        dict_path_sub1=args.dict_sub1,
                        dict_path_sub2=args.dict_sub2,
                        dict_path_sub3=args.dict_sub3,
                        nlsyms=args.nlsyms,
                        unit=args.unit,
                        unit_sub1=args.unit_sub1,
                        unit_sub2=args.unit_sub2,
                        unit_sub3=args.unit_sub3,
                        wp_model=args.wp_model,
                        wp_model_sub1=args.wp_model_sub1,
                        wp_model_sub2=args.wp_model_sub2,
                        wp_model_sub3=args.wp_model_sub3,
                        batch_size=args.batch_size * args.n_gpus,
                        n_epochs=args.n_epochs,
                        min_n_frames=args.min_n_frames,
                        max_n_frames=args.max_n_frames,
                        sort_by_input_length=True,
                        short2long=True,
                        sort_stop_epoch=args.sort_stop_epoch,
                        dynamic_batching=args.dynamic_batching,
                        ctc=args.ctc_weight > 0,
                        ctc_sub1=args.ctc_weight_sub1 > 0,
                        ctc_sub2=args.ctc_weight_sub2 > 0,
                        ctc_sub3=args.ctc_weight_sub3 > 0,
                        subsample_factor=subsample_factor,
                        subsample_factor_sub1=subsample_factor_sub1,
                        subsample_factor_sub2=subsample_factor_sub2,
                        subsample_factor_sub3=subsample_factor_sub3,
                        concat_prev_n_utterances=args.concat_prev_n_utterances,
                        n_caches=args.n_caches)
    dev_set = Dataset(corpus=args.corpus,
                      tsv_path=args.dev_set,
                      tsv_path_sub1=args.dev_set_sub1,
                      tsv_path_sub2=args.dev_set_sub2,
                      tsv_path_sub3=args.dev_set_sub3,
                      dict_path=args.dict,
                      dict_path_sub1=args.dict_sub1,
                      dict_path_sub2=args.dict_sub2,
                      dict_path_sub3=args.dict_sub3,
                      unit=args.unit,
                      unit_sub1=args.unit_sub1,
                      unit_sub2=args.unit_sub2,
                      unit_sub3=args.unit_sub3,
                      wp_model=args.wp_model,
                      wp_model_sub1=args.wp_model_sub1,
                      wp_model_sub2=args.wp_model_sub2,
                      wp_model_sub3=args.wp_model_sub3,
                      batch_size=args.batch_size * args.n_gpus,
                      min_n_frames=args.min_n_frames,
                      max_n_frames=args.max_n_frames,
                      shuffle=True if args.n_caches == 0 else False,
                      ctc=args.ctc_weight > 0,
                      ctc_sub1=args.ctc_weight_sub1 > 0,
                      ctc_sub2=args.ctc_weight_sub2 > 0,
                      ctc_sub3=args.ctc_weight_sub3 > 0,
                      subsample_factor=subsample_factor,
                      subsample_factor_sub1=subsample_factor_sub1,
                      subsample_factor_sub2=subsample_factor_sub2,
                      subsample_factor_sub3=subsample_factor_sub3,
                      n_caches=args.n_caches)
    eval_sets = []
    for s in args.eval_sets:
        eval_sets += [
            Dataset(corpus=args.corpus,
                    tsv_path=s,
                    dict_path=args.dict,
                    unit=args.unit,
                    wp_model=args.wp_model,
                    batch_size=1,
                    n_caches=args.n_caches,
                    is_test=True)
        ]

    args.vocab = train_set.vocab
    args.vocab_sub1 = train_set.vocab_sub1
    args.vocab_sub2 = train_set.vocab_sub2
    args.vocab_sub3 = train_set.vocab_sub3
    args.input_dim = train_set.input_dim

    # Load a LM conf file for cold fusion & LM initialization
    if args.lm_fusion:
        if args.model:
            lm_conf = load_config(
                os.path.join(os.path.dirname(args.lm_fusion), 'conf.yml'))
        elif args.resume:
            lm_conf = load_config(
                os.path.join(os.path.dirname(args.resume), 'conf_lm.yml'))
        args.lm_conf = argparse.Namespace()
        for k, v in lm_conf.items():
            setattr(args.lm_conf, k, v)
        assert args.unit == args.lm_conf.unit
        assert args.vocab == args.lm_conf.vocab

    if args.enc_type == 'transformer':
        args.decay_type = 'warmup'

    # Model setting
    model = Seq2seq(args)
    dir_name = make_model_name(args, subsample_factor)

    if args.resume:
        # Set save path
        model.save_path = os.path.dirname(args.resume)

        # Setting for logging
        logger = set_logger(os.path.join(os.path.dirname(args.resume),
                                         'train.log'),
                            key='training')

        # Set optimizer
        epoch = int(args.resume.split('-')[-1])
        model.set_optimizer(
            optimizer='sgd'
            if epoch > conf['convert_to_sgd_epoch'] + 1 else conf['optimizer'],
            learning_rate=float(conf['learning_rate']),  # on-the-fly
            weight_decay=float(conf['weight_decay']))

        # Restore the last saved model
        checkpoints = model.load_checkpoint(args.resume, resume=True)
        lr_controller = checkpoints['lr_controller']
        epoch = checkpoints['epoch']
        step = checkpoints['step']
        metric_dev_best = checkpoints['metric_dev_best']

        # Resume between convert_to_sgd_epoch and convert_to_sgd_epoch + 1
        if epoch == conf['convert_to_sgd_epoch'] + 1:
            model.set_optimizer(optimizer='sgd',
                                learning_rate=args.learning_rate,
                                weight_decay=float(conf['weight_decay']))
            logger.info('========== Convert to SGD ==========')
    else:
        # Set save path
        save_path = mkdir_join(
            args.model,
            '_'.join(os.path.basename(args.train_set).split('.')[:-1]),
            dir_name)
        model.set_save_path(save_path)  # avoid overwriting

        # Save the conf file as a yaml file
        save_config(vars(args), os.path.join(model.save_path, 'conf.yml'))
        if args.lm_fusion:
            save_config(args.lm_conf,
                        os.path.join(model.save_path, 'conf_lm.yml'))

        # Save the nlsyms, dictionar, and wp_model
        if args.nlsyms:
            shutil.copy(args.nlsyms, os.path.join(model.save_path,
                                                  'nlsyms.txt'))
        for sub in ['', '_sub1', '_sub2', '_sub3']:
            if getattr(args, 'dict' + sub):
                shutil.copy(
                    getattr(args, 'dict' + sub),
                    os.path.join(model.save_path, 'dict' + sub + '.txt'))
            if getattr(args, 'unit' + sub) == 'wp':
                shutil.copy(
                    getattr(args, 'wp_model' + sub),
                    os.path.join(model.save_path, 'wp' + sub + '.model'))

        # Setting for logging
        logger = set_logger(os.path.join(model.save_path, 'train.log'),
                            key='training')

        for k, v in sorted(vars(args).items(), key=lambda x: x[0]):
            logger.info('%s: %s' % (k, str(v)))

        # Count total parameters
        for n in sorted(list(model.num_params_dict.keys())):
            nparams = model.num_params_dict[n]
            logger.info("%s %d" % (n, nparams))
        logger.info("Total %.2f M parameters" %
                    (model.total_parameters / 1000000))
        logger.info(model)

        # Initialize with pre-trained model's parameters
        if args.pretrained_model and os.path.isfile(args.pretrained_model):
            # Load a conf file
            conf_pt = load_config(
                os.path.join(os.path.dirname(args.pretrained_model),
                             'conf.yml'))

            # Merge conf with args
            for k, v in conf_pt.items():
                setattr(args_pt, k, v)

            # Load the ASR model
            model_pt = Seq2seq(args_pt)
            model_pt.load_checkpoint(args.pretrained_model)

            # Overwrite parameters
            only_enc = (args.enc_n_layers !=
                        args_pt.enc_n_layers) or (args.unit != args_pt.unit)
            param_dict = dict(model_pt.named_parameters())
            for n, p in model.named_parameters():
                if n in param_dict.keys() and p.size() == param_dict[n].size():
                    if only_enc and 'enc' not in n:
                        continue
                    if args.lm_fusion_type == 'cache' and 'output' in n:
                        continue
                    p.data = param_dict[n].data
                    logger.info('Overwrite %s' % n)

        # Set optimizer
        model.set_optimizer(optimizer=args.optimizer,
                            learning_rate=float(args.learning_rate),
                            weight_decay=float(args.weight_decay),
                            transformer=True if args.enc_type == 'transformer'
                            or args.dec_type == 'transformer' else False)

        epoch, step = 1, 1
        metric_dev_best = 10000

        # Set learning rate controller
        lr_controller = Controller(
            learning_rate=float(args.learning_rate),
            decay_type=args.decay_type,
            decay_start_epoch=args.decay_start_epoch,
            decay_rate=args.decay_rate,
            decay_patient_n_epochs=args.decay_patient_n_epochs,
            lower_better=True,
            best_value=metric_dev_best,
            model_size=args.d_model,
            warmup_start_learning_rate=args.warmup_start_learning_rate,
            warmup_n_steps=args.warmup_n_steps,
            factor=1)

    train_set.epoch = epoch - 1  # start from index:0

    # GPU setting
    if args.n_gpus >= 1:
        model = CustomDataParallel(model,
                                   device_ids=list(range(0, args.n_gpus, 1)),
                                   deterministic=False,
                                   benchmark=True)
        model.cuda()

    logger.info('PID: %s' % os.getpid())
    logger.info('USERNAME: %s' % os.uname()[1])

    # Set process name
    if args.job_name:
        setproctitle(args.job_name)
    else:
        setproctitle(dir_name)

    # Set reporter
    reporter = Reporter(model.module.save_path, tensorboard=True)

    if args.mtl_per_batch:
        # NOTE: from easier to harder tasks
        tasks = []
        if 1 - args.bwd_weight - args.ctc_weight - args.sub1_weight - args.sub2_weight - args.sub3_weight > 0:
            tasks += ['ys']
        if args.bwd_weight > 0:
            tasks = ['ys.bwd'] + tasks
        if args.ctc_weight > 0:
            tasks = ['ys.ctc'] + tasks
        if args.lmobj_weight > 0:
            tasks = ['ys.lmobj'] + tasks
        if args.lm_fusion is not None and 'mtl' in args.lm_fusion_type:
            tasks = ['ys.lm'] + tasks
        for sub in ['sub1', 'sub2', 'sub3']:
            if getattr(args, 'train_set_' + sub):
                if getattr(args, sub + '_weight') - getattr(
                        args, 'bwd_weight_' + sub) - getattr(
                            args, 'ctc_weight_' + sub) > 0:
                    tasks = ['ys_' + sub] + tasks
                if getattr(args, 'bwd_weight_' + sub) > 0:
                    tasks = ['ys_' + sub + '.bwd'] + tasks
                if getattr(args, 'ctc_weight_' + sub) > 0:
                    tasks = ['ys_' + sub + '.ctc'] + tasks
                if getattr(args, 'lmobj_weight_' + sub) > 0:
                    tasks = ['ys_' + sub + '.lmobj'] + tasks
    else:
        tasks = ['all']

    start_time_train = time.time()
    start_time_epoch = time.time()
    start_time_step = time.time()
    not_improved_n_epochs = 0
    pbar_epoch = tqdm(total=len(train_set))
    while True:
        # Compute loss in the training set
        batch_train, is_new_epoch = train_set.next()

        # Change tasks depending on task
        for task in tasks:
            model.module.optimizer.zero_grad()
            loss, reporter = model(batch_train, reporter=reporter, task=task)
            if len(model.device_ids) > 1:
                loss.backward(torch.ones(len(model.device_ids)))
            else:
                loss.backward()
            loss.detach()  # Trancate the graph
            if args.clip_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(model.module.parameters(),
                                               args.clip_grad_norm)
            model.module.optimizer.step()
            loss_train = loss.item()
            del loss

        reporter.step(is_eval=False)

        # Update learning rate
        if args.decay_type == 'warmup' and step < args.warmup_n_steps:
            model.module.optimizer = lr_controller.warmup(
                model.module.optimizer, step=step)

        if step % args.print_step == 0:
            # Compute loss in the dev set
            batch_dev = dev_set.next()[0]
            # Change tasks depending on task
            for task in tasks:
                loss, reporter = model(batch_dev,
                                       reporter=reporter,
                                       task=task,
                                       is_eval=True)
                loss_dev = loss.item()
                del loss
            reporter.step(is_eval=True)

            duration_step = time.time() - start_time_step
            if args.input_type == 'speech':
                xlen = max(len(x) for x in batch_train['xs'])
            elif args.input_type == 'text':
                xlen = max(len(x) for x in batch_train['ys'])
            logger.info(
                "step:%d(ep:%.2f) loss:%.3f(%.3f)/lr:%.5f/bs:%d/xlen:%d (%.2f min)"
                % (step, train_set.epoch_detail, loss_train, loss_dev,
                   lr_controller.lr, len(
                       batch_train['utt_ids']), xlen, duration_step / 60))
            start_time_step = time.time()
        step += args.n_gpus
        pbar_epoch.update(len(batch_train['utt_ids']))

        # Save fugures of loss and accuracy
        if step % (args.print_step * 10) == 0:
            reporter.snapshot()

        # Save checkpoint and evaluate model per epoch
        if is_new_epoch:
            duration_epoch = time.time() - start_time_epoch
            logger.info('========== EPOCH:%d (%.2f min) ==========' %
                        (epoch, duration_epoch / 60))

            if epoch < args.eval_start_epoch:
                # Save the model
                model.module.save_checkpoint(model.module.save_path,
                                             lr_controller, epoch, step - 1,
                                             metric_dev_best)
                reporter._epoch += 1
                # TODO(hirofumi): fix later
            else:
                start_time_eval = time.time()
                # dev
                if args.metric == 'edit_distance':
                    if args.unit in ['word', 'word_char']:
                        metric_dev = eval_word([model.module],
                                               dev_set,
                                               recog_params,
                                               epoch=epoch)[0]
                        logger.info('WER (%s): %.2f %%' %
                                    (dev_set.set, metric_dev))
                    elif args.unit == 'wp':
                        metric_dev, cer_dev = eval_wordpiece([model.module],
                                                             dev_set,
                                                             recog_params,
                                                             epoch=epoch)
                        logger.info('WER (%s): %.2f %%' %
                                    (dev_set.set, metric_dev))
                        logger.info('CER (%s): %.2f %%' %
                                    (dev_set.set, cer_dev))
                    elif 'char' in args.unit:
                        metric_dev, cer_dev = eval_char([model.module],
                                                        dev_set,
                                                        recog_params,
                                                        epoch=epoch)
                        logger.info('WER (%s): %.2f %%' %
                                    (dev_set.set, metric_dev))
                        logger.info('CER (%s): %.2f %%' %
                                    (dev_set.set, cer_dev))
                    elif 'phone' in args.unit:
                        metric_dev = eval_phone([model.module],
                                                dev_set,
                                                recog_params,
                                                epoch=epoch)
                        logger.info('PER (%s): %.2f %%' %
                                    (dev_set.set, metric_dev))
                elif args.metric == 'ppl':
                    metric_dev = eval_ppl([model.module], dev_set,
                                          recog_params)[0]
                    logger.info('PPL (%s): %.2f %%' %
                                (dev_set.set, metric_dev))
                elif args.metric == 'loss':
                    metric_dev = eval_ppl([model.module], dev_set,
                                          recog_params)[1]
                    logger.info('Loss (%s): %.2f %%' %
                                (dev_set.set, metric_dev))
                else:
                    raise NotImplementedError(args.metric)
                reporter.epoch(metric_dev)

                # Update learning rate
                model.module.optimizer = lr_controller.decay(
                    model.module.optimizer, epoch=epoch, value=metric_dev)

                if metric_dev < metric_dev_best:
                    metric_dev_best = metric_dev
                    not_improved_n_epochs = 0
                    logger.info('||||| Best Score |||||')

                    # Save the model
                    model.module.save_checkpoint(model.module.save_path,
                                                 lr_controller, epoch,
                                                 step - 1, metric_dev_best)

                    # test
                    for s in eval_sets:
                        if args.metric == 'edit_distance':
                            if args.unit in ['word', 'word_char']:
                                wer_test = eval_word([model.module],
                                                     s,
                                                     recog_params,
                                                     epoch=epoch)[0]
                                logger.info('WER (%s): %.2f %%' %
                                            (s.set, wer_test))
                            elif args.unit == 'wp':
                                wer_test, cer_test = eval_wordpiece(
                                    [model.module],
                                    s,
                                    recog_params,
                                    epoch=epoch)
                                logger.info('WER (%s): %.2f %%' %
                                            (s.set, wer_test))
                                logger.info('CER (%s): %.2f %%' %
                                            (s.set, cer_test))
                            elif 'char' in args.unit:
                                wer_test, cer_test = eval_char([model.module],
                                                               s,
                                                               recog_params,
                                                               epoch=epoch)
                                logger.info('WER (%s): %.2f %%' %
                                            (s.set, wer_test))
                                logger.info('CER (%s): %.2f %%' %
                                            (s.set, cer_test))
                            elif 'phone' in args.unit:
                                per_test = eval_phone([model.module],
                                                      s,
                                                      recog_params,
                                                      epoch=epoch)
                                logger.info('PER (%s): %.2f %%' %
                                            (s.set, per_test))
                        elif args.metric == 'ppl':
                            ppl_test = eval_ppl([model.module], s,
                                                recog_params)[0]
                            logger.info('PPL (%s): %.2f %%' %
                                        (s.set, ppl_test))
                        elif args.metric == 'loss':
                            loss_test = eval_ppl([model.module], s,
                                                 recog_params)[1]
                            logger.info('Loss (%s): %.2f %%' %
                                        (s.set, loss_test))
                        else:
                            raise NotImplementedError(args.metric)
                else:
                    not_improved_n_epochs += 1

                    # start scheduled sampling
                    if args.ss_prob > 0:
                        model.module.scheduled_sampling_trigger()

                duration_eval = time.time() - start_time_eval
                logger.info('Evaluation time: %.2f min' % (duration_eval / 60))

                # Early stopping
                if not_improved_n_epochs == args.not_improved_patient_n_epochs:
                    break

                # Convert to fine-tuning stage
                if epoch == args.convert_to_sgd_epoch:
                    model.module.set_optimizer(
                        'sgd',
                        learning_rate=args.learning_rate,
                        weight_decay=float(args.weight_decay))
                    lr_controller = Controller(
                        learning_rate=args.learning_rate,
                        decay_type='epoch',
                        decay_start_epoch=epoch,
                        decay_rate=0.5,
                        lower_better=True)
                    logger.info('========== Convert to SGD ==========')

            pbar_epoch = tqdm(total=len(train_set))

            if epoch == args.n_epochs:
                break

            start_time_step = time.time()
            start_time_epoch = time.time()
            epoch += 1

    duration_train = time.time() - start_time_train
    logger.info('Total time: %.2f hour' % (duration_train / 3600))

    if reporter.tensorboard:
        reporter.tf_writer.close()
    pbar_epoch.close()

    return model.module.save_path
Beispiel #9
0
def main():

    args = parse()

    # Load a conf file
    if args.resume:
        conf = load_config(
            os.path.join(os.path.dirname(args.resume), 'conf.yml'))
        for k, v in conf.items():
            if k != 'resume':
                setattr(args, k, v)

    # Load dataset
    train_set = Dataset(corpus=args.corpus,
                        tsv_path=args.train_set,
                        dict_path=args.dict,
                        nlsyms=args.nlsyms,
                        unit=args.unit,
                        wp_model=args.wp_model,
                        batch_size=args.batch_size * args.n_gpus,
                        n_epochs=args.n_epochs,
                        min_n_tokens=args.min_n_tokens,
                        bptt=args.bptt,
                        backward=args.backward,
                        serialize=args.serialize)
    dev_set = Dataset(corpus=args.corpus,
                      tsv_path=args.dev_set,
                      dict_path=args.dict,
                      nlsyms=args.nlsyms,
                      unit=args.unit,
                      wp_model=args.wp_model,
                      batch_size=args.batch_size * args.n_gpus,
                      bptt=args.bptt,
                      backward=args.backward,
                      serialize=args.serialize)
    eval_sets = []
    for s in args.eval_sets:
        eval_sets += [
            Dataset(corpus=args.corpus,
                    tsv_path=s,
                    dict_path=args.dict,
                    nlsyms=args.nlsyms,
                    unit=args.unit,
                    wp_model=args.wp_model,
                    batch_size=1,
                    bptt=args.bptt,
                    backward=args.backward,
                    serialize=args.serialize)
        ]

    args.vocab = train_set.vocab

    # Set save path
    if args.resume:
        save_path = os.path.dirname(args.resume)
        dir_name = os.path.basename(save_path)
    else:
        dir_name = make_model_name(args)
        save_path = mkdir_join(
            args.model,
            '_'.join(os.path.basename(args.train_set).split('.')[:-1]),
            dir_name)
        save_path = set_save_path(save_path)  # avoid overwriting

    # Set logger
    logger = set_logger(os.path.join(save_path, 'train.log'), key='training')

    # Model setting
    if 'gated_conv' in args.lm_type:
        model = GatedConvLM(args)
    else:
        model = RNNLM(args)
    model.save_path = save_path

    if args.resume:
        # Set optimizer
        epoch = int(args.resume.split('-')[-1])
        model.set_optimizer(
            optimizer='sgd'
            if epoch > conf['convert_to_sgd_epoch'] + 1 else conf['optimizer'],
            learning_rate=float(conf['learning_rate']),  # on-the-fly
            weight_decay=float(conf['weight_decay']))

        # Restore the last saved model
        model, checkpoint = load_checkpoint(model, args.resume, resume=True)
        lr_controller = checkpoint['lr_controller']
        epoch = checkpoint['epoch']
        step = checkpoint['step']
        ppl_dev_best = checkpoint['metric_dev_best']

        # Resume between convert_to_sgd_epoch and convert_to_sgd_epoch + 1
        if epoch == conf['convert_to_sgd_epoch'] + 1:
            model.set_optimizer(optimizer='sgd',
                                learning_rate=args.learning_rate,
                                weight_decay=float(conf['weight_decay']))
            logger.info('========== Convert to SGD ==========')
    else:
        # Save the conf file as a yaml file
        save_config(vars(args), os.path.join(model.save_path, 'conf.yml'))

        # Save the nlsyms, dictionar, and wp_model
        if args.nlsyms:
            shutil.copy(args.nlsyms, os.path.join(model.save_path,
                                                  'nlsyms.txt'))
        shutil.copy(args.dict, os.path.join(model.save_path, 'dict.txt'))
        if args.unit == 'wp':
            shutil.copy(args.wp_model, os.path.join(model.save_path,
                                                    'wp.model'))

        for k, v in sorted(vars(args).items(), key=lambda x: x[0]):
            logger.info('%s: %s' % (k, str(v)))

        # Count total parameters
        for n in sorted(list(model.num_params_dict.keys())):
            nparams = model.num_params_dict[n]
            logger.info("%s %d" % (n, nparams))
        logger.info("Total %.2f M parameters" %
                    (model.total_parameters / 1000000))
        logger.info(model)

        # Set optimizer
        model.set_optimizer(optimizer=args.optimizer,
                            learning_rate=float(args.learning_rate),
                            weight_decay=float(args.weight_decay))

        epoch, step = 1, 1
        ppl_dev_best = 10000

        # Set learning rate controller
        lr_controller = Controller(
            learning_rate=float(args.learning_rate),
            decay_type=args.decay_type,
            decay_start_epoch=args.decay_start_epoch,
            decay_rate=args.decay_rate,
            decay_patient_n_epochs=args.decay_patient_n_epochs,
            lower_better=True,
            best_value=ppl_dev_best)

    train_set.epoch = epoch - 1  # start from index:0

    # GPU setting
    if args.n_gpus >= 1:
        model = CustomDataParallel(model,
                                   device_ids=list(range(0, args.n_gpus, 1)),
                                   deterministic=False,
                                   benchmark=True)
        model.cuda()

    logger.info('PID: %s' % os.getpid())
    logger.info('USERNAME: %s' % os.uname()[1])

    # Set process name
    if args.job_name:
        setproctitle(args.job_name)
    else:
        setproctitle(dir_name)

    # Set reporter
    reporter = Reporter(model.module.save_path, tensorboard=True)

    hidden = None
    start_time_train = time.time()
    start_time_epoch = time.time()
    start_time_step = time.time()
    not_improved_epoch = 0
    pbar_epoch = tqdm(total=len(train_set))
    while True:
        # Compute loss in the training set
        ys_train, is_new_epoch = train_set.next()

        model.module.optimizer.zero_grad()
        loss, hidden, reporter = model(ys_train, hidden, reporter)
        if len(model.device_ids) > 1:
            loss.backward(torch.ones(len(model.device_ids)))
        else:
            loss.backward()
        loss.detach()  # Trancate the graph
        if args.clip_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.module.parameters(),
                                           args.clip_grad_norm)
        model.module.optimizer.step()
        loss_train = loss.item()
        del loss
        if 'gated_conv' not in args.lm_type:
            hidden = model.module.repackage_hidden(hidden)
        reporter.step(is_eval=False)

        if step % args.print_step == 0:
            # Compute loss in the dev set
            ys_dev = dev_set.next()[0]
            loss, _, reporter = model(ys_dev, None, reporter, is_eval=True)
            loss_dev = loss.item()
            del loss
            reporter.step(is_eval=True)

            duration_step = time.time() - start_time_step
            logger.info(
                "step:%d(ep:%.2f) loss:%.3f(%.3f)/ppl:%.3f(%.3f)/lr:%.5f/bs:%d (%.2f min)"
                % (step, train_set.epoch_detail, loss_train, loss_dev,
                   np.exp(loss_train), np.exp(loss_dev), lr_controller.lr,
                   ys_train.shape[0], duration_step / 60))
            start_time_step = time.time()
        step += args.n_gpus
        pbar_epoch.update(ys_train.shape[0] * (ys_train.shape[1] - 1))

        # Save fugures of loss and accuracy
        if step % (args.print_step * 10) == 0:
            reporter.snapshot()

        # Save checkpoint and evaluate model per epoch
        if is_new_epoch:
            duration_epoch = time.time() - start_time_epoch
            logger.info('========== EPOCH:%d (%.2f min) ==========' %
                        (epoch, duration_epoch / 60))

            if epoch < args.eval_start_epoch:
                # Save the model
                save_checkpoint(model.module,
                                model.module.save_path,
                                lr_controller,
                                epoch,
                                step - 1,
                                ppl_dev_best,
                                remove_old_checkpoints=True)
            else:
                start_time_eval = time.time()
                # dev
                ppl_dev, _ = eval_ppl([model.module],
                                      dev_set,
                                      batch_size=1,
                                      bptt=args.bptt)
                logger.info('PPL (%s): %.2f' % (dev_set.set, ppl_dev))

                # Update learning rate
                model.module.optimizer = lr_controller.decay(
                    model.module.optimizer, epoch=epoch, value=ppl_dev)

                if ppl_dev < ppl_dev_best:
                    ppl_dev_best = ppl_dev
                    not_improved_epoch = 0
                    logger.info('||||| Best Score |||||')

                    # Save the model
                    save_checkpoint(model.module,
                                    model.module.save_path,
                                    lr_controller,
                                    epoch,
                                    step - 1,
                                    ppl_dev_best,
                                    remove_old_checkpoints=True)

                    # test
                    ppl_test_avg = 0.
                    for eval_set in eval_sets:
                        ppl_test, _ = eval_ppl([model.module],
                                               eval_set,
                                               batch_size=1,
                                               bptt=args.bptt)
                        logger.info('PPL (%s): %.2f' %
                                    (eval_set.set, ppl_test))
                        ppl_test_avg += ppl_test
                    if len(eval_sets) > 0:
                        logger.info('PPL (avg.): %.2f' %
                                    (ppl_test_avg / len(eval_sets)))
                else:
                    not_improved_epoch += 1

                duration_eval = time.time() - start_time_eval
                logger.info('Evaluation time: %.2f min' % (duration_eval / 60))

                # Early stopping
                if not_improved_epoch == args.not_improved_patient_n_epochs:
                    break

                # Convert to fine-tuning stage
                if epoch == args.convert_to_sgd_epoch:
                    model.module.set_optimizer(
                        'sgd',
                        learning_rate=args.learning_rate,
                        weight_decay=float(args.weight_decay))
                    lr_controller = Controller(
                        learning_rate=args.learning_rate,
                        decay_type='epoch',
                        decay_start_epoch=epoch,
                        decay_rate=0.5,
                        lower_better=True)
                    logger.info('========== Convert to SGD ==========')

            pbar_epoch = tqdm(total=len(train_set))

            if epoch == args.n_epochs:
                break

            start_time_step = time.time()
            start_time_epoch = time.time()
            epoch += 1

    duration_train = time.time() - start_time_train
    logger.info('Total time: %.2f hour' % (duration_train / 3600))

    if reporter.tensorboard:
        reporter.tf_writer.close()
    pbar_epoch.close()

    return model.module.save_path
Beispiel #10
0
def main():

    args = parse()

    # Load a conf file
    if args.resume:
        conf = load_config(
            os.path.join(os.path.dirname(args.resume), 'conf.yml'))
        for k, v in conf.items():
            if k != 'resume':
                setattr(args, k, v)

    # Set save path
    if args.resume:
        save_path = os.path.dirname(args.resume)
        dir_name = os.path.basename(save_path)
    else:
        dir_name = set_lm_name(args)
        save_path = mkdir_join(
            args.model_save_dir,
            '_'.join(os.path.basename(args.train_set).split('.')[:-1]),
            dir_name)
        save_path = set_save_path(save_path)  # avoid overwriting

    # Set logger
    logger = set_logger(os.path.join(save_path, 'train.log'),
                        key='training',
                        stdout=args.stdout)

    # Load dataset
    train_set = Dataset(corpus=args.corpus,
                        tsv_path=args.train_set,
                        dict_path=args.dict,
                        nlsyms=args.nlsyms,
                        unit=args.unit,
                        wp_model=args.wp_model,
                        batch_size=args.batch_size * args.n_gpus,
                        n_epochs=args.n_epochs,
                        min_n_tokens=args.min_n_tokens,
                        bptt=args.bptt,
                        backward=args.backward,
                        serialize=args.serialize)
    dev_set = Dataset(corpus=args.corpus,
                      tsv_path=args.dev_set,
                      dict_path=args.dict,
                      nlsyms=args.nlsyms,
                      unit=args.unit,
                      wp_model=args.wp_model,
                      batch_size=args.batch_size * args.n_gpus,
                      bptt=args.bptt,
                      backward=args.backward,
                      serialize=args.serialize)
    eval_sets = []
    for s in args.eval_sets:
        eval_sets += [
            Dataset(corpus=args.corpus,
                    tsv_path=s,
                    dict_path=args.dict,
                    nlsyms=args.nlsyms,
                    unit=args.unit,
                    wp_model=args.wp_model,
                    batch_size=1,
                    bptt=args.bptt,
                    backward=args.backward,
                    serialize=args.serialize)
        ]

    args.vocab = train_set.vocab

    # Model setting
    model = build_lm(args, save_path)

    if args.resume:
        # Set optimizer
        epoch = int(args.resume.split('-')[-1])
        optimizer = set_optimizer(
            model, 'sgd' if epoch > conf['convert_to_sgd_epoch'] else
            conf['optimizer'], conf['lr'], conf['weight_decay'])

        # Wrap optimizer by learning rate scheduler
        optimizer = LRScheduler(
            optimizer,
            conf['lr'],
            decay_type=conf['lr_decay_type'],
            decay_start_epoch=conf['lr_decay_start_epoch'],
            decay_rate=conf['lr_decay_rate'],
            decay_patient_n_epochs=conf['lr_decay_patient_n_epochs'],
            early_stop_patient_n_epochs=conf['early_stop_patient_n_epochs'],
            warmup_start_lr=conf['warmup_start_lr'],
            warmup_n_steps=conf['warmup_n_steps'],
            model_size=conf['d_model'],
            factor=conf['lr_factor'],
            noam=conf['lm_type'] == 'transformer')

        # Restore the last saved model
        model, optimizer = load_checkpoint(model,
                                           args.resume,
                                           optimizer,
                                           resume=True)

        # Resume between convert_to_sgd_epoch -1 and convert_to_sgd_epoch
        if epoch == conf['convert_to_sgd_epoch']:
            n_epochs = optimizer.n_epochs
            n_steps = optimizer.n_steps
            optimizer = set_optimizer(model, 'sgd', args.lr,
                                      conf['weight_decay'])
            optimizer = LRScheduler(optimizer,
                                    args.lr,
                                    decay_type='always',
                                    decay_start_epoch=0,
                                    decay_rate=0.5)
            optimizer._epoch = n_epochs
            optimizer._step = n_steps
            logger.info('========== Convert to SGD ==========')
    else:
        # Save the conf file as a yaml file
        save_config(vars(args), os.path.join(save_path, 'conf.yml'))

        # Save the nlsyms, dictionar, and wp_model
        if args.nlsyms:
            shutil.copy(args.nlsyms, os.path.join(save_path, 'nlsyms.txt'))
        shutil.copy(args.dict, os.path.join(save_path, 'dict.txt'))
        if args.unit == 'wp':
            shutil.copy(args.wp_model, os.path.join(save_path, 'wp.model'))

        for k, v in sorted(vars(args).items(), key=lambda x: x[0]):
            logger.info('%s: %s' % (k, str(v)))

        # Count total parameters
        for n in sorted(list(model.num_params_dict.keys())):
            n_params = model.num_params_dict[n]
            logger.info("%s %d" % (n, n_params))
        logger.info("Total %.2f M parameters" %
                    (model.total_parameters / 1000000))
        logger.info(model)

        # Set optimizer
        optimizer = set_optimizer(model, args.optimizer, args.lr,
                                  args.weight_decay)

        # Wrap optimizer by learning rate scheduler
        optimizer = LRScheduler(
            optimizer,
            args.lr,
            decay_type=args.lr_decay_type,
            decay_start_epoch=args.lr_decay_start_epoch,
            decay_rate=args.lr_decay_rate,
            decay_patient_n_epochs=args.lr_decay_patient_n_epochs,
            early_stop_patient_n_epochs=args.early_stop_patient_n_epochs,
            warmup_start_lr=args.warmup_start_lr,
            warmup_n_steps=args.warmup_n_steps,
            model_size=args.d_model,
            factor=args.lr_factor,
            noam=args.lm_type == 'transformer')

    # GPU setting
    if args.n_gpus >= 1:
        torch.backends.cudnn.benchmark = True
        model = CustomDataParallel(model,
                                   device_ids=list(range(0, args.n_gpus)))
        model.cuda()

    # Set process name
    logger.info('PID: %s' % os.getpid())
    logger.info('USERNAME: %s' % os.uname()[1])
    setproctitle(args.job_name if args.job_name else dir_name)

    # Set reporter
    reporter = Reporter(save_path)

    hidden = None
    start_time_train = time.time()
    start_time_epoch = time.time()
    start_time_step = time.time()
    pbar_epoch = tqdm(total=len(train_set))
    accum_n_tokens = 0
    while True:
        # Compute loss in the training set
        ys_train, is_new_epoch = train_set.next()
        accum_n_tokens += sum([len(y) for y in ys_train])
        optimizer.zero_grad()
        loss, hidden, reporter = model(ys_train, hidden, reporter)
        loss.backward()
        loss.detach()  # Trancate the graph
        if args.accum_grad_n_tokens == 0 or accum_n_tokens >= args.accum_grad_n_tokens:
            if args.clip_grad_norm > 0:
                total_norm = torch.nn.utils.clip_grad_norm_(
                    model.module.parameters(), args.clip_grad_norm)
                reporter.add_tensorboard_scalar('total_norm', total_norm)
            optimizer.step()
            optimizer.zero_grad()
            accum_n_tokens = 0
        loss_train = loss.item()
        del loss
        hidden = model.module.repackage_state(hidden)
        reporter.add_tensorboard_scalar('learning_rate', optimizer.lr)
        # NOTE: loss/acc/ppl are already added in the model
        reporter.step()

        if optimizer.n_steps % args.print_step == 0:
            # Compute loss in the dev set
            ys_dev = dev_set.next()[0]
            loss, _, reporter = model(ys_dev, None, reporter, is_eval=True)
            loss_dev = loss.item()
            del loss
            reporter.step(is_eval=True)

            duration_step = time.time() - start_time_step
            logger.info(
                "step:%d(ep:%.2f) loss:%.3f(%.3f)/ppl:%.3f(%.3f)/lr:%.5f/bs:%d (%.2f min)"
                % (optimizer.n_steps,
                   optimizer.n_epochs + train_set.epoch_detail, loss_train,
                   loss_dev, np.exp(loss_train), np.exp(loss_dev),
                   optimizer.lr, ys_train.shape[0], duration_step / 60))
            start_time_step = time.time()
        pbar_epoch.update(ys_train.shape[0] * (ys_train.shape[1] - 1))

        # Save fugures of loss and accuracy
        if optimizer.n_steps % (args.print_step * 10) == 0:
            reporter.snapshot()
            if args.lm_type == 'transformer':
                model.module.plot_attention()

        # Save checkpoint and evaluate model per epoch
        if is_new_epoch:
            duration_epoch = time.time() - start_time_epoch
            logger.info('========== EPOCH:%d (%.2f min) ==========' %
                        (optimizer.n_epochs + 1, duration_epoch / 60))

            if optimizer.n_epochs + 1 < args.eval_start_epoch:
                optimizer.epoch()  # lr decay
                reporter.epoch()  # plot

                # Save the model
                save_checkpoint(
                    model,
                    save_path,
                    optimizer,
                    optimizer.n_epochs,
                    remove_old_checkpoints=args.lm_type != 'transformer')
            else:
                start_time_eval = time.time()
                # dev
                ppl_dev, _ = eval_ppl([model.module],
                                      dev_set,
                                      batch_size=1,
                                      bptt=args.bptt)
                logger.info('PPL (%s, epoch:%d): %.2f' %
                            (dev_set.set, optimizer.n_epochs, ppl_dev))
                optimizer.epoch(ppl_dev)  # lr decay
                reporter.epoch(ppl_dev, name='perplexity')  # plot

                if optimizer.is_best:
                    # Save the model
                    save_checkpoint(
                        model,
                        save_path,
                        optimizer,
                        optimizer.n_epochs,
                        remove_old_checkpoints=args.lm_type != 'transformer')

                    # test
                    ppl_test_avg = 0.
                    for eval_set in eval_sets:
                        ppl_test, _ = eval_ppl([model.module],
                                               eval_set,
                                               batch_size=1,
                                               bptt=args.bptt)
                        logger.info(
                            'PPL (%s, epoch:%d): %.2f' %
                            (eval_set.set, optimizer.n_epochs, ppl_test))
                        ppl_test_avg += ppl_test
                    if len(eval_sets) > 0:
                        logger.info('PPL (avg., epoch:%d): %.2f' %
                                    (optimizer.n_epochs,
                                     ppl_test_avg / len(eval_sets)))

                duration_eval = time.time() - start_time_eval
                logger.info('Evaluation time: %.2f min' % (duration_eval / 60))

                # Early stopping
                if optimizer.is_early_stop:
                    break

                # Convert to fine-tuning stage
                if optimizer.n_epochs == args.convert_to_sgd_epoch:
                    n_epochs = optimizer.n_epochs
                    n_steps = optimizer.n_steps
                    optimizer = set_optimizer(model, 'sgd', args.lr,
                                              args.weight_decay)
                    optimizer = LRScheduler(optimizer,
                                            args.lr,
                                            decay_type='always',
                                            decay_start_epoch=0,
                                            decay_rate=0.5)
                    optimizer._epoch = n_epochs
                    optimizer._step = n_steps
                    logger.info('========== Convert to SGD ==========')

            pbar_epoch = tqdm(total=len(train_set))

            if optimizer.n_epochs == args.n_epochs:
                break

            start_time_step = time.time()
            start_time_epoch = time.time()

    duration_train = time.time() - start_time_train
    logger.info('Total time: %.2f hour' % (duration_train / 3600))

    reporter.tf_writer.close()
    pbar_epoch.close()

    return save_path
Beispiel #11
0
def main():

    args = parse_args_train(sys.argv[1:])

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    # Load a conf file
    if args.resume:
        conf = load_config(
            os.path.join(os.path.dirname(args.resume), 'conf.yml'))
        for k, v in conf.items():
            if k != 'resume':
                setattr(args, k, v)

    # for multi-GPUs
    if args.n_gpus > 1:
        batch_size = args.batch_size * args.n_gpus
        accum_grad_n_steps = max(1, args.accum_grad_n_steps // args.n_gpus)
    else:
        batch_size = args.batch_size
        accum_grad_n_steps = args.accum_grad_n_steps

    # Load dataset
    train_set = Dataset(corpus=args.corpus,
                        tsv_path=args.train_set,
                        dict_path=args.dict,
                        nlsyms=args.nlsyms,
                        unit=args.unit,
                        wp_model=args.wp_model,
                        batch_size=batch_size,
                        n_epochs=args.n_epochs,
                        min_n_tokens=args.min_n_tokens,
                        bptt=args.bptt,
                        shuffle=args.shuffle,
                        backward=args.backward,
                        serialize=args.serialize)
    dev_set = Dataset(corpus=args.corpus,
                      tsv_path=args.dev_set,
                      dict_path=args.dict,
                      nlsyms=args.nlsyms,
                      unit=args.unit,
                      wp_model=args.wp_model,
                      batch_size=batch_size,
                      bptt=args.bptt,
                      backward=args.backward,
                      serialize=args.serialize)
    eval_sets = [
        Dataset(corpus=args.corpus,
                tsv_path=s,
                dict_path=args.dict,
                nlsyms=args.nlsyms,
                unit=args.unit,
                wp_model=args.wp_model,
                batch_size=1,
                bptt=args.bptt,
                backward=args.backward,
                serialize=args.serialize) for s in args.eval_sets
    ]

    args.vocab = train_set.vocab

    # Set save path
    if args.resume:
        args.save_path = os.path.dirname(args.resume)
        dir_name = os.path.basename(args.save_path)
    else:
        dir_name = set_lm_name(args)
        args.save_path = mkdir_join(
            args.model_save_dir,
            '_'.join(os.path.basename(args.train_set).split('.')[:-1]),
            dir_name)
        args.save_path = set_save_path(args.save_path)  # avoid overwriting

    # Set logger
    set_logger(os.path.join(args.save_path, 'train.log'), stdout=args.stdout)

    # Model setting
    model = build_lm(args, args.save_path)

    if not args.resume:
        # Save nlsyms, dictionary, and wp_model
        if args.nlsyms:
            shutil.copy(args.nlsyms, os.path.join(args.save_path,
                                                  'nlsyms.txt'))
        shutil.copy(args.dict, os.path.join(args.save_path, 'dict.txt'))
        if args.unit == 'wp':
            shutil.copy(args.wp_model, os.path.join(args.save_path,
                                                    'wp.model'))

        for k, v in sorted(args.items(), key=lambda x: x[0]):
            logger.info('%s: %s' % (k, str(v)))

        # Count total parameters
        for n in sorted(list(model.num_params_dict.keys())):
            n_params = model.num_params_dict[n]
            logger.info("%s %d" % (n, n_params))
        logger.info("Total %.2f M parameters" %
                    (model.total_parameters / 1000000))
        logger.info('torch version: %s' % str(torch.__version__))
        logger.info(model)

    # Set optimizer
    resume_epoch = int(args.resume.split('-')[-1]) if args.resume else 0
    optimizer = set_optimizer(
        model,
        'sgd' if resume_epoch > args.convert_to_sgd_epoch else args.optimizer,
        args.lr, args.weight_decay)

    # Wrap optimizer by learning rate scheduler
    is_transformer = args.lm_type in ['transformer', 'transformer_xl']
    scheduler = LRScheduler(
        optimizer,
        args.lr,
        decay_type=args.lr_decay_type,
        decay_start_epoch=args.lr_decay_start_epoch,
        decay_rate=args.lr_decay_rate,
        decay_patient_n_epochs=args.lr_decay_patient_n_epochs,
        early_stop_patient_n_epochs=args.early_stop_patient_n_epochs,
        warmup_start_lr=args.warmup_start_lr,
        warmup_n_steps=args.warmup_n_steps,
        model_size=args.get('transformer_d_model', 0),
        factor=args.lr_factor,
        noam=args.optimizer == 'noam',
        save_checkpoints_topk=10 if is_transformer else 1)

    if args.resume:
        # Restore the last saved model
        load_checkpoint(args.resume, model, scheduler)

        # Resume between convert_to_sgd_epoch -1 and convert_to_sgd_epoch
        if resume_epoch == args.convert_to_sgd_epoch:
            scheduler.convert_to_sgd(model,
                                     args.lr,
                                     args.weight_decay,
                                     decay_type='always',
                                     decay_rate=0.5)

    # GPU setting
    args.use_apex = args.train_dtype in ["O0", "O1", "O2", "O3"]
    amp, scaler = None, None
    if args.n_gpus >= 1:
        model.cudnn_setting(
            deterministic=not (is_transformer or args.cudnn_benchmark),
            benchmark=not is_transformer and args.cudnn_benchmark)

        # Mixed precision training setting
        if args.use_apex:
            if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
                scaler = torch.cuda.amp.GradScaler()
            else:
                from apex import amp
                model, scheduler.optimizer = amp.initialize(
                    model, scheduler.optimizer, opt_level=args.train_dtype)
                amp.init()
                if args.resume:
                    load_checkpoint(args.resume, amp=amp)
        model.cuda()
        model = CustomDataParallel(model,
                                   device_ids=list(range(0, args.n_gpus)))
    else:
        model = CPUWrapperLM(model)

    # Set process name
    logger.info('PID: %s' % os.getpid())
    logger.info('USERNAME: %s' % os.uname()[1])
    logger.info('#GPU: %d' % torch.cuda.device_count())
    setproctitle(args.job_name if args.job_name else dir_name)

    # Set reporter
    reporter = Reporter(args, model)
    if args.resume:
        n_steps = scheduler.n_steps * accum_grad_n_steps
        reporter.resume(n_steps, resume_epoch)

    # Save conf file as a yaml file
    if not args.resume:
        save_config(args, os.path.join(args.save_path, 'conf.yml'))
        # NOTE: save after reporter for wandb ID

    hidden = None
    start_time_train = time.time()
    for ep in range(resume_epoch, args.n_epochs):
        for ys_train, is_new_epoch in train_set:
            hidden = train(model, train_set, dev_set, scheduler, reporter,
                           logger, args, accum_grad_n_steps, amp, scaler,
                           hidden)

        # Save checkpoint and validate model per epoch
        if reporter.n_epochs + 1 < args.eval_start_epoch:
            scheduler.epoch()  # lr decay
            reporter.epoch()  # plot

            # Save model
            scheduler.save_checkpoint(model,
                                      args.save_path,
                                      remove_old=not is_transformer
                                      and args.remove_old_checkpoints,
                                      amp=amp)
        else:
            start_time_eval = time.time()
            # dev
            model.module.reset_length(args.bptt)
            ppl_dev, _ = eval_ppl([model.module],
                                  dev_set,
                                  batch_size=1,
                                  bptt=args.bptt)
            model.module.reset_length(args.bptt)
            scheduler.epoch(ppl_dev)  # lr decay
            reporter.epoch(ppl_dev, name='perplexity')  # plot
            reporter.add_scalar('dev/perplexity', ppl_dev)
            logger.info('PPL (%s, ep:%d): %.2f' %
                        (dev_set.set, reporter.n_epochs, ppl_dev))

            if scheduler.is_topk or is_transformer:
                # Save model
                scheduler.save_checkpoint(model,
                                          args.save_path,
                                          remove_old=not is_transformer
                                          and args.remove_old_checkpoints,
                                          amp=amp)

                # test
                ppl_test_avg = 0.
                for eval_set in eval_sets:
                    model.module.reset_length(args.bptt)
                    ppl_test, _ = eval_ppl([model.module],
                                           eval_set,
                                           batch_size=1,
                                           bptt=args.bptt)
                    model.module.reset_length(args.bptt)
                    logger.info('PPL (%s, ep:%d): %.2f' %
                                (eval_set.set, reporter.n_epochs, ppl_test))
                    ppl_test_avg += ppl_test
                if len(eval_sets) > 0:
                    logger.info(
                        'PPL (avg., ep:%d): %.2f' %
                        (reporter.n_epochs, ppl_test_avg / len(eval_sets)))

            logger.info('Evaluation time: %.2f min' %
                        ((time.time() - start_time_eval) / 60))

            # Early stopping
            if scheduler.is_early_stop:
                break

            # Convert to fine-tuning stage
            if reporter.n_epochs == args.convert_to_sgd_epoch:
                scheduler.convert_to_sgd(model,
                                         args.lr,
                                         args.weight_decay,
                                         decay_type='always',
                                         decay_rate=0.5)

        if reporter.n_epochs >= args.n_epochs:
            break

    logger.info('Total time: %.2f hour' %
                ((time.time() - start_time_train) / 3600))
    reporter.close()

    return args.save_path
Beispiel #12
0
def main():

    # Load a config file
    if args.resume_model is None:
        config = load_config(args.config)
    else:
        # Restart from the last checkpoint
        config = load_config(os.path.join(args.resume_model, 'config.yml'))

    # Check differences between args and yaml comfiguraiton
    for k, v in vars(args).items():
        if k not in config.keys():
            warnings.warn("key %s is automatically set to %s" % (k, str(v)))

    # Merge config with args
    for k, v in config.items():
        setattr(args, k, v)

    # Load dataset
    train_set = Dataset(csv_path=args.train_set,
                        dict_path=args.dict,
                        label_type=args.label_type,
                        batch_size=args.batch_size * args.ngpus,
                        bptt=args.bptt,
                        eos=args.eos,
                        max_epoch=args.num_epochs,
                        shuffle=True)
    dev_set = Dataset(csv_path=args.dev_set,
                      dict_path=args.dict,
                      label_type=args.label_type,
                      batch_size=args.batch_size * args.ngpus,
                      bptt=args.bptt,
                      eos=args.eos,
                      shuffle=True)
    eval_sets = []
    for set in args.eval_sets:
        eval_sets += [Dataset(csv_path=set,
                              dict_path=args.dict,
                              label_type=args.label_type,
                              batch_size=1,
                              bptt=args.bptt,
                              eos=args.eos,
                              is_test=True)]

    args.num_classes = train_set.num_classes

    # Model setting
    model = RNNLM(args)
    model.name = args.rnn_type
    model.name += str(args.num_units) + 'H'
    model.name += str(args.num_projs) + 'P'
    model.name += str(args.num_layers) + 'L'
    model.name += '_emb' + str(args.emb_dim)
    model.name += '_' + args.optimizer
    model.name += '_lr' + str(args.learning_rate)
    model.name += '_bs' + str(args.batch_size)
    if args.tie_weights:
        model.name += '_tie'
    if args.residual:
        model.name += '_residual'
    if args.backward:
        model.name += '_bwd'

    if args.resume_model is None:
        # Set save path
        save_path = mkdir_join(args.model, '_'.join(os.path.basename(args.train_set).split('.')[:-1]), model.name)
        model.set_save_path(save_path)  # avoid overwriting

        # Save the config file as a yaml file
        save_config(vars(args), model.save_path)

        # Save the dictionary & wp_model
        shutil.copy(args.dict, os.path.join(save_path, 'dict.txt'))
        if args.label_type == 'wordpiece':
            shutil.copy(args.wp_model, os.path.join(save_path, 'wp.model'))

        # Setting for logging
        logger = set_logger(os.path.join(model.save_path, 'train.log'), key='training')

        for k, v in sorted(vars(args).items(), key=lambda x: x[0]):
            logger.info('%s: %s' % (k, str(v)))

        # Count total parameters
        for name in sorted(list(model.num_params_dict.keys())):
            num_params = model.num_params_dict[name]
            logger.info("%s %d" % (name, num_params))
        logger.info("Total %.2f M parameters" % (model.total_parameters / 1000000))

        # Set optimizer
        model.set_optimizer(optimizer=args.optimizer,
                            learning_rate_init=float(args.learning_rate),
                            weight_decay=float(args.weight_decay),
                            clip_grad_norm=args.clip_grad_norm,
                            lr_schedule=False,
                            factor=args.decay_rate,
                            patience_epoch=args.decay_patient_epoch)

        epoch, step = 1, 0
        learning_rate = float(args.learning_rate)
        metric_dev_best = 10000

    else:
        raise NotImplementedError()

    train_set.epoch = epoch - 1

    # GPU setting
    if args.ngpus >= 1:
        model = CustomDataParallel(model,
                                   device_ids=list(range(0, args.ngpus, 1)),
                                   deterministic=True,
                                   benchmark=False)
        model.cuda()

    logger.info('PID: %s' % os.getpid())
    logger.info('USERNAME: %s' % os.uname()[1])

    # Set process name
    # setproctitle(args.job_name)

    # Set learning rate controller
    lr_controller = Controller(learning_rate_init=learning_rate,
                               decay_type=args.decay_type,
                               decay_start_epoch=args.decay_start_epoch,
                               decay_rate=args.decay_rate,
                               decay_patient_epoch=args.decay_patient_epoch,
                               lower_better=True,
                               best_value=metric_dev_best)

    # Set reporter
    reporter = Reporter(model.module.save_path, max_loss=10)

    # Set the updater
    updater = Updater(args.clip_grad_norm)

    # Setting for tensorboard
    tf_writer = SummaryWriter(model.module.save_path)

    start_time_train = time.time()
    start_time_epoch = time.time()
    start_time_step = time.time()
    not_improved_epoch = 0
    loss_train_mean, acc_train_mean = 0., 0.
    pbar_epoch = tqdm(total=len(train_set))
    pbar_all = tqdm(total=len(train_set) * args.num_epochs)
    while True:
        # Compute loss in the training set (including parameter update)
        ys_train, is_new_epoch = train_set.next()
        model, loss_train, acc_train = updater(model, ys_train, args.bptt)
        loss_train_mean += loss_train
        acc_train_mean += acc_train
        pbar_epoch.update(np.sum([len(y) for y in ys_train]))

        if (step + 1) % args.print_step == 0:
            # Compute loss in the dev set
            ys_dev = dev_set.next()[0]
            model, loss_dev, acc_dev = updater(model, ys_dev, args.bptt, is_eval=True)

            loss_train_mean /= args.print_step
            acc_train_mean /= args.print_step
            reporter.step(step, loss_train_mean, loss_dev, acc_train_mean, acc_dev)

            # Logging by tensorboard
            tf_writer.add_scalar('train/loss', loss_train_mean, step + 1)
            tf_writer.add_scalar('dev/loss', loss_dev, step + 1)
            for n, p in model.module.named_parameters():
                n = n.replace('.', '/')
                if p.grad is not None:
                    tf_writer.add_histogram(n, p.data.cpu().numpy(), step + 1)
                    tf_writer.add_histogram(n + '/grad', p.grad.data.cpu().numpy(), step + 1)

            duration_step = time.time() - start_time_step
            logger.info("...Step:%d(ep:%.2f) loss:%.2f(%.2f)/acc:%.2f(%.2f)/ppl:%.2f(%.2f)/lr:%.5f/bs:%d (%.2f min)" %
                        (step + 1, train_set.epoch_detail,
                         loss_train_mean, loss_dev, acc_train_mean, acc_dev,
                         math.exp(loss_train_mean), math.exp(loss_dev),
                         learning_rate, len(ys_train), duration_step / 60))
            start_time_step = time.time()
            loss_train_mean, acc_train_mean = 0., 0.
        step += args.ngpus

        # Save checkpoint and evaluate model per epoch
        if is_new_epoch:
            duration_epoch = time.time() - start_time_epoch
            logger.info('===== EPOCH:%d (%.2f min) =====' % (epoch, duration_epoch / 60))

            # Save fugures of loss and accuracy
            reporter.epoch()

            if epoch < args.eval_start_epoch:
                # Save the model
                model.module.save_checkpoint(model.module.save_path, epoch, step,
                                             learning_rate, metric_dev_best)
            else:
                start_time_eval = time.time()
                # dev
                ppl_dev = eval_ppl([model.module], dev_set, args.bptt)
                logger.info(' PPL (%s): %.3f' % (dev_set.set, ppl_dev))

                if ppl_dev < metric_dev_best:
                    metric_dev_best = ppl_dev
                    not_improved_epoch = 0
                    logger.info('||||| Best Score |||||')

                    # Update learning rate
                    model.module.optimizer, learning_rate = lr_controller.decay_lr(
                        optimizer=model.module.optimizer,
                        learning_rate=learning_rate,
                        epoch=epoch,
                        value=ppl_dev)

                    # Save the model
                    model.module.save_checkpoint(model.module.save_path, epoch, step,
                                                 learning_rate, metric_dev_best)

                    # test
                    ppl_test_mean = 0.
                    for eval_set in eval_sets:
                        ppl_test = eval_ppl([model.module], eval_set, args.bptt)
                        logger.info(' PPL (%s): %.3f' % (eval_set.set, ppl_test))
                        ppl_test_mean += ppl_test
                    if len(eval_sets) > 0:
                        logger.info(' PPL (mean): %.3f' % (ppl_test_mean / len(eval_sets)))
                else:
                    # Update learning rate
                    model.module.optimizer, learning_rate = lr_controller.decay_lr(
                        optimizer=model.module.optimizer,
                        learning_rate=learning_rate,
                        epoch=epoch,
                        value=ppl_dev)

                    not_improved_epoch += 1

                duration_eval = time.time() - start_time_eval
                logger.info('Evaluation time: %.2f min' % (duration_eval / 60))

                # Early stopping
                if not_improved_epoch == args.not_improved_patient_epoch:
                    break

                if epoch == args.convert_to_sgd_epoch:
                    # Convert to fine-tuning stage
                    model.module.set_optimizer(
                        'sgd',
                        learning_rate_init=float(args.learning_rate),  # TODO: ?
                        weight_decay=float(args.weight_decay),
                        clip_grad_norm=args.clip_grad_norm,
                        lr_schedule=False,
                        factor=args.decay_rate,
                        patience_epoch=args.decay_patient_epoch)
                    logger.info('========== Convert to SGD ==========')

            pbar_epoch = tqdm(total=len(train_set))
            pbar_all.update(len(train_set))

            if epoch == args.num_epoch:
                break

            start_time_step = time.time()
            start_time_epoch = time.time()
            epoch += 1

    duration_train = time.time() - start_time_train
    logger.info('Total time: %.2f hour' % (duration_train / 3600))

    tf_writer.close()
    pbar_epoch.close()
    pbar_all.close()

    return model.module.save_path
Beispiel #13
0
def main():

    args = parse()

    # Load a conf file
    dir_name = os.path.dirname(args.recog_model[0])
    conf = load_config(os.path.join(dir_name, 'conf.yml'))

    # Overwrite conf
    for k, v in conf.items():
        if 'recog' not in k:
            setattr(args, k, v)
    recog_params = vars(args)

    # Setting for logging
    if os.path.isfile(os.path.join(args.recog_dir, 'decode.log')):
        os.remove(os.path.join(args.recog_dir, 'decode.log'))
    set_logger(os.path.join(args.recog_dir, 'decode.log'),
               stdout=args.recog_stdout)

    wer_avg, cer_avg, per_avg = 0, 0, 0
    ppl_avg, loss_avg = 0, 0
    for i, s in enumerate(args.recog_sets):
        # Load dataset
        dataset = Dataset(
            corpus=args.corpus,
            tsv_path=s,
            dict_path=os.path.join(dir_name, 'dict.txt'),
            dict_path_sub1=os.path.join(dir_name, 'dict_sub1.txt') if
            os.path.isfile(os.path.join(dir_name, 'dict_sub1.txt')) else False,
            dict_path_sub2=os.path.join(dir_name, 'dict_sub2.txt') if
            os.path.isfile(os.path.join(dir_name, 'dict_sub2.txt')) else False,
            nlsyms=os.path.join(dir_name, 'nlsyms.txt'),
            wp_model=os.path.join(dir_name, 'wp.model'),
            wp_model_sub1=os.path.join(dir_name, 'wp_sub1.model'),
            wp_model_sub2=os.path.join(dir_name, 'wp_sub2.model'),
            unit=args.unit,
            unit_sub1=args.unit_sub1,
            unit_sub2=args.unit_sub2,
            batch_size=args.recog_batch_size,
            is_test=True)

        if i == 0:
            # Load the ASR model
            model = Speech2Text(args, dir_name)
            load_checkpoint(model, args.recog_model[0])
            epoch = int(args.recog_model[0].split('-')[-1])

            # Model averaging for Transformer
            if 'transformer' in conf['enc_type'] and conf[
                    'dec_type'] == 'transformer':
                model = average_checkpoints(model,
                                            args.recog_model[0],
                                            epoch,
                                            n_average=args.recog_n_average)

            # Ensemble (different models)
            ensemble_models = [model]
            if len(args.recog_model) > 1:
                for recog_model_e in args.recog_model[1:]:
                    conf_e = load_config(
                        os.path.join(os.path.dirname(recog_model_e),
                                     'conf.yml'))
                    args_e = copy.deepcopy(args)
                    for k, v in conf_e.items():
                        if 'recog' not in k:
                            setattr(args_e, k, v)
                    model_e = Speech2Text(args_e)
                    load_checkpoint(model_e, recog_model_e)
                    if args.recog_n_gpus >= 1:
                        model_e.cuda()
                    ensemble_models += [model_e]

            # Load the LM for shallow fusion
            if not args.lm_fusion:
                # first path
                if args.recog_lm is not None and args.recog_lm_weight > 0:
                    conf_lm = load_config(
                        os.path.join(os.path.dirname(args.recog_lm),
                                     'conf.yml'))
                    args_lm = argparse.Namespace()
                    for k, v in conf_lm.items():
                        setattr(args_lm, k, v)
                    lm = build_lm(args_lm,
                                  wordlm=args.recog_wordlm,
                                  lm_dict_path=os.path.join(
                                      os.path.dirname(args.recog_lm),
                                      'dict.txt'),
                                  asr_dict_path=os.path.join(
                                      dir_name, 'dict.txt'))
                    load_checkpoint(lm, args.recog_lm)
                    if args_lm.backward:
                        model.lm_bwd = lm
                    else:
                        model.lm_fwd = lm

                # second path (forward)
                if args.recog_lm_second is not None and args.recog_lm_second_weight > 0:
                    conf_lm_2nd = load_config(
                        os.path.join(os.path.dirname(args.recog_lm_second),
                                     'conf.yml'))
                    args_lm_2nd = argparse.Namespace()
                    for k, v in conf_lm_2nd.items():
                        setattr(args_lm_2nd, k, v)
                    lm_2nd = build_lm(args_lm_2nd)
                    load_checkpoint(lm_2nd, args.recog_lm_second)
                    model.lm_2nd = lm_2nd

                # second path (bakward)
                if args.recog_lm_bwd is not None and args.recog_lm_rev_weight > 0:
                    conf_lm = load_config(
                        os.path.join(os.path.dirname(args.recog_lm_bwd),
                                     'conf.yml'))
                    args_lm_bwd = argparse.Namespace()
                    for k, v in conf_lm.items():
                        setattr(args_lm_bwd, k, v)
                    lm_bwd = build_lm(args_lm_bwd)
                    load_checkpoint(lm_bwd, args.recog_lm_bwd)
                    model.lm_bwd = lm_bwd

            if not args.recog_unit:
                args.recog_unit = args.unit

            logger.info('recog unit: %s' % args.recog_unit)
            logger.info('recog metric: %s' % args.recog_metric)
            logger.info('recog oracle: %s' % args.recog_oracle)
            logger.info('epoch: %d' % epoch)
            logger.info('batch size: %d' % args.recog_batch_size)
            logger.info('beam width: %d' % args.recog_beam_width)
            logger.info('min length ratio: %.3f' % args.recog_min_len_ratio)
            logger.info('max length ratio: %.3f' % args.recog_max_len_ratio)
            logger.info('length penalty: %.3f' % args.recog_length_penalty)
            logger.info('length norm: %s' % args.recog_length_norm)
            logger.info('coverage penalty: %.3f' % args.recog_coverage_penalty)
            logger.info('coverage threshold: %.3f' %
                        args.recog_coverage_threshold)
            logger.info('CTC weight: %.3f' % args.recog_ctc_weight)
            logger.info('fist LM path: %s' % args.recog_lm)
            logger.info('second LM path: %s' % args.recog_lm_second)
            logger.info('backward LM path: %s' % args.recog_lm_bwd)
            logger.info('LM weight: %.3f' % args.recog_lm_weight)
            logger.info('GNMT: %s' % args.recog_gnmt_decoding)
            logger.info('forward-backward attention: %s' %
                        args.recog_fwd_bwd_attention)
            logger.info('resolving UNK: %s' % args.recog_resolving_unk)
            logger.info('ensemble: %d' % (len(ensemble_models)))
            logger.info('ASR decoder state carry over: %s' %
                        (args.recog_asr_state_carry_over))
            logger.info('LM state carry over: %s' %
                        (args.recog_lm_state_carry_over))
            logger.info('model average (Transformer): %d' %
                        (args.recog_n_average))

            # GPU setting
            if args.recog_n_gpus >= 1:
                model.cuda()

        start_time = time.time()

        if args.recog_metric == 'edit_distance':
            if args.recog_unit in ['word', 'word_char']:
                wer, cer, _ = eval_word(ensemble_models,
                                        dataset,
                                        recog_params,
                                        epoch=epoch - 1,
                                        recog_dir=args.recog_dir,
                                        progressbar=True)
                wer_avg += wer
                cer_avg += cer
            elif args.recog_unit == 'wp':
                wer, cer = eval_wordpiece(ensemble_models,
                                          dataset,
                                          recog_params,
                                          epoch=epoch - 1,
                                          recog_dir=args.recog_dir,
                                          streaming=args.recog_streaming,
                                          progressbar=True)
                wer_avg += wer
                cer_avg += cer
            elif 'char' in args.recog_unit:
                wer, cer = eval_char(ensemble_models,
                                     dataset,
                                     recog_params,
                                     epoch=epoch - 1,
                                     recog_dir=args.recog_dir,
                                     progressbar=True,
                                     task_idx=0)
                #  task_idx=1 if args.recog_unit and 'char' in args.recog_unit else 0)
                wer_avg += wer
                cer_avg += cer
            elif 'phone' in args.recog_unit:
                per = eval_phone(ensemble_models,
                                 dataset,
                                 recog_params,
                                 epoch=epoch - 1,
                                 recog_dir=args.recog_dir,
                                 progressbar=True)
                per_avg += per
            else:
                raise ValueError(args.recog_unit)
        elif args.recog_metric == 'acc':
            raise NotImplementedError
        elif args.recog_metric in ['ppl', 'loss']:
            ppl, loss = eval_ppl(ensemble_models, dataset, progressbar=True)
            ppl_avg += ppl
            loss_avg += loss
        elif args.recog_metric == 'bleu':
            raise NotImplementedError
        else:
            raise NotImplementedError
        logger.info('Elasped time: %.2f [sec]:' % (time.time() - start_time))

    if args.recog_metric == 'edit_distance':
        if 'phone' in args.recog_unit:
            logger.info('PER (avg.): %.2f %%\n' %
                        (per_avg / len(args.recog_sets)))
        else:
            logger.info('WER / CER (avg.): %.2f / %.2f %%\n' %
                        (wer_avg / len(args.recog_sets),
                         cer_avg / len(args.recog_sets)))
    elif args.recog_metric in ['ppl', 'loss']:
        logger.info('PPL (avg.): %.2f\n' % (ppl_avg / len(args.recog_sets)))
        print('PPL (avg.): %.2f' % (ppl_avg / len(args.recog_sets)))
        logger.info('Loss (avg.): %.2f\n' % (loss_avg / len(args.recog_sets)))
        print('Loss (avg.): %.2f' % (loss_avg / len(args.recog_sets)))