Exemple #1
0
def eval_wordpiece(models,
                   dataset,
                   recog_params,
                   epoch,
                   recog_dir=None,
                   progressbar=False):
    """Evaluate the wordpiece-level model by WER.

    Args:
        models (list): models to evaluate
        dataset: An instance of a `Dataset' class
        recog_params (recog_dict):
        epoch (int):
        recog_dir (str):
        progressbar (bool): visualize the progressbar
    Returns:
        wer (float): Word error rate
        cer (float): Character error rate

    """
    # Reset data counter
    dataset.reset()

    if recog_dir is None:
        recog_dir = 'decode_' + dataset.set + '_ep' + str(
            epoch) + '_beam' + str(recog_params['recog_beam_width'])
        recog_dir += '_lp' + str(recog_params['recog_length_penalty'])
        recog_dir += '_cp' + str(recog_params['recog_coverage_penalty'])
        recog_dir += '_' + str(
            recog_params['recog_min_len_ratio']) + '_' + str(
                recog_params['recog_max_len_ratio'])
        recog_dir += '_lm' + str(recog_params['recog_lm_weight'])

        ref_trn_save_path = mkdir_join(models[0].save_path, recog_dir,
                                       'ref.trn')
        hyp_trn_save_path = mkdir_join(models[0].save_path, recog_dir,
                                       'hyp.trn')
    else:
        ref_trn_save_path = mkdir_join(recog_dir, 'ref.trn')
        hyp_trn_save_path = mkdir_join(recog_dir, 'hyp.trn')

    wer, cer = 0, 0
    n_sub_w, n_ins_w, n_del_w = 0, 0, 0
    n_sub_c, n_ins_c, n_del_c = 0, 0, 0
    n_word, n_char = 0, 0
    if progressbar:
        pbar = tqdm(total=len(dataset))

    with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path,
                                                     'w') as f_ref:
        while True:
            batch, is_new_epoch = dataset.next(
                recog_params['recog_batch_size'])
            best_hyps_id, _, perm_id, _ = models[0].decode(
                batch['xs'],
                recog_params,
                dataset.idx2token[0],
                exclude_eos=True,
                refs_id=batch['ys'],
                utt_ids=batch['utt_ids'],
                speakers=batch['sessions']
                if dataset.corpus == 'swbd' else batch['speakers'],
                ensemble_models=models[1:] if len(models) > 1 else [])
            ys = [batch['text'][i] for i in perm_id]

            for b in range(len(batch['xs'])):
                ref = ys[b]
                hyp = dataset.idx2token[0](best_hyps_id[b])

                # Write to trn
                utt_id = str(batch['utt_ids'][b])
                speaker = str(batch['speakers'][b]).replace('-', '_')
                f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n')
                f_hyp.write(hyp + ' (' + speaker + '-' + utt_id + ')\n')
                logger.info('utt-id: %s' % batch['utt_ids'][b])
                logger.info('Ref: %s' % ref)
                logger.info('Hyp: %s' % hyp)
                logger.info('-' * 150)

                # Compute WER
                wer_b, sub_b, ins_b, del_b = compute_wer(ref=ref.split(' '),
                                                         hyp=hyp.split(' '),
                                                         normalize=False)
                wer += wer_b
                n_sub_w += sub_b
                n_ins_w += ins_b
                n_del_w += del_b
                n_word += len(ref.split(' '))

                # Compute CER
                if dataset.corpus == 'csj':
                    ref = ref.replace(' ', '')
                    hyp = hyp.replace(' ', '')
                cer_b, sub_b, ins_b, del_b = compute_wer(ref=list(ref),
                                                         hyp=list(hyp),
                                                         normalize=False)
                cer += cer_b
                n_sub_c += sub_b
                n_ins_c += ins_b
                n_del_c += del_b
                n_char += len(ref)

                if progressbar:
                    pbar.update(1)

            if is_new_epoch:
                break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    wer /= n_word
    n_sub_w /= n_word
    n_ins_w /= n_word
    n_del_w /= n_word

    cer /= n_char
    n_sub_c /= n_char
    n_ins_c /= n_char
    n_del_c /= n_char

    logger.info('WER (%s): %.2f %%' % (dataset.set, wer))
    logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' %
                (n_sub_w, n_ins_w, n_del_w))
    logger.info('CER (%s): %.2f %%' % (dataset.set, cer))
    logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' %
                (n_sub_c, n_ins_c, n_del_c))

    return wer, cer
Exemple #2
0
def eval_phone(models, dataset, decode_params, epoch, progressbar=False):
    """Evaluate a phone-level model by PER.

    Args:
        models (list): the models to evaluate
        dataset: An instance of a `Dataset' class
        decode_params (dict):
        epoch (int):
        progressbar (bool): if True, visualize the progressbar
    Returns:
        per (float): Phone error rate
        num_sub (int): the number of substitution errors
        num_ins (int): the number of insertion errors
        num_del (int): the number of deletion errors
        decode_dir (str):

    """
    # Reset data counter
    dataset.reset()

    model = models[0]
    # TODO(hirofumi): ensemble decoding

    decode_dir = 'decode_' + dataset.set + '_ep' + str(epoch) + '_beam' + str(decode_params['beam_width'])
    decode_dir += '_lp' + str(decode_params['length_penalty'])
    decode_dir += '_cp' + str(decode_params['coverage_penalty'])
    decode_dir += '_' + str(decode_params['min_len_ratio']) + '_' + str(decode_params['max_len_ratio'])
    decode_dir += '_rnnlm' + str(decode_params['rnnlm_weight'])

    ref_trn_save_path = mkdir_join(model.save_path, decode_dir, 'ref.trn')
    hyp_trn_save_path = mkdir_join(model.save_path, decode_dir, 'hyp.trn')

    per = 0
    num_sub, num_ins, num_del = 0, 0, 0
    num_phones = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))

    with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path, 'w') as f_ref:
        while True:
            batch, is_new_epoch = dataset.next(decode_params['batch_size'])
            best_hyps, _, perm_idx = model.decode(batch['xs'], decode_params,
                                                  exclude_eos=True)
            ys = [batch['ys'][i] for i in perm_idx]

            for b in range(len(batch['xs'])):
                # Reference
                if dataset.is_test:
                    text_ref = ys[b]
                else:
                    text_ref = dataset.idx2phone(ys[b])

                # Hypothesis
                text_hyp = dataset.idx2phone(best_hyps[b])

                # Write to trn
                speaker = '_'.join(batch['utt_ids'][b].replace('-', '_').split('_')[:-2])
                start = batch['utt_ids'][b].replace('-', '_').split('_')[-2]
                end = batch['utt_ids'][b].replace('-', '_').split('_')[-1]
                f_ref.write(text_ref + ' (' + speaker + '-' + start + '-' + end + ')\n')
                f_hyp.write(text_hyp + ' (' + speaker + '-' + start + '-' + end + ')\n')

                # Compute PER
                per_b, sub_b, ins_b, del_b = compute_wer(ref=text_ref.split(' '),
                                                         hyp=text_hyp.split(' '),
                                                         normalize=False)
                per += per_b
                num_sub += sub_b
                num_ins += ins_b
                num_del += del_b
                num_phones += len(text_ref.split(' '))

                if progressbar:
                    pbar.update(1)

            if is_new_epoch:
                break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    per /= num_phones
    num_sub /= num_phones
    num_ins /= num_phones
    num_del /= num_phones

    return per, num_sub, num_ins, num_del, os.path.join(model.save_path, decode_dir)
Exemple #3
0
def eval_char(models, dataset, decode_params, epoch, progressbar=False):
    """Evaluate the character-level model by WER & CER.

    Args:
        models (list): the models to evaluate
        dataset: An instance of a `Dataset' class
        decode_params (dict):
        epoch (int):
        progressbar (bool): if True, visualize the progressbar
    Returns:
        wer (float): Word error rate
        num_sub (int): the number of substitution errors
        num_ins (int): the number of insertion errors
        num_del (int): the number of deletion errors
        cer (float): Character error rate
        num_sub (int): the number of substitution errors
        num_ins (int): the number of insertion errors
        num_del (int): the number of deletion errors
        decode_dir (str):

    """
    # Reset data counter
    dataset.reset()

    model = models[0]

    decode_dir = 'decode_' + dataset.set + '_ep' + str(epoch) + '_beam' + str(decode_params['beam_width'])
    decode_dir += '_lp' + str(decode_params['length_penalty'])
    decode_dir += '_cp' + str(decode_params['coverage_penalty'])
    decode_dir += '_' + str(decode_params['min_len_ratio']) + '_' + str(decode_params['max_len_ratio'])
    decode_dir += '_rnnlm' + str(decode_params['rnnlm_weight'])

    ref_trn_save_path = mkdir_join(model.save_path, decode_dir, 'ref.trn')
    hyp_trn_save_path = mkdir_join(model.save_path, decode_dir, 'hyp.trn')

    wer, cer = 0, 0
    num_sub_w, num_ins_w, num_del_w = 0, 0, 0
    num_sub_c, num_ins_c, num_del_c = 0, 0, 0
    num_words, num_chars = 0, 0
    if progressbar:
        pbar = tqdm(total=len(dataset))

    with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path, 'w') as f_ref:
        while True:
            batch, is_new_epoch = dataset.next(decode_params['batch_size'])
            best_hyps, aw, perm_idx = model.decode(batch['xs'], decode_params,
                                                   exclude_eos=True)
            # task_index = 0
            ys = [batch['ys'][i] for i in perm_idx]

            for b in range(len(batch['xs'])):
                # Reference
                if dataset.is_test:
                    text_ref = ys[b]
                else:
                    text_ref = dataset.idx2char(ys[b])

                # Hypothesis
                text_hyp = dataset.idx2char(best_hyps[b])

                # Write to trn
                speaker = '_'.join(batch['utt_ids'][b].replace('-', '_').split('_')[:-2])
                start = batch['utt_ids'][b].replace('-', '_').split('_')[-2]
                end = batch['utt_ids'][b].replace('-', '_').split('_')[-1]
                f_ref.write(text_ref + ' (' + speaker + '-' + start + '-' + end + ')\n')
                f_hyp.write(text_hyp + ' (' + speaker + '-' + start + '-' + end + ')\n')

                if ('character' in dataset.label_type and 'nowb' not in dataset.label_type) or (task_index > 0 and dataset.label_type_sub == 'character'):
                    # Compute WER
                    wer_b, sub_b, ins_b, del_b = compute_wer(ref=text_ref.split(' '),
                                                             hyp=text_hyp.split(' '),
                                                             normalize=False)
                    wer += wer_b
                    num_sub_w += sub_b
                    num_ins_w += ins_b
                    num_del_w += del_b
                    num_words += len(text_ref.split(' '))

                # Compute CER
                cer_b, sub_b, ins_b, del_b = compute_wer(ref=list(text_ref.replace(' ', '')),
                                                         hyp=list(text_hyp.replace(' ', '')),
                                                         normalize=False)
                cer += cer_b
                num_sub_c += sub_b
                num_ins_c += ins_b
                num_del_c += del_b
                num_chars += len(text_ref)

                if progressbar:
                    pbar.update(1)

            if is_new_epoch:
                break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    if ('character' in dataset.label_type and 'nowb' not in dataset.label_type) or (task_index > 0 and dataset.label_type_sub == 'character'):
        wer /= num_words
        num_sub_w /= num_words
        num_ins_w /= num_words
        num_del_w /= num_words
    else:
        wer = num_sub_w = num_ins_w = num_del_w = 0

    cer /= num_chars
    num_sub_c /= num_chars
    num_ins_c /= num_chars
    num_del_c /= num_chars

    return (wer, num_sub_w, num_ins_w, num_del_w), (cer, num_sub_c, num_ins_c, num_del_c), os.path.join(model.save_path, decode_dir)
def main():

    # Load a config file
    config = load_config(os.path.join(args.model, 'config.yml'))

    decode_params = vars(args)

    # Merge config with args
    for k, v in config.items():
        if not hasattr(args, k):
            setattr(args, k, v)

    # Setting for logging
    logger = set_logger(os.path.join(args.model, 'decode.log'), key='decoding')

    for i, set in enumerate(args.eval_sets):
        # Load dataset
        eval_set = Dataset(
            csv_path=set,
            dict_path=os.path.join(args.model, 'dict.txt'),
            dict_path_sub=os.path.join(args.model, 'dict_sub.txt') if
            os.path.isfile(os.path.join(args.model, 'dict_sub.txt')) else None,
            label_type=args.label_type,
            batch_size=args.batch_size,
            max_epoch=args.num_epochs,
            max_num_frames=args.max_num_frames,
            min_num_frames=args.min_num_frames,
            is_test=False)

        if i == 0:
            args.num_classes = eval_set.num_classes
            args.input_dim = eval_set.input_dim
            args.num_classes_sub = eval_set.num_classes_sub

            # TODO(hirofumi): For cold fusion
            args.rnnlm_cf = None
            args.rnnlm_init = None

            # Load the ASR model
            model = Seq2seq(args)

            # Restore the saved parameters
            epoch, _, _, _ = model.load_checkpoint(args.model,
                                                   epoch=args.epoch)

            model.save_path = args.model

            # For shallow fusion
            if args.rnnlm_cf is None and args.rnnlm is not None and args.rnnlm_weight > 0:
                # Load a RNNLM config file
                config_rnnlm = load_config(
                    os.path.join(args.rnnlm, 'config.yml'))

                # Merge config with args
                args_rnnlm = argparse.Namespace()
                for k, v in config_rnnlm.items():
                    setattr(args_rnnlm, k, v)

                assert args.label_type == args_rnnlm.label_type
                args_rnnlm.num_classes = eval_set.num_classes

                # Load the pre-trianed RNNLM
                rnnlm = RNNLM(args_rnnlm)
                rnnlm.load_checkpoint(args.rnnlm, epoch=-1)
                if args_rnnlm.backward:
                    model.rnnlm_bwd_0 = rnnlm
                else:
                    model.rnnlm_fwd_0 = rnnlm

                logger.info('RNNLM path: %s' % args.rnnlm)
                logger.info('RNNLM weight: %.3f' % args.rnnlm_weight)
                logger.info('RNNLM backward: %s' %
                            str(config_rnnlm['backward']))

            # GPU setting
            model.set_cuda(deterministic=False, benchmark=True)

            logger.info('beam width: %d' % args.beam_width)
            logger.info('length penalty: %.3f' % args.length_penalty)
            logger.info('coverage penalty: %.3f' % args.coverage_penalty)
            logger.info('coverage threshold: %.3f' % args.coverage_threshold)
            logger.info('epoch: %d' % (epoch - 1))

        save_path = mkdir_join(args.model, 'att_weights')

        # Clean directory
        if save_path is not None and os.path.isdir(save_path):
            shutil.rmtree(save_path)
            os.mkdir(save_path)

        while True:
            batch, is_new_epoch = eval_set.next(decode_params['batch_size'])
            best_hyps, aw, perm_idx = model.decode(batch['xs'],
                                                   decode_params,
                                                   exclude_eos=False)
            ys = [batch['ys'][i] for i in perm_idx]

            for b in range(len(batch['xs'])):
                if args.label_type in ['word', 'wordpiece']:
                    token_list = eval_set.idx2word(best_hyps[b],
                                                   return_list=True)
                elif args.label_type == 'char':
                    token_list = eval_set.idx2char(best_hyps[b],
                                                   return_list=True)
                elif args.label_type == 'phone':
                    token_list = eval_set.idx2phone(best_hyps[b],
                                                    return_list=True)
                else:
                    raise NotImplementedError()
                token_list = [unicode(t, 'utf-8') for t in token_list]
                speaker = '_'.join(batch['utt_ids'][b].replace(
                    '-', '_').split('_')[:-2])

                # error check
                assert len(batch['xs'][b]) <= 2000

                plot_attention_weights(
                    aw[b][:len(token_list)],
                    label_list=token_list,
                    spectrogram=batch['xs'][b][:, :eval_set.input_dim]
                    if args.input_type == 'speech' else None,
                    save_path=mkdir_join(save_path, speaker,
                                         batch['utt_ids'][b] + '.png'),
                    figsize=(20, 8))

                # Reference
                if eval_set.is_test:
                    text_ref = ys[b]
                else:
                    if args.label_type in ['word', 'wordpiece']:
                        text_ref = eval_set.idx2word(ys[b])
                    if args.label_type in ['word', 'wordpiece']:
                        token_list = eval_set.idx2word(ys[b])
                    elif args.label_type == 'char':
                        token_list = eval_set.idx2char(ys[b])
                    elif args.label_type == 'phone':
                        token_list = eval_set.idx2phone(ys[b])

                # Hypothesis
                text_hyp = ' '.join(token_list)

                sys.stdout = open(
                    os.path.join(save_path, speaker,
                                 batch['utt_ids'][b] + '.txt'), 'w')
                ler = wer_align(
                    ref=text_ref.split(' '),
                    hyp=text_hyp.encode('utf-8').split(' '),
                    normalize=True,
                    double_byte=False)[0]  # TODO(hirofumi): add corpus to args
                print('\nLER: %.3f %%\n\n' % ler)

            if is_new_epoch:
                break
Exemple #5
0
def main():

    args = parse()
    args_pt = copy.deepcopy(args)

    # Load a conf file
    if args.resume:
        conf = load_config(
            os.path.join(os.path.dirname(args.resume), 'conf.yml'))
        for k, v in conf.items():
            if k != 'resume':
                setattr(args, k, v)
    recog_params = vars(args)

    # Automatically reduce batch size in multi-GPU setting
    if args.n_gpus > 1:
        args.batch_size -= 10
        args.print_step //= args.n_gpus

    subsample_factor = 1
    subsample_factor_sub1 = 1
    subsample_factor_sub2 = 1
    subsample_factor_sub3 = 1
    subsample = [int(s) for s in args.subsample.split('_')]
    if args.conv_poolings:
        for p in args.conv_poolings.split('_'):
            p = int(p.split(',')[1].replace(')', ''))
            if p > 1:
                subsample_factor *= p
    if args.train_set_sub1:
        subsample_factor_sub1 = subsample_factor * np.prod(
            subsample[:args.enc_n_layers_sub1 - 1])
    if args.train_set_sub2:
        subsample_factor_sub2 = subsample_factor * np.prod(
            subsample[:args.enc_n_layers_sub2 - 1])
    if args.train_set_sub3:
        subsample_factor_sub3 = subsample_factor * np.prod(
            subsample[:args.enc_n_layers_sub3 - 1])
    subsample_factor *= np.prod(subsample)

    # Load dataset
    train_set = Dataset(corpus=args.corpus,
                        tsv_path=args.train_set,
                        tsv_path_sub1=args.train_set_sub1,
                        tsv_path_sub2=args.train_set_sub2,
                        tsv_path_sub3=args.train_set_sub3,
                        dict_path=args.dict,
                        dict_path_sub1=args.dict_sub1,
                        dict_path_sub2=args.dict_sub2,
                        dict_path_sub3=args.dict_sub3,
                        nlsyms=args.nlsyms,
                        unit=args.unit,
                        unit_sub1=args.unit_sub1,
                        unit_sub2=args.unit_sub2,
                        unit_sub3=args.unit_sub3,
                        wp_model=args.wp_model,
                        wp_model_sub1=args.wp_model_sub1,
                        wp_model_sub2=args.wp_model_sub2,
                        wp_model_sub3=args.wp_model_sub3,
                        batch_size=args.batch_size * args.n_gpus,
                        n_epochs=args.n_epochs,
                        min_n_frames=args.min_n_frames,
                        max_n_frames=args.max_n_frames,
                        sort_by_input_length=True,
                        short2long=True,
                        sort_stop_epoch=args.sort_stop_epoch,
                        dynamic_batching=args.dynamic_batching,
                        ctc=args.ctc_weight > 0,
                        ctc_sub1=args.ctc_weight_sub1 > 0,
                        ctc_sub2=args.ctc_weight_sub2 > 0,
                        ctc_sub3=args.ctc_weight_sub3 > 0,
                        subsample_factor=subsample_factor,
                        subsample_factor_sub1=subsample_factor_sub1,
                        subsample_factor_sub2=subsample_factor_sub2,
                        subsample_factor_sub3=subsample_factor_sub3,
                        concat_prev_n_utterances=args.concat_prev_n_utterances,
                        n_caches=args.n_caches)
    dev_set = Dataset(corpus=args.corpus,
                      tsv_path=args.dev_set,
                      tsv_path_sub1=args.dev_set_sub1,
                      tsv_path_sub2=args.dev_set_sub2,
                      tsv_path_sub3=args.dev_set_sub3,
                      dict_path=args.dict,
                      dict_path_sub1=args.dict_sub1,
                      dict_path_sub2=args.dict_sub2,
                      dict_path_sub3=args.dict_sub3,
                      unit=args.unit,
                      unit_sub1=args.unit_sub1,
                      unit_sub2=args.unit_sub2,
                      unit_sub3=args.unit_sub3,
                      wp_model=args.wp_model,
                      wp_model_sub1=args.wp_model_sub1,
                      wp_model_sub2=args.wp_model_sub2,
                      wp_model_sub3=args.wp_model_sub3,
                      batch_size=args.batch_size * args.n_gpus,
                      min_n_frames=args.min_n_frames,
                      max_n_frames=args.max_n_frames,
                      shuffle=True if args.n_caches == 0 else False,
                      ctc=args.ctc_weight > 0,
                      ctc_sub1=args.ctc_weight_sub1 > 0,
                      ctc_sub2=args.ctc_weight_sub2 > 0,
                      ctc_sub3=args.ctc_weight_sub3 > 0,
                      subsample_factor=subsample_factor,
                      subsample_factor_sub1=subsample_factor_sub1,
                      subsample_factor_sub2=subsample_factor_sub2,
                      subsample_factor_sub3=subsample_factor_sub3,
                      n_caches=args.n_caches)
    eval_sets = []
    for s in args.eval_sets:
        eval_sets += [
            Dataset(corpus=args.corpus,
                    tsv_path=s,
                    dict_path=args.dict,
                    unit=args.unit,
                    wp_model=args.wp_model,
                    batch_size=1,
                    n_caches=args.n_caches,
                    is_test=True)
        ]

    args.vocab = train_set.vocab
    args.vocab_sub1 = train_set.vocab_sub1
    args.vocab_sub2 = train_set.vocab_sub2
    args.vocab_sub3 = train_set.vocab_sub3
    args.input_dim = train_set.input_dim

    # Load a LM conf file for cold fusion & LM initialization
    if args.lm_fusion:
        if args.model:
            lm_conf = load_config(
                os.path.join(os.path.dirname(args.lm_fusion), 'conf.yml'))
        elif args.resume:
            lm_conf = load_config(
                os.path.join(os.path.dirname(args.resume), 'conf_lm.yml'))
        args.lm_conf = argparse.Namespace()
        for k, v in lm_conf.items():
            setattr(args.lm_conf, k, v)
        assert args.unit == args.lm_conf.unit
        assert args.vocab == args.lm_conf.vocab

    if args.enc_type == 'transformer':
        args.decay_type = 'warmup'

    # Model setting
    model = Seq2seq(args)
    dir_name = make_model_name(args, subsample_factor)

    if args.resume:
        # Set save path
        model.save_path = os.path.dirname(args.resume)

        # Setting for logging
        logger = set_logger(os.path.join(os.path.dirname(args.resume),
                                         'train.log'),
                            key='training')

        # Set optimizer
        epoch = int(args.resume.split('-')[-1])
        model.set_optimizer(
            optimizer='sgd'
            if epoch > conf['convert_to_sgd_epoch'] + 1 else conf['optimizer'],
            learning_rate=float(conf['learning_rate']),  # on-the-fly
            weight_decay=float(conf['weight_decay']))

        # Restore the last saved model
        checkpoints = model.load_checkpoint(args.resume, resume=True)
        lr_controller = checkpoints['lr_controller']
        epoch = checkpoints['epoch']
        step = checkpoints['step']
        metric_dev_best = checkpoints['metric_dev_best']

        # Resume between convert_to_sgd_epoch and convert_to_sgd_epoch + 1
        if epoch == conf['convert_to_sgd_epoch'] + 1:
            model.set_optimizer(optimizer='sgd',
                                learning_rate=args.learning_rate,
                                weight_decay=float(conf['weight_decay']))
            logger.info('========== Convert to SGD ==========')
    else:
        # Set save path
        save_path = mkdir_join(
            args.model,
            '_'.join(os.path.basename(args.train_set).split('.')[:-1]),
            dir_name)
        model.set_save_path(save_path)  # avoid overwriting

        # Save the conf file as a yaml file
        save_config(vars(args), os.path.join(model.save_path, 'conf.yml'))
        if args.lm_fusion:
            save_config(args.lm_conf,
                        os.path.join(model.save_path, 'conf_lm.yml'))

        # Save the nlsyms, dictionar, and wp_model
        if args.nlsyms:
            shutil.copy(args.nlsyms, os.path.join(model.save_path,
                                                  'nlsyms.txt'))
        for sub in ['', '_sub1', '_sub2', '_sub3']:
            if getattr(args, 'dict' + sub):
                shutil.copy(
                    getattr(args, 'dict' + sub),
                    os.path.join(model.save_path, 'dict' + sub + '.txt'))
            if getattr(args, 'unit' + sub) == 'wp':
                shutil.copy(
                    getattr(args, 'wp_model' + sub),
                    os.path.join(model.save_path, 'wp' + sub + '.model'))

        # Setting for logging
        logger = set_logger(os.path.join(model.save_path, 'train.log'),
                            key='training')

        for k, v in sorted(vars(args).items(), key=lambda x: x[0]):
            logger.info('%s: %s' % (k, str(v)))

        # Count total parameters
        for n in sorted(list(model.num_params_dict.keys())):
            nparams = model.num_params_dict[n]
            logger.info("%s %d" % (n, nparams))
        logger.info("Total %.2f M parameters" %
                    (model.total_parameters / 1000000))
        logger.info(model)

        # Initialize with pre-trained model's parameters
        if args.pretrained_model and os.path.isfile(args.pretrained_model):
            # Load a conf file
            conf_pt = load_config(
                os.path.join(os.path.dirname(args.pretrained_model),
                             'conf.yml'))

            # Merge conf with args
            for k, v in conf_pt.items():
                setattr(args_pt, k, v)

            # Load the ASR model
            model_pt = Seq2seq(args_pt)
            model_pt.load_checkpoint(args.pretrained_model)

            # Overwrite parameters
            only_enc = (args.enc_n_layers !=
                        args_pt.enc_n_layers) or (args.unit != args_pt.unit)
            param_dict = dict(model_pt.named_parameters())
            for n, p in model.named_parameters():
                if n in param_dict.keys() and p.size() == param_dict[n].size():
                    if only_enc and 'enc' not in n:
                        continue
                    if args.lm_fusion_type == 'cache' and 'output' in n:
                        continue
                    p.data = param_dict[n].data
                    logger.info('Overwrite %s' % n)

        # Set optimizer
        model.set_optimizer(optimizer=args.optimizer,
                            learning_rate=float(args.learning_rate),
                            weight_decay=float(args.weight_decay),
                            transformer=True if args.enc_type == 'transformer'
                            or args.dec_type == 'transformer' else False)

        epoch, step = 1, 1
        metric_dev_best = 10000

        # Set learning rate controller
        lr_controller = Controller(
            learning_rate=float(args.learning_rate),
            decay_type=args.decay_type,
            decay_start_epoch=args.decay_start_epoch,
            decay_rate=args.decay_rate,
            decay_patient_n_epochs=args.decay_patient_n_epochs,
            lower_better=True,
            best_value=metric_dev_best,
            model_size=args.d_model,
            warmup_start_learning_rate=args.warmup_start_learning_rate,
            warmup_n_steps=args.warmup_n_steps,
            factor=1)

    train_set.epoch = epoch - 1  # start from index:0

    # GPU setting
    if args.n_gpus >= 1:
        model = CustomDataParallel(model,
                                   device_ids=list(range(0, args.n_gpus, 1)),
                                   deterministic=False,
                                   benchmark=True)
        model.cuda()

    logger.info('PID: %s' % os.getpid())
    logger.info('USERNAME: %s' % os.uname()[1])

    # Set process name
    if args.job_name:
        setproctitle(args.job_name)
    else:
        setproctitle(dir_name)

    # Set reporter
    reporter = Reporter(model.module.save_path, tensorboard=True)

    if args.mtl_per_batch:
        # NOTE: from easier to harder tasks
        tasks = []
        if 1 - args.bwd_weight - args.ctc_weight - args.sub1_weight - args.sub2_weight - args.sub3_weight > 0:
            tasks += ['ys']
        if args.bwd_weight > 0:
            tasks = ['ys.bwd'] + tasks
        if args.ctc_weight > 0:
            tasks = ['ys.ctc'] + tasks
        if args.lmobj_weight > 0:
            tasks = ['ys.lmobj'] + tasks
        if args.lm_fusion is not None and 'mtl' in args.lm_fusion_type:
            tasks = ['ys.lm'] + tasks
        for sub in ['sub1', 'sub2', 'sub3']:
            if getattr(args, 'train_set_' + sub):
                if getattr(args, sub + '_weight') - getattr(
                        args, 'bwd_weight_' + sub) - getattr(
                            args, 'ctc_weight_' + sub) > 0:
                    tasks = ['ys_' + sub] + tasks
                if getattr(args, 'bwd_weight_' + sub) > 0:
                    tasks = ['ys_' + sub + '.bwd'] + tasks
                if getattr(args, 'ctc_weight_' + sub) > 0:
                    tasks = ['ys_' + sub + '.ctc'] + tasks
                if getattr(args, 'lmobj_weight_' + sub) > 0:
                    tasks = ['ys_' + sub + '.lmobj'] + tasks
    else:
        tasks = ['all']

    start_time_train = time.time()
    start_time_epoch = time.time()
    start_time_step = time.time()
    not_improved_n_epochs = 0
    pbar_epoch = tqdm(total=len(train_set))
    while True:
        # Compute loss in the training set
        batch_train, is_new_epoch = train_set.next()

        # Change tasks depending on task
        for task in tasks:
            model.module.optimizer.zero_grad()
            loss, reporter = model(batch_train, reporter=reporter, task=task)
            if len(model.device_ids) > 1:
                loss.backward(torch.ones(len(model.device_ids)))
            else:
                loss.backward()
            loss.detach()  # Trancate the graph
            if args.clip_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(model.module.parameters(),
                                               args.clip_grad_norm)
            model.module.optimizer.step()
            loss_train = loss.item()
            del loss

        reporter.step(is_eval=False)

        # Update learning rate
        if args.decay_type == 'warmup' and step < args.warmup_n_steps:
            model.module.optimizer = lr_controller.warmup(
                model.module.optimizer, step=step)

        if step % args.print_step == 0:
            # Compute loss in the dev set
            batch_dev = dev_set.next()[0]
            # Change tasks depending on task
            for task in tasks:
                loss, reporter = model(batch_dev,
                                       reporter=reporter,
                                       task=task,
                                       is_eval=True)
                loss_dev = loss.item()
                del loss
            reporter.step(is_eval=True)

            duration_step = time.time() - start_time_step
            if args.input_type == 'speech':
                xlen = max(len(x) for x in batch_train['xs'])
            elif args.input_type == 'text':
                xlen = max(len(x) for x in batch_train['ys'])
            logger.info(
                "step:%d(ep:%.2f) loss:%.3f(%.3f)/lr:%.5f/bs:%d/xlen:%d (%.2f min)"
                % (step, train_set.epoch_detail, loss_train, loss_dev,
                   lr_controller.lr, len(
                       batch_train['utt_ids']), xlen, duration_step / 60))
            start_time_step = time.time()
        step += args.n_gpus
        pbar_epoch.update(len(batch_train['utt_ids']))

        # Save fugures of loss and accuracy
        if step % (args.print_step * 10) == 0:
            reporter.snapshot()

        # Save checkpoint and evaluate model per epoch
        if is_new_epoch:
            duration_epoch = time.time() - start_time_epoch
            logger.info('========== EPOCH:%d (%.2f min) ==========' %
                        (epoch, duration_epoch / 60))

            if epoch < args.eval_start_epoch:
                # Save the model
                model.module.save_checkpoint(model.module.save_path,
                                             lr_controller, epoch, step - 1,
                                             metric_dev_best)
                reporter._epoch += 1
                # TODO(hirofumi): fix later
            else:
                start_time_eval = time.time()
                # dev
                if args.metric == 'edit_distance':
                    if args.unit in ['word', 'word_char']:
                        metric_dev = eval_word([model.module],
                                               dev_set,
                                               recog_params,
                                               epoch=epoch)[0]
                        logger.info('WER (%s): %.2f %%' %
                                    (dev_set.set, metric_dev))
                    elif args.unit == 'wp':
                        metric_dev, cer_dev = eval_wordpiece([model.module],
                                                             dev_set,
                                                             recog_params,
                                                             epoch=epoch)
                        logger.info('WER (%s): %.2f %%' %
                                    (dev_set.set, metric_dev))
                        logger.info('CER (%s): %.2f %%' %
                                    (dev_set.set, cer_dev))
                    elif 'char' in args.unit:
                        metric_dev, cer_dev = eval_char([model.module],
                                                        dev_set,
                                                        recog_params,
                                                        epoch=epoch)
                        logger.info('WER (%s): %.2f %%' %
                                    (dev_set.set, metric_dev))
                        logger.info('CER (%s): %.2f %%' %
                                    (dev_set.set, cer_dev))
                    elif 'phone' in args.unit:
                        metric_dev = eval_phone([model.module],
                                                dev_set,
                                                recog_params,
                                                epoch=epoch)
                        logger.info('PER (%s): %.2f %%' %
                                    (dev_set.set, metric_dev))
                elif args.metric == 'ppl':
                    metric_dev = eval_ppl([model.module], dev_set,
                                          recog_params)[0]
                    logger.info('PPL (%s): %.2f %%' %
                                (dev_set.set, metric_dev))
                elif args.metric == 'loss':
                    metric_dev = eval_ppl([model.module], dev_set,
                                          recog_params)[1]
                    logger.info('Loss (%s): %.2f %%' %
                                (dev_set.set, metric_dev))
                else:
                    raise NotImplementedError(args.metric)
                reporter.epoch(metric_dev)

                # Update learning rate
                model.module.optimizer = lr_controller.decay(
                    model.module.optimizer, epoch=epoch, value=metric_dev)

                if metric_dev < metric_dev_best:
                    metric_dev_best = metric_dev
                    not_improved_n_epochs = 0
                    logger.info('||||| Best Score |||||')

                    # Save the model
                    model.module.save_checkpoint(model.module.save_path,
                                                 lr_controller, epoch,
                                                 step - 1, metric_dev_best)

                    # test
                    for s in eval_sets:
                        if args.metric == 'edit_distance':
                            if args.unit in ['word', 'word_char']:
                                wer_test = eval_word([model.module],
                                                     s,
                                                     recog_params,
                                                     epoch=epoch)[0]
                                logger.info('WER (%s): %.2f %%' %
                                            (s.set, wer_test))
                            elif args.unit == 'wp':
                                wer_test, cer_test = eval_wordpiece(
                                    [model.module],
                                    s,
                                    recog_params,
                                    epoch=epoch)
                                logger.info('WER (%s): %.2f %%' %
                                            (s.set, wer_test))
                                logger.info('CER (%s): %.2f %%' %
                                            (s.set, cer_test))
                            elif 'char' in args.unit:
                                wer_test, cer_test = eval_char([model.module],
                                                               s,
                                                               recog_params,
                                                               epoch=epoch)
                                logger.info('WER (%s): %.2f %%' %
                                            (s.set, wer_test))
                                logger.info('CER (%s): %.2f %%' %
                                            (s.set, cer_test))
                            elif 'phone' in args.unit:
                                per_test = eval_phone([model.module],
                                                      s,
                                                      recog_params,
                                                      epoch=epoch)
                                logger.info('PER (%s): %.2f %%' %
                                            (s.set, per_test))
                        elif args.metric == 'ppl':
                            ppl_test = eval_ppl([model.module], s,
                                                recog_params)[0]
                            logger.info('PPL (%s): %.2f %%' %
                                        (s.set, ppl_test))
                        elif args.metric == 'loss':
                            loss_test = eval_ppl([model.module], s,
                                                 recog_params)[1]
                            logger.info('Loss (%s): %.2f %%' %
                                        (s.set, loss_test))
                        else:
                            raise NotImplementedError(args.metric)
                else:
                    not_improved_n_epochs += 1

                    # start scheduled sampling
                    if args.ss_prob > 0:
                        model.module.scheduled_sampling_trigger()

                duration_eval = time.time() - start_time_eval
                logger.info('Evaluation time: %.2f min' % (duration_eval / 60))

                # Early stopping
                if not_improved_n_epochs == args.not_improved_patient_n_epochs:
                    break

                # Convert to fine-tuning stage
                if epoch == args.convert_to_sgd_epoch:
                    model.module.set_optimizer(
                        'sgd',
                        learning_rate=args.learning_rate,
                        weight_decay=float(args.weight_decay))
                    lr_controller = Controller(
                        learning_rate=args.learning_rate,
                        decay_type='epoch',
                        decay_start_epoch=epoch,
                        decay_rate=0.5,
                        lower_better=True)
                    logger.info('========== Convert to SGD ==========')

            pbar_epoch = tqdm(total=len(train_set))

            if epoch == args.n_epochs:
                break

            start_time_step = time.time()
            start_time_epoch = time.time()
            epoch += 1

    duration_train = time.time() - start_time_train
    logger.info('Total time: %.2f hour' % (duration_train / 3600))

    if reporter.tensorboard:
        reporter.tf_writer.close()
    pbar_epoch.close()

    return model.module.save_path
def main():

    args = parse()

    # Load a conf file
    dir_name = os.path.dirname(args.recog_model[0])
    conf = load_config(os.path.join(dir_name, 'conf.yml'))

    # Overwrite conf
    for k, v in conf.items():
        if 'recog' not in k:
            setattr(args, k, v)
    recog_params = vars(args)

    # Setting for logging
    if os.path.isfile(os.path.join(args.recog_dir, 'plot.log')):
        os.remove(os.path.join(args.recog_dir, 'plot.log'))
    logger = set_logger(os.path.join(args.recog_dir, 'plot.log'), key='decoding')

    for i, s in enumerate(args.recog_sets):
        # Load dataset
        dataset = Dataset(corpus=args.corpus,
                          tsv_path=s,
                          dict_path=os.path.join(dir_name, 'dict.txt'),
                          dict_path_sub1=os.path.join(dir_name, 'dict_sub1.txt') if os.path.isfile(
                              os.path.join(dir_name, 'dict_sub1.txt')) else False,
                          nlsyms=args.nlsyms,
                          wp_model=os.path.join(dir_name, 'wp.model'),
                          unit=args.unit,
                          unit_sub1=args.unit_sub1,
                          batch_size=args.recog_batch_size,
                          is_test=True)

        if i == 0:
            # Load the ASR model
            model = Seq2seq(args)
            model, checkpoint = load_checkpoint(model, args.recog_model[0])
            epoch = checkpoint['epoch']
            model.save_path = dir_name

            # ensemble (different models)
            ensemble_models = [model]
            if len(args.recog_model) > 1:
                for recog_model_e in args.recog_model[1:]:
                    # Load a conf file
                    conf_e = load_config(os.path.join(os.path.dirname(recog_model_e), 'conf.yml'))

                    # Overwrite conf
                    args_e = copy.deepcopy(args)
                    for k, v in conf_e.items():
                        if 'recog' not in k:
                            setattr(args_e, k, v)

                    model_e = Seq2seq(args_e)
                    model_e, _ = load_checkpoint(model_e, recog_model_e)
                    model_e.cuda()
                    ensemble_models += [model_e]

            # For shallow fusion
            if not args.lm_fusion:
                if args.recog_lm is not None and args.recog_lm_weight > 0:
                    # Load a LM conf file
                    conf_lm = load_config(os.path.join(os.path.dirname(args.recog_lm), 'conf.yml'))

                    # Merge conf with args
                    args_lm = argparse.Namespace()
                    for k, v in conf_lm.items():
                        setattr(args_lm, k, v)

                    # Load the pre-trianed LM
                    if args_lm.lm_type == 'gated_cnn':
                        lm = GatedConvLM(args_lm)
                    else:
                        lm = RNNLM(args_lm)
                    lm, _ = load_checkpoint(lm, args.recog_lm)
                    if args_lm.backward:
                        model.lm_bwd = lm
                    else:
                        model.lm_fwd = lm

                if args.recog_lm_bwd is not None and args.recog_lm_weight > 0 and (args.recog_fwd_bwd_attention or args.recog_reverse_lm_rescoring):
                    # Load a LM conf file
                    conf_lm = load_config(os.path.join(args.recog_lm_bwd, 'conf.yml'))

                    # Merge conf with args
                    args_lm_bwd = argparse.Namespace()
                    for k, v in conf_lm.items():
                        setattr(args_lm_bwd, k, v)

                    # Load the pre-trianed LM
                    if args_lm_bwd.lm_type == 'gated_cnn':
                        lm_bwd = GatedConvLM(args_lm_bwd)
                    else:
                        lm_bwd = RNNLM(args_lm_bwd)
                    lm_bwd, _ = load_checkpoint(lm_bwd, args.recog_lm_bwd)
                    model.lm_bwd = lm_bwd

            if not args.recog_unit:
                args.recog_unit = args.unit

            logger.info('recog unit: %s' % args.recog_unit)
            logger.info('recog metric: %s' % args.recog_metric)
            logger.info('recog oracle: %s' % args.recog_oracle)
            logger.info('epoch: %d' % (epoch - 1))
            logger.info('batch size: %d' % args.recog_batch_size)
            logger.info('beam width: %d' % args.recog_beam_width)
            logger.info('min length ratio: %.3f' % args.recog_min_len_ratio)
            logger.info('max length ratio: %.3f' % args.recog_max_len_ratio)
            logger.info('length penalty: %.3f' % args.recog_length_penalty)
            logger.info('coverage penalty: %.3f' % args.recog_coverage_penalty)
            logger.info('coverage threshold: %.3f' % args.recog_coverage_threshold)
            logger.info('CTC weight: %.3f' % args.recog_ctc_weight)
            logger.info('LM path: %s' % args.recog_lm)
            logger.info('LM path (bwd): %s' % args.recog_lm_bwd)
            logger.info('LM weight: %.3f' % args.recog_lm_weight)
            logger.info('GNMT: %s' % args.recog_gnmt_decoding)
            logger.info('forward-backward attention: %s' % args.recog_fwd_bwd_attention)
            logger.info('reverse LM rescoring: %s' % args.recog_reverse_lm_rescoring)
            logger.info('resolving UNK: %s' % args.recog_resolving_unk)
            logger.info('ensemble: %d' % (len(ensemble_models)))
            logger.info('ASR decoder state carry over: %s' % (args.recog_asr_state_carry_over))
            logger.info('LM state carry over: %s' % (args.recog_lm_state_carry_over))
            logger.info('cache size: %d' % (args.recog_n_caches))
            logger.info('cache type: %s' % (args.recog_cache_type))
            logger.info('cache word frequency threshold: %s' % (args.recog_cache_word_freq))
            logger.info('cache theta (speech): %.3f' % (args.recog_cache_theta_speech))
            logger.info('cache lambda (speech): %.3f' % (args.recog_cache_lambda_speech))
            logger.info('cache theta (lm): %.3f' % (args.recog_cache_theta_lm))
            logger.info('cache lambda (lm): %.3f' % (args.recog_cache_lambda_lm))

            # GPU setting
            model.cuda()
            # TODO(hirofumi): move this

        save_path = mkdir_join(args.recog_dir, 'att_weights')
        if args.recog_n_caches > 0:
            save_path_cache = mkdir_join(args.recog_dir, 'cache')

        # Clean directory
        if save_path is not None and os.path.isdir(save_path):
            shutil.rmtree(save_path)
            os.mkdir(save_path)
            if args.recog_n_caches > 0:
                shutil.rmtree(save_path_cache)
                os.mkdir(save_path_cache)

        while True:
            batch, is_new_epoch = dataset.next(recog_params['recog_batch_size'])
            best_hyps_id, aws, (cache_attn_hist, cache_id_hist) = model.decode(
                batch['xs'], recog_params, dataset.idx2token[0],
                exclude_eos=False,
                refs_id=batch['ys'],
                ensemble_models=ensemble_models[1:] if len(ensemble_models) > 1 else [],
                speakers=batch['sessions'] if dataset.corpus == 'swbd' else batch['speakers'])

            if model.bwd_weight > 0.5:
                # Reverse the order
                best_hyps_id = [hyp[::-1] for hyp in best_hyps_id]
                aws = [aw[::-1] for aw in aws]

            for b in range(len(batch['xs'])):
                tokens = dataset.idx2token[0](best_hyps_id[b], return_list=True)
                spk = batch['speakers'][b]

                plot_attention_weights(
                    aws[b][:len(tokens)],
                    tokens,
                    spectrogram=batch['xs'][b][:, :dataset.input_dim] if args.input_type == 'speech' else None,
                    save_path=mkdir_join(save_path, spk, batch['utt_ids'][b] + '.png'),
                    figsize=(20, 8))

                if args.recog_n_caches > 0 and cache_id_hist is not None and cache_attn_hist is not None:
                    n_keys, n_queries = cache_attn_hist[0].shape
                    # mask = np.ones((n_keys, n_queries))
                    # for i in range(n_queries):
                    #     mask[:n_keys - i, -(i + 1)] = 0
                    mask = np.zeros((n_keys, n_queries))

                    plot_cache_weights(
                        cache_attn_hist[0],
                        keys=dataset.idx2token[0](cache_id_hist[-1], return_list=True),  # fifo
                        # keys=dataset.idx2token[0](cache_id_hist, return_list=True),  # dict
                        queries=tokens,
                        save_path=mkdir_join(save_path_cache, spk, batch['utt_ids'][b] + '.png'),
                        figsize=(40, 16),
                        mask=mask)

                if model.bwd_weight > 0.5:
                    hyp = ' '.join(tokens[::-1])
                else:
                    hyp = ' '.join(tokens)
                logger.info('utt-id: %s' % batch['utt_ids'][b])
                logger.info('Ref: %s' % batch['text'][b].lower())
                logger.info('Hyp: %s' % hyp)
                logger.info('-' * 50)

            if is_new_epoch:
                break
Exemple #7
0
def main():

    # Load a config file
    if args.resume_model is None:
        config = load_config(args.config)
    else:
        # Restart from the last checkpoint
        config = load_config(os.path.join(args.resume_model, 'config.yml'))

    # Check differences between args and yaml comfiguraiton
    for k, v in vars(args).items():
        if k not in config.keys():
            warnings.warn("key %s is automatically set to %s" % (k, str(v)))

    # Merge config with args
    for k, v in config.items():
        setattr(args, k, v)

    # Load dataset
    train_set = Dataset(csv_path=args.train_set,
                        dict_path=args.dict,
                        label_type=args.label_type,
                        batch_size=args.batch_size * args.ngpus,
                        bptt=args.bptt,
                        eos=args.eos,
                        max_epoch=args.num_epochs,
                        shuffle=True)
    dev_set = Dataset(csv_path=args.dev_set,
                      dict_path=args.dict,
                      label_type=args.label_type,
                      batch_size=args.batch_size * args.ngpus,
                      bptt=args.bptt,
                      eos=args.eos,
                      shuffle=True)
    eval_sets = []
    for set in args.eval_sets:
        eval_sets += [Dataset(csv_path=set,
                              dict_path=args.dict,
                              label_type=args.label_type,
                              batch_size=1,
                              bptt=args.bptt,
                              eos=args.eos,
                              is_test=True)]

    args.num_classes = train_set.num_classes

    # Model setting
    model = RNNLM(args)
    model.name = args.rnn_type
    model.name += str(args.num_units) + 'H'
    model.name += str(args.num_projs) + 'P'
    model.name += str(args.num_layers) + 'L'
    model.name += '_emb' + str(args.emb_dim)
    model.name += '_' + args.optimizer
    model.name += '_lr' + str(args.learning_rate)
    model.name += '_bs' + str(args.batch_size)
    if args.tie_weights:
        model.name += '_tie'
    if args.residual:
        model.name += '_residual'
    if args.backward:
        model.name += '_bwd'

    if args.resume_model is None:
        # Set save path
        save_path = mkdir_join(args.model, '_'.join(os.path.basename(args.train_set).split('.')[:-1]), model.name)
        model.set_save_path(save_path)  # avoid overwriting

        # Save the config file as a yaml file
        save_config(vars(args), model.save_path)

        # Save the dictionary & wp_model
        shutil.copy(args.dict, os.path.join(save_path, 'dict.txt'))
        if args.label_type == 'wordpiece':
            shutil.copy(args.wp_model, os.path.join(save_path, 'wp.model'))

        # Setting for logging
        logger = set_logger(os.path.join(model.save_path, 'train.log'), key='training')

        for k, v in sorted(vars(args).items(), key=lambda x: x[0]):
            logger.info('%s: %s' % (k, str(v)))

        # Count total parameters
        for name in sorted(list(model.num_params_dict.keys())):
            num_params = model.num_params_dict[name]
            logger.info("%s %d" % (name, num_params))
        logger.info("Total %.2f M parameters" % (model.total_parameters / 1000000))

        # Set optimizer
        model.set_optimizer(optimizer=args.optimizer,
                            learning_rate_init=float(args.learning_rate),
                            weight_decay=float(args.weight_decay),
                            clip_grad_norm=args.clip_grad_norm,
                            lr_schedule=False,
                            factor=args.decay_rate,
                            patience_epoch=args.decay_patient_epoch)

        epoch, step = 1, 0
        learning_rate = float(args.learning_rate)
        metric_dev_best = 10000

    else:
        raise NotImplementedError()

    train_set.epoch = epoch - 1

    # GPU setting
    if args.ngpus >= 1:
        model = CustomDataParallel(model,
                                   device_ids=list(range(0, args.ngpus, 1)),
                                   deterministic=True,
                                   benchmark=False)
        model.cuda()

    logger.info('PID: %s' % os.getpid())
    logger.info('USERNAME: %s' % os.uname()[1])

    # Set process name
    # setproctitle(args.job_name)

    # Set learning rate controller
    lr_controller = Controller(learning_rate_init=learning_rate,
                               decay_type=args.decay_type,
                               decay_start_epoch=args.decay_start_epoch,
                               decay_rate=args.decay_rate,
                               decay_patient_epoch=args.decay_patient_epoch,
                               lower_better=True,
                               best_value=metric_dev_best)

    # Set reporter
    reporter = Reporter(model.module.save_path, max_loss=10)

    # Set the updater
    updater = Updater(args.clip_grad_norm)

    # Setting for tensorboard
    tf_writer = SummaryWriter(model.module.save_path)

    start_time_train = time.time()
    start_time_epoch = time.time()
    start_time_step = time.time()
    not_improved_epoch = 0
    loss_train_mean, acc_train_mean = 0., 0.
    pbar_epoch = tqdm(total=len(train_set))
    pbar_all = tqdm(total=len(train_set) * args.num_epochs)
    while True:
        # Compute loss in the training set (including parameter update)
        ys_train, is_new_epoch = train_set.next()
        model, loss_train, acc_train = updater(model, ys_train, args.bptt)
        loss_train_mean += loss_train
        acc_train_mean += acc_train
        pbar_epoch.update(np.sum([len(y) for y in ys_train]))

        if (step + 1) % args.print_step == 0:
            # Compute loss in the dev set
            ys_dev = dev_set.next()[0]
            model, loss_dev, acc_dev = updater(model, ys_dev, args.bptt, is_eval=True)

            loss_train_mean /= args.print_step
            acc_train_mean /= args.print_step
            reporter.step(step, loss_train_mean, loss_dev, acc_train_mean, acc_dev)

            # Logging by tensorboard
            tf_writer.add_scalar('train/loss', loss_train_mean, step + 1)
            tf_writer.add_scalar('dev/loss', loss_dev, step + 1)
            for n, p in model.module.named_parameters():
                n = n.replace('.', '/')
                if p.grad is not None:
                    tf_writer.add_histogram(n, p.data.cpu().numpy(), step + 1)
                    tf_writer.add_histogram(n + '/grad', p.grad.data.cpu().numpy(), step + 1)

            duration_step = time.time() - start_time_step
            logger.info("...Step:%d(ep:%.2f) loss:%.2f(%.2f)/acc:%.2f(%.2f)/ppl:%.2f(%.2f)/lr:%.5f/bs:%d (%.2f min)" %
                        (step + 1, train_set.epoch_detail,
                         loss_train_mean, loss_dev, acc_train_mean, acc_dev,
                         math.exp(loss_train_mean), math.exp(loss_dev),
                         learning_rate, len(ys_train), duration_step / 60))
            start_time_step = time.time()
            loss_train_mean, acc_train_mean = 0., 0.
        step += args.ngpus

        # Save checkpoint and evaluate model per epoch
        if is_new_epoch:
            duration_epoch = time.time() - start_time_epoch
            logger.info('===== EPOCH:%d (%.2f min) =====' % (epoch, duration_epoch / 60))

            # Save fugures of loss and accuracy
            reporter.epoch()

            if epoch < args.eval_start_epoch:
                # Save the model
                model.module.save_checkpoint(model.module.save_path, epoch, step,
                                             learning_rate, metric_dev_best)
            else:
                start_time_eval = time.time()
                # dev
                ppl_dev = eval_ppl([model.module], dev_set, args.bptt)
                logger.info(' PPL (%s): %.3f' % (dev_set.set, ppl_dev))

                if ppl_dev < metric_dev_best:
                    metric_dev_best = ppl_dev
                    not_improved_epoch = 0
                    logger.info('||||| Best Score |||||')

                    # Update learning rate
                    model.module.optimizer, learning_rate = lr_controller.decay_lr(
                        optimizer=model.module.optimizer,
                        learning_rate=learning_rate,
                        epoch=epoch,
                        value=ppl_dev)

                    # Save the model
                    model.module.save_checkpoint(model.module.save_path, epoch, step,
                                                 learning_rate, metric_dev_best)

                    # test
                    ppl_test_mean = 0.
                    for eval_set in eval_sets:
                        ppl_test = eval_ppl([model.module], eval_set, args.bptt)
                        logger.info(' PPL (%s): %.3f' % (eval_set.set, ppl_test))
                        ppl_test_mean += ppl_test
                    if len(eval_sets) > 0:
                        logger.info(' PPL (mean): %.3f' % (ppl_test_mean / len(eval_sets)))
                else:
                    # Update learning rate
                    model.module.optimizer, learning_rate = lr_controller.decay_lr(
                        optimizer=model.module.optimizer,
                        learning_rate=learning_rate,
                        epoch=epoch,
                        value=ppl_dev)

                    not_improved_epoch += 1

                duration_eval = time.time() - start_time_eval
                logger.info('Evaluation time: %.2f min' % (duration_eval / 60))

                # Early stopping
                if not_improved_epoch == args.not_improved_patient_epoch:
                    break

                if epoch == args.convert_to_sgd_epoch:
                    # Convert to fine-tuning stage
                    model.module.set_optimizer(
                        'sgd',
                        learning_rate_init=float(args.learning_rate),  # TODO: ?
                        weight_decay=float(args.weight_decay),
                        clip_grad_norm=args.clip_grad_norm,
                        lr_schedule=False,
                        factor=args.decay_rate,
                        patience_epoch=args.decay_patient_epoch)
                    logger.info('========== Convert to SGD ==========')

            pbar_epoch = tqdm(total=len(train_set))
            pbar_all.update(len(train_set))

            if epoch == args.num_epoch:
                break

            start_time_step = time.time()
            start_time_epoch = time.time()
            epoch += 1

    duration_train = time.time() - start_time_train
    logger.info('Total time: %.2f hour' % (duration_train / 3600))

    tf_writer.close()
    pbar_epoch.close()
    pbar_all.close()

    return model.module.save_path
Exemple #8
0
def eval_word(models,
              dataset,
              recog_params,
              epoch,
              recog_dir=None,
              word_list=[],
              progressbar=False):
    """Evaluate the word-level model by WER.

    Args:
        models (list): models to evaluate
        dataset: An instance of a `Dataset' class
        recog_params (dict):
        epoch (int):
        recog_dir (str):
        word_list (list):
        progressbar (bool): visualize the progressbar
    Returns:
        wer (float): Word error rate
        cer (float): Character error rate
        n_oov_total (int): totol number of OOV

    """
    # Reset data counter
    dataset.reset()

    if recog_dir is None:
        recog_dir = 'decode_' + dataset.set + '_ep' + str(
            epoch) + '_beam' + str(recog_params['recog_beam_width'])
        recog_dir += '_lp' + str(recog_params['recog_length_penalty'])
        recog_dir += '_cp' + str(recog_params['recog_coverage_penalty'])
        recog_dir += '_' + str(
            recog_params['recog_min_len_ratio']) + '_' + str(
                recog_params['recog_max_len_ratio'])
        recog_dir += '_lm' + str(recog_params['recog_lm_weight'])

        ref_trn_save_path = mkdir_join(models[0].save_path, recog_dir,
                                       'ref.trn')
        hyp_trn_save_path = mkdir_join(models[0].save_path, recog_dir,
                                       'hyp.trn')
    else:
        ref_trn_save_path = mkdir_join(recog_dir, 'ref.trn')
        hyp_trn_save_path = mkdir_join(recog_dir, 'hyp.trn')

    wer, cer = 0, 0
    n_sub_w, n_ins_w, n_del_w = 0, 0, 0
    n_sub_c, n_ins_c, n_del_c = 0, 0, 0
    n_word, n_char = 0, 0
    n_oov_total = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))  # TODO(hirofumi): fix this

    with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path,
                                                     'w') as f_ref:
        while True:
            batch, is_new_epoch = dataset.next(
                recog_params['recog_batch_size'])
            best_hyps_id, aws, _, _ = models[0].decode(
                batch['xs'],
                recog_params,
                dataset.idx2token[0],
                exclude_eos=True,
                refs_id=batch['ys'],
                utt_ids=batch['utt_ids'],
                speakers=batch['sessions']
                if dataset.corpus == 'swbd' else batch['speakers'],
                ensemble_models=models[1:] if len(models) > 1 else [],
                word_list=word_list)

            for b in range(len(batch['xs'])):
                ref = batch['text'][b]
                hyp = dataset.idx2token[0](best_hyps_id[b])

                n_oov_total += hyp.count('<unk>')

                # Resolving UNK
                if recog_params['recog_resolving_unk'] and '<unk>' in hyp:
                    recog_params_char = copy.deepcopy(recog_params)
                    recog_params_char['recog_lm_weight'] = 0
                    recog_params_char['recog_beam_width'] = 1
                    best_hyps_id_char, aw_char, _, _ = models[0].decode(
                        batch['xs'][b:b + 1],
                        recog_params_char,
                        dataset.idx2token[1],
                        exclude_eos=True,
                        refs_id=batch['ys_sub1'],
                        utt_ids=batch['utt_ids'],
                        speakers=batch['sessions']
                        if dataset.corpus == 'swbd' else batch['speakers'],
                        task='ys_sub1')
                    # TODO(hirofumi): support ys_sub2 and ys_sub3

                    hyp = resolve_unk(
                        hyp,
                        best_hyps_id_char[0],
                        aws[b],
                        aw_char[0],
                        dataset.idx2token[1],
                        subsample_factor_word=np.prod(models[0].subsample),
                        subsample_factor_char=np.prod(
                            models[0].subsample[:models[0].enc_n_layers_sub1 -
                                                1]))
                    logger.info('Hyp (after OOV resolution): %s' % hyp)
                    hyp = hyp.replace('*', '')

                    # Compute CER
                    ref_char = ref
                    hyp_char = hyp
                    if dataset.corpus == 'csj':
                        ref_char = ref.replace(' ', '')
                        hyp_char = hyp.replace(' ', '')
                    cer_b, sub_b, ins_b, del_b = compute_wer(
                        ref=list(ref_char),
                        hyp=list(hyp_char),
                        normalize=False)
                    cer += cer_b
                    n_sub_c += sub_b
                    n_ins_c += ins_b
                    n_del_c += del_b
                    n_char += len(ref_char)

                # Write to trn
                utt_id = str(batch['utt_ids'][b])
                speaker = str(batch['speakers'][b]).replace('-', '_')
                f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n')
                f_hyp.write(hyp + ' (' + speaker + '-' + utt_id + ')\n')
                logger.info('utt-id: %s' % batch['utt_ids'][b])
                logger.info('Ref: %s' % ref)
                logger.info('Hyp: %s' % hyp)
                logger.info('-' * 150)

                # Compute WER
                wer_b, sub_b, ins_b, del_b = compute_wer(ref=ref.split(' '),
                                                         hyp=hyp.split(' '),
                                                         normalize=False)
                wer += wer_b
                n_sub_w += sub_b
                n_ins_w += ins_b
                n_del_w += del_b
                n_word += len(ref.split(' '))

                if progressbar:
                    pbar.update(1)

            if is_new_epoch:
                break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    wer /= n_word
    n_sub_w /= n_word
    n_ins_w /= n_word
    n_del_w /= n_word

    if n_char > 0:
        cer /= n_char
        n_sub_c /= n_char
        n_ins_c /= n_char
        n_del_c /= n_char

    logger.info('WER (%s): %.2f %%' % (dataset.set, wer))
    logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' %
                (n_sub_w, n_ins_w, n_del_w))
    logger.info('CER (%s): %.2f %%' % (dataset.set, cer))
    logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' %
                (n_sub_c, n_ins_c, n_del_c))
    logger.info('OOV (total): %d' % (n_oov_total))

    return wer, cer, n_oov_total
Exemple #9
0
def main():

    args = parse()

    # Load a conf file
    if args.resume:
        conf = load_config(os.path.join(args.resume, 'conf.yml'))
        for k, v in conf.items():
            setattr(args, k, v)

    # Load dataset
    train_set = Dataset(corpus=args.corpus,
                        tsv_path=args.train_set,
                        dict_path=args.dict,
                        unit=args.unit,
                        wp_model=args.wp_model,
                        batch_size=args.batch_size * args.n_gpus,
                        n_epochs=args.n_epochs,
                        bptt=args.bptt,
                        serialize=args.serialize)
    dev_set = Dataset(corpus=args.corpus,
                      tsv_path=args.dev_set,
                      dict_path=args.dict,
                      unit=args.unit,
                      wp_model=args.wp_model,
                      batch_size=args.batch_size * args.n_gpus,
                      bptt=args.bptt,
                      serialize=args.serialize)
    eval_sets = []
    for s in args.eval_sets:
        eval_sets += [
            Dataset(corpus=args.corpus,
                    tsv_path=s,
                    dict_path=args.dict,
                    unit=args.unit,
                    wp_model=args.wp_model,
                    batch_size=1,
                    bptt=args.bptt,
                    serialize=args.serialize)
        ]

    args.vocab = train_set.vocab

    # Model setting
    if args.lm_type == 'gated_cnn':
        model = GatedConvLM(args)
    else:
        model = RNNLM(args)
    dir_name = args.lm_type
    dir_name += str(args.n_units) + 'H'
    dir_name += str(args.n_projs) + 'P'
    dir_name += str(args.n_layers) + 'L'
    dir_name += '_emb' + str(args.emb_dim)
    dir_name += '_' + args.optimizer
    dir_name += '_lr' + str(args.learning_rate)
    dir_name += '_bs' + str(args.batch_size)
    dir_name += '_bptt' + str(args.bptt)
    if args.tie_embedding:
        dir_name += '_tie'
    if args.residual:
        dir_name += '_residual'
    if args.use_glu:
        dir_name += '_glu'
    if args.backward:
        dir_name += '_bwd'
    if args.serialize:
        dir_name += '_serialize'

    if args.resume:
        raise NotImplementedError
    else:
        # Set save path
        save_path = mkdir_join(
            args.model,
            '_'.join(os.path.basename(args.train_set).split('.')[:-1]),
            dir_name)
        model.set_save_path(save_path)  # avoid overwriting

        # Save the conf file as a yaml file
        save_config(vars(args), os.path.join(model.save_path, 'conf.yml'))

        # Save the dictionary & wp_model
        shutil.copy(args.dict, os.path.join(model.save_path, 'dict.txt'))
        if args.unit == 'wp':
            shutil.copy(args.wp_model, os.path.join(model.save_path,
                                                    'wp.model'))

        # Setting for logging
        logger = set_logger(os.path.join(model.save_path, 'train.log'),
                            key='training')

        for k, v in sorted(vars(args).items(), key=lambda x: x[0]):
            logger.info('%s: %s' % (k, str(v)))

        # Count total parameters
        for n in sorted(list(model.num_params_dict.keys())):
            nparams = model.num_params_dict[n]
            logger.info("%s %d" % (n, nparams))
        logger.info("Total %.2f M parameters" %
                    (model.total_parameters / 1000000))
        logger.info(model)

        # Set optimizer
        model.set_optimizer(optimizer=args.optimizer,
                            learning_rate=float(args.learning_rate),
                            weight_decay=float(args.weight_decay))

        epoch, step = 1, 1
        ppl_dev_best = 10000

        # Set learning rate controller
        lr_controller = Controller(
            learning_rate=float(args.learning_rate),
            decay_type=args.decay_type,
            decay_start_epoch=args.decay_start_epoch,
            decay_rate=args.decay_rate,
            decay_patient_n_epochs=args.decay_patient_n_epochs,
            lower_better=True,
            best_value=ppl_dev_best)

    train_set.epoch = epoch - 1  # start from index:0

    # GPU setting
    if args.n_gpus >= 1:
        model = CustomDataParallel(model,
                                   device_ids=list(range(0, args.n_gpus, 1)),
                                   deterministic=False,
                                   benchmark=True)
        model.cuda()

    logger.info('PID: %s' % os.getpid())
    logger.info('USERNAME: %s' % os.uname()[1])

    # Set process name
    if args.job_name:
        setproctitle(args.job_name)
    else:
        setproctitle(dir_name)

    # Set reporter
    reporter = Reporter(model.module.save_path, tensorboard=True)

    hidden = None
    start_time_train = time.time()
    start_time_epoch = time.time()
    start_time_step = time.time()
    not_improved_epoch = 0
    pbar_epoch = tqdm(total=len(train_set))
    while True:
        # Compute loss in the training set
        ys_train, is_new_epoch = train_set.next()

        model.module.optimizer.zero_grad()
        loss, hidden, reporter = model(ys_train, hidden, reporter)
        if len(model.device_ids) > 1:
            loss.backward(torch.ones(len(model.device_ids)))
        else:
            loss.backward()
        loss.detach()  # Trancate the graph
        if args.clip_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.module.parameters(),
                                           args.clip_grad_norm)
        model.module.optimizer.step()
        loss_train = loss.item()
        del loss
        if args.lm_type != 'gated_cnn':
            hidden = model.module.repackage_hidden(hidden)
        reporter.step(is_eval=False)

        if step % args.print_step == 0:
            # Compute loss in the dev set
            ys_dev = dev_set.next()[0]
            loss, _, reporter = model(ys_dev, None, reporter, is_eval=True)
            loss_dev = loss.item()
            del loss
            reporter.step(is_eval=True)

            duration_step = time.time() - start_time_step
            logger.info(
                "step:%d(ep:%.2f) loss:%.3f(%.3f)/ppl:%.3f(%.3f)/lr:%.5f/bs:%d (%.2f min)"
                % (step, train_set.epoch_detail, loss_train, loss_dev,
                   math.exp(loss_train), math.exp(loss_dev), lr_controller.lr,
                   len(ys_train), duration_step / 60))
            start_time_step = time.time()
        step += args.n_gpus
        pbar_epoch.update(np.prod(ys_train.shape))

        # Save fugures of loss and accuracy
        if step % (args.print_step * 10) == 0:
            reporter.snapshot()

        # Save checkpoint and evaluate model per epoch
        if is_new_epoch:
            duration_epoch = time.time() - start_time_epoch
            logger.info('========== EPOCH:%d (%.2f min) ==========' %
                        (epoch, duration_epoch / 60))

            if epoch < args.eval_start_epoch:
                # Save the model
                model.module.save_checkpoint(model.module.save_path,
                                             lr_controller, epoch, step - 1,
                                             ppl_dev_best)
            else:
                start_time_eval = time.time()
                # dev
                ppl_dev, _ = eval_ppl([model.module],
                                      dev_set,
                                      batch_size=1,
                                      bptt=args.bptt)
                logger.info('PPL (%s): %.2f' % (dev_set.set, ppl_dev))

                # Update learning rate
                model.module.optimizer = lr_controller.decay(
                    model.module.optimizer, epoch=epoch, value=ppl_dev)

                if ppl_dev < ppl_dev_best:
                    ppl_dev_best = ppl_dev
                    not_improved_epoch = 0
                    logger.info('||||| Best Score |||||')

                    # Save the model
                    model.module.save_checkpoint(model.module.save_path,
                                                 lr_controller, epoch,
                                                 step - 1, ppl_dev_best)

                    # test
                    ppl_test_avg = 0.
                    for eval_set in eval_sets:
                        ppl_test, _ = eval_ppl([model.module],
                                               eval_set,
                                               batch_size=1,
                                               bptt=args.bptt)
                        logger.info('PPL (%s): %.2f' %
                                    (eval_set.set, ppl_test))
                        ppl_test_avg += ppl_test
                    if len(eval_sets) > 0:
                        logger.info('PPL (avg.): %.2f' %
                                    (ppl_test_avg / len(eval_sets)))
                else:
                    not_improved_epoch += 1

                duration_eval = time.time() - start_time_eval
                logger.info('Evaluation time: %.2f min' % (duration_eval / 60))

                # Early stopping
                if not_improved_epoch == args.not_improved_patient_n_epochs:
                    break

                # Convert to fine-tuning stage
                if epoch == args.convert_to_sgd_epoch:
                    model.module.set_optimizer(
                        'sgd',
                        learning_rate=args.learning_rate,
                        weight_decay=float(args.weight_decay))
                    lr_controller = Controller(
                        learning_rate=args.learning_rate,
                        decay_type='epoch',
                        decay_start_epoch=epoch,
                        decay_rate=0.5,
                        lower_better=True)
                    logger.info('========== Convert to SGD ==========')

            pbar_epoch = tqdm(total=len(train_set))

            if epoch == args.n_epochs:
                break

            start_time_step = time.time()
            start_time_epoch = time.time()
            epoch += 1

    duration_train = time.time() - start_time_train
    logger.info('Total time: %.2f hour' % (duration_train / 3600))

    if reporter.tensorboard:
        reporter.tf_writer.close()
    pbar_epoch.close()

    return model.module.save_path
Exemple #10
0
def eval_char(models,
              dataset,
              decode_params,
              epoch,
              decode_dir=None,
              progressbar=False,
              task_id=0):
    """Evaluate the character-level model by WER & CER.

    Args:
        models (list): the models to evaluate
        dataset: An instance of a `Dataset' class
        decode_params (dict):
        epoch (int):
        decode_dir (str):
        progressbar (bool): if True, visualize the progressbar
        task_id (int): the index of the target task in interest
            0: main task
            1: sub task
            2: sub sub task
    Returns:
        wer (float): Word error rate
        nsub_w (int): the number of substitution errors for WER
        nins_w (int): the number of insertion errors for WER
        ndel_w (int): the number of deletion errors for WER
        cer (float): Character error rate
        nsub_w (int): the number of substitution errors for CER
        nins_c (int): the number of insertion errors for CER
        ndel_c (int): the number of deletion errors for CER

    """
    # Reset data counter
    dataset.reset()

    model = models[0]

    if decode_dir is None:
        decode_dir = 'decode_' + dataset.set + '_ep' + str(
            epoch) + '_beam' + str(decode_params['beam_width'])
        decode_dir += '_lp' + str(decode_params['length_penalty'])
        decode_dir += '_cp' + str(decode_params['coverage_penalty'])
        decode_dir += '_' + str(decode_params['min_len_ratio']) + '_' + str(
            decode_params['max_len_ratio'])
        decode_dir += '_rnnlm' + str(decode_params['rnnlm_weight'])

        ref_trn_save_path = mkdir_join(model.save_path, decode_dir, 'ref.trn')
        hyp_trn_save_path = mkdir_join(model.save_path, decode_dir, 'hyp.trn')
    else:
        ref_trn_save_path = mkdir_join(decode_dir, 'ref.trn')
        hyp_trn_save_path = mkdir_join(decode_dir, 'hyp.trn')

    wer, cer = 0, 0
    nsub_w, nins_w, ndel_w = 0, 0, 0
    nsub_c, nins_c, ndel_c = 0, 0, 0
    nword, nchar = 0, 0
    if progressbar:
        pbar = tqdm(total=len(dataset))

    if task_id == 0:
        task = 'ys'
    elif task_id == 1:
        task = 'ys_sub1'
    elif task_id == 2:
        task = 'ys_sub2'

    with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path,
                                                     'w') as f_ref:
        while True:
            batch, is_new_epoch = dataset.next(decode_params['batch_size'])
            best_hyps, _, perm_ids = model.decode(batch['xs'],
                                                  decode_params,
                                                  exclude_eos=True,
                                                  task=task)
            ys = [batch['text'][i] for i in perm_ids]

            for b in six.moves.range(len(batch['xs'])):
                ref = ys[b]
                hyp = dataset.id2char(best_hyps[b])

                # Write to trn
                speaker = '_'.join(batch['utt_ids'][b].replace(
                    '-', '_').split('_')[:-2])
                start = batch['utt_ids'][b].replace('-', '_').split('_')[-2]
                end = batch['utt_ids'][b].replace('-', '_').split('_')[-1]
                f_ref.write(ref + ' (' + speaker + '-' + start + '-' + end +
                            ')\n')
                f_hyp.write(hyp + ' (' + speaker + '-' + start + '-' + end +
                            ')\n')
                logger.info('utt-id: %s' % batch['utt_ids'][b])
                # logger.info('Ref: %s' % ref.lower())
                logger.info('Ref: %s' % ref)
                logger.info('Hyp: %s' % hyp)
                logger.info('-' * 50)

                if ('char' in dataset.unit and 'nowb' not in dataset.unit) or (
                        task_id > 0 and dataset.unit_sub1 == 'char'):
                    # Compute WER
                    wer_b, sub_b, ins_b, del_b = compute_wer(
                        ref=ref.split(' '),
                        hyp=hyp.split(' '),
                        normalize=False)
                    wer += wer_b
                    nsub_w += sub_b
                    nins_w += ins_b
                    ndel_w += del_b
                    nword += len(ref.split(' '))
                    # logger.info('WER: %d%%' % (wer_b / len(ref.split(' '))))

                # Compute CER
                cer_b, sub_b, ins_b, del_b = compute_wer(ref=list(ref),
                                                         hyp=list(hyp),
                                                         normalize=False)
                cer += cer_b
                nsub_c += sub_b
                nins_c += ins_b
                ndel_c += del_b
                nchar += len(ref)
                # logger.info('CER: %d%%' % (cer_b / len(ref)))

                if progressbar:
                    pbar.update(1)

            if is_new_epoch:
                break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    if ('char' in dataset.unit and 'nowb' not in dataset.unit) or (
            task_id > 0 and dataset.unit_sub1 == 'char'):
        wer /= nword
        nsub_w /= nword
        nins_w /= nword
        ndel_w /= nword
    else:
        wer = nsub_w = nins_w = ndel_w = 0

    cer /= nchar
    nsub_c /= nchar
    nins_c /= nchar
    ndel_c /= nchar

    return (wer, nsub_w, nins_w, ndel_w), (cer, nsub_c, nins_c, ndel_c)
Exemple #11
0
def eval_wordpiece(models,
                   dataset,
                   decode_params,
                   epoch,
                   decode_dir=None,
                   progressbar=False):
    """Evaluate the wordpiece-level model by WER.

    Args:
        models (list): the models to evaluate
        dataset: An instance of a `Dataset' class
        decode_params (dict):
        epoch (int):
        decode_dir (str):
        progressbar (bool): if True, visualize the progressbar
    Returns:
        wer (float): Word error rate
        nsub (int): the number of substitution errors
        nins (int): the number of insertion errors
        ndel (int): the number of deletion errors

    """
    # Reset data counter
    dataset.reset()

    model = models[0]
    # TODO(hirofumi): ensemble decoding

    if decode_dir is None:
        decode_dir = 'decode_' + dataset.set + '_ep' + str(
            epoch) + '_beam' + str(decode_params['beam_width'])
        decode_dir += '_lp' + str(decode_params['length_penalty'])
        decode_dir += '_cp' + str(decode_params['coverage_penalty'])
        decode_dir += '_' + str(decode_params['min_len_ratio']) + '_' + str(
            decode_params['max_len_ratio'])
        decode_dir += '_rnnlm' + str(decode_params['rnnlm_weight'])

        ref_trn_save_path = mkdir_join(model.save_path, decode_dir, 'ref.trn')
        hyp_trn_save_path = mkdir_join(model.save_path, decode_dir, 'hyp.trn')
    else:
        ref_trn_save_path = mkdir_join(decode_dir, 'ref.trn')
        hyp_trn_save_path = mkdir_join(decode_dir, 'hyp.trn')

    wer = 0
    nsub, nins, ndel = 0, 0, 0
    nword = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))

    with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path,
                                                     'w') as f_ref:
        while True:
            batch, is_new_epoch = dataset.next(decode_params['batch_size'])
            best_hyps, _, perm_id = model.decode(batch['xs'],
                                                 decode_params,
                                                 exclude_eos=True,
                                                 id2token=dataset.id2wp,
                                                 refs=batch['ys'])
            ys = [batch['text'][i] for i in perm_id]

            for b in six.moves.range(len(batch['xs'])):
                ref = ys[b]
                hyp = dataset.id2wp(best_hyps[b])

                # Write to trn
                speaker = '_'.join(batch['utt_ids'][b].replace(
                    '-', '_').split('_')[:-2])
                start = batch['utt_ids'][b].replace('-', '_').split('_')[-2]
                end = batch['utt_ids'][b].replace('-', '_').split('_')[-1]
                f_ref.write(ref + ' (' + speaker + '-' + start + '-' + end +
                            ')\n')
                f_hyp.write(hyp + ' (' + speaker + '-' + start + '-' + end +
                            ')\n')
                logger.info('utt-id: %s' % batch['utt_ids'][b])
                # logger.info('Ref: %s' % ref.lower())
                logger.info('Ref: %s' % ref)
                logger.info('Hyp: %s' % hyp)
                logger.info('-' * 50)

                # Compute WER
                wer_b, sub_b, ins_b, del_b = compute_wer(ref=ref.split(' '),
                                                         hyp=hyp.split(' '),
                                                         normalize=False)
                wer += wer_b
                nsub += sub_b
                nins += ins_b
                ndel += del_b
                nword += len(ref.split(' '))
                # logger.info('WER: %d%%' % (float(wer_b) / len(ref.split(' '))))

                if progressbar:
                    pbar.update(1)

            if is_new_epoch:
                break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    wer /= nword
    nsub /= nword
    nins /= nword
    ndel /= nword

    return wer, nsub, nins, ndel
Exemple #12
0
def main():

    # Load a config file
    if args.resume:
        config = load_config(os.path.join(args.resume, 'config.yml'))
        for k, v in config.items():
            setattr(args, k, v)

    # Automatically reduce batch size in multi-GPU setting
    if args.ngpus > 1:
        args.batch_size -= 10
        args.print_step //= args.ngpus

    subsample_factor = 1
    subsample_factor_sub1 = 1
    subsample_factor_sub2 = 1
    subsample = [int(s) for s in args.subsample.split('_')]
    if args.conv_poolings:
        for p in args.conv_poolings.split('_'):
            p = int(p.split(',')[0].replace('(', ''))
            if p > 1:
                subsample_factor *= p
    if args.train_set_sub1:
        subsample_factor_sub1 = subsample_factor * np.prod(
            subsample[:args.enc_nlayers_sub1 - 1])
    if args.train_set_sub2:
        subsample_factor_sub2 = subsample_factor * np.prod(
            subsample[:args.enc_nlayers_sub2 - 1])
    subsample_factor *= np.prod(subsample)

    # Load dataset
    train_set = Dataset(csv_path=args.train_set,
                        csv_path_sub1=args.train_set_sub1,
                        csv_path_sub2=args.train_set_sub2,
                        dict_path=args.dict,
                        dict_path_sub1=args.dict_sub1,
                        dict_path_sub2=args.dict_sub2,
                        unit=args.unit,
                        unit_sub1=args.unit_sub1,
                        unit_sub2=args.unit_sub2,
                        wp_model=args.wp_model,
                        wp_model_sub1=args.wp_model_sub1,
                        wp_model_sub2=args.wp_model_sub2,
                        batch_size=args.batch_size * args.ngpus,
                        nepochs=args.nepochs,
                        min_nframes=args.min_nframes,
                        max_nframes=args.max_nframes,
                        sort_by_input_length=True,
                        short2long=True,
                        sort_stop_epoch=args.sort_stop_epoch,
                        dynamic_batching=args.dynamic_batching,
                        ctc=args.ctc_weight > 0,
                        ctc_sub1=args.ctc_weight_sub1 > 0,
                        ctc_sub2=args.ctc_weight_sub2 > 0,
                        subsample_factor=subsample_factor,
                        subsample_factor_sub1=subsample_factor_sub1,
                        subsample_factor_sub2=subsample_factor_sub2,
                        skip_speech=(args.input_type != 'speech'))
    dev_set = Dataset(csv_path=args.dev_set,
                      csv_path_sub1=args.dev_set_sub1,
                      csv_path_sub2=args.dev_set_sub2,
                      dict_path=args.dict,
                      dict_path_sub1=args.dict_sub1,
                      dict_path_sub2=args.dict_sub2,
                      unit=args.unit,
                      unit_sub1=args.unit_sub1,
                      unit_sub2=args.unit_sub2,
                      wp_model=args.wp_model,
                      wp_model_sub1=args.wp_model_sub1,
                      wp_model_sub2=args.wp_model_sub2,
                      batch_size=args.batch_size * args.ngpus,
                      min_nframes=args.min_nframes,
                      max_nframes=args.max_nframes,
                      shuffle=True,
                      ctc=args.ctc_weight > 0,
                      ctc_sub1=args.ctc_weight_sub1 > 0,
                      ctc_sub2=args.ctc_weight_sub2 > 0,
                      subsample_factor=subsample_factor,
                      subsample_factor_sub1=subsample_factor_sub1,
                      subsample_factor_sub2=subsample_factor_sub2,
                      skip_speech=(args.input_type != 'speech'))
    eval_sets = []
    for set in args.eval_sets:
        eval_sets += [
            Dataset(csv_path=set,
                    dict_path=args.dict,
                    unit=args.unit,
                    wp_model=args.wp_model,
                    batch_size=1,
                    is_test=True,
                    skip_speech=(args.input_type != 'speech'))
        ]

    args.vocab = train_set.vocab
    args.vocab_sub1 = train_set.vocab_sub1
    args.vocab_sub2 = train_set.vocab_sub2
    args.input_dim = train_set.input_dim

    # Load a RNNLM config file for cold fusion & RNNLM initialization
    # if config['rnnlm_cold_fusion']:
    #     if args.model:
    #         config['rnnlm_config_cold_fusion'] = load_config(
    #             os.path.join(config['rnnlm_cold_fusion'], 'config.yml'), is_eval=True)
    #     elif args.resume:
    #         config = load_config(os.path.join(
    #             args.resume, 'config_rnnlm_cf.yml'))
    #     assert args.unit == config['rnnlm_config_cold_fusion']['unit']
    #     config['rnnlm_config_cold_fusion']['vocab'] = train_set.vocab
    args.rnnlm_cold_fusion = False

    # Model setting
    if args.transformer:
        model = Transformer(args)
        dir_name = 'transformer'
        if len(args.conv_channels) > 0:
            tmp = dir_name
            dir_name = 'conv' + str(len(args.conv_channels.split('_'))) + 'L'
            if args.conv_batch_norm:
                dir_name += 'bn'
            dir_name += tmp
        dir_name += str(args.d_model) + 'H'
        dir_name += str(args.enc_nlayers) + 'L'
        dir_name += str(args.dec_nlayers) + 'L'
        dir_name += '_head' + str(args.attn_nheads)
        dir_name += '_' + args.optimizer
        dir_name += '_lr' + str(args.learning_rate)
        dir_name += '_bs' + str(args.batch_size)
        dir_name += '_ls' + str(args.lsm_prob)
        dir_name += '_' + str(args.pre_process) + 't' + str(args.post_process)
        if args.nstacks > 1:
            dir_name += '_stack' + str(args.nstacks)
        if args.bwd_weight > 0:
            dir_name += '_bwd' + str(args.bwd_weight)
    else:
        model = Seq2seq(args)
        dir_name = args.enc_type
        if args.conv_channels and len(args.conv_channels.split('_')) > 0:
            tmp = dir_name
            dir_name = 'conv' + str(len(args.conv_channels.split('_'))) + 'L'
            if args.conv_batch_norm:
                dir_name += 'bn'
            dir_name += tmp
        dir_name += str(args.enc_nunits) + 'H'
        dir_name += str(args.enc_nprojs) + 'P'
        dir_name += str(args.enc_nlayers) + 'L'
        dir_name += '_' + args.subsample_type + str(subsample_factor)
        dir_name += '_' + args.dec_type
        if args.internal_lm > 0:
            dir_name += 'LM'
        dir_name += str(args.dec_nunits) + 'H'
        # dir_name += str(args.dec_nprojs) + 'P'
        dir_name += str(args.dec_nlayers) + 'L'
        if args.tie_embedding:
            dir_name += '_tie'
        dir_name += '_' + args.attn_type
        if args.attn_nheads > 1:
            dir_name += '_head' + str(args.attn_nheads)
        if args.attn_sigmoid:
            dir_name += '_sig'
        dir_name += '_' + args.optimizer
        dir_name += '_lr' + str(args.learning_rate)
        dir_name += '_bs' + str(args.batch_size)
        dir_name += '_ss' + str(args.ss_prob)
        dir_name += '_ls' + str(args.lsm_prob)
        if args.focal_loss_weight > 0:
            dir_name += '_fl' + str(args.focal_loss_weight)
        if args.layer_norm:
            dir_name += '_layernorm'
        # MTL
        if args.mtl_per_batch:
            dir_name += '_mtlperbatch'
            if args.ctc_weight > 0:
                dir_name += '_' + args.unit + 'ctc'
            if args.bwd_weight > 0:
                dir_name += '_' + args.unit + 'bwd'
            if args.lmobj_weight > 0:
                dir_name += '_' + args.unit + 'lmobj'
            if args.train_set_sub1:
                dir_name += '_' + args.unit_sub1
                if args.ctc_weight_sub1 == 0:
                    dir_name += 'att'
                elif args.ctc_weight_sub1 == args.sub1_weight:
                    dir_name += 'ctc'
                else:
                    dir_name += 'attctc'
            if args.train_set_sub2:
                dir_name += '_' + args.unit_sub2
                if args.ctc_weight_sub2 == 0:
                    dir_name += 'att'
                elif args.ctc_weight_sub2 == args.sub2_weight:
                    dir_name += 'ctc'
                else:
                    dir_name += 'attctc'
        else:
            if args.ctc_weight > 0:
                dir_name += '_ctc' + str(args.ctc_weight)
            if args.bwd_weight > 0:
                dir_name += '_bwd' + str(args.bwd_weight)
            if args.lmobj_weight > 0:
                dir_name += '_lmobj' + str(args.lmobj_weight)
            if args.sub1_weight > 0:
                if args.ctc_weight_sub1 == args.sub1_weight:
                    dir_name += '_ctcsub1' + str(args.ctc_weight_sub1)
                elif args.ctc_weight_sub1 == 0:
                    dir_name += '_attsub1' + str(args.sub1_weight)
                else:
                    dir_name += '_ctcsub1' + str(args.ctc_weight_sub1) + 'attsub1' + \
                        str(args.sub1_weight - args.ctc_weight_sub1)
                if args.sub2_weight > 0:
                    if args.ctc_weight_sub2 == args.sub2_weight:
                        dir_name += '_ctcsub2' + str(args.ctc_weight_sub2)
                    elif args.ctc_weight_sub2 == 0:
                        dir_name += '_attsub2' + str(args.sub2_weight)
                    else:
                        dir_name += '_ctcsub2' + str(args.ctc_weight_sub2) + 'attsub2' + \
                            str(args.sub2_weight - args.ctc_weight_sub2)
        if args.task_specific_layer:
            dir_name += '_tsl'
        # Pre-training
        if args.pretrained_model and os.path.isdir(args.pretrained_model):
            # Load a config file
            config_pre = load_config(
                os.path.join(args.pretrained_model, 'config.yml'))
            dir_name += '_' + config_pre['unit'] + 'pt'

    if not args.resume:
        # Load pre-trained RNNLM
        # if config['rnnlm_cold_fusion']:
        #     rnnlm = RNNLM(args)
        #     rnnlm.load_checkpoint(save_path=config['rnnlm_cold_fusion'], epoch=-1)
        #     rnnlm.flatten_parameters()
        #
        #     # Fix RNNLM parameters
        #     for param in rnnlm.parameters():
        #         param.requires_grad = False
        #
        #     # Set pre-trained parameters
        #     if config['rnnlm_config_cold_fusion']['backward']:
        #         model.dec_0_bwd.rnnlm = rnnlm
        #     else:
        #         model.dec_0_fwd.rnnlm = rnnlm
        # TODO(hirofumi): 最初にRNNLMのモデルをコピー

        # Set save path
        save_path = mkdir_join(
            args.model,
            '_'.join(os.path.basename(args.train_set).split('.')[:-1]),
            dir_name)
        model.set_save_path(save_path)  # avoid overwriting

        # Save the config file as a yaml file
        save_config(vars(args), model.save_path)

        # Save the dictionary & wp_model
        shutil.copy(args.dict, os.path.join(model.save_path, 'dict.txt'))
        if args.dict_sub1:
            shutil.copy(args.dict_sub1,
                        os.path.join(model.save_path, 'dict_sub1.txt'))
        if args.dict_sub2:
            shutil.copy(args.dict_sub2,
                        os.path.join(model.save_path, 'dict_sub2.txt'))
        if args.unit == 'wp':
            shutil.copy(args.wp_model, os.path.join(model.save_path,
                                                    'wp.model'))

        # Setting for logging
        logger = set_logger(os.path.join(model.save_path, 'train.log'),
                            key='training')

        for k, v in sorted(vars(args).items(), key=lambda x: x[0]):
            logger.info('%s: %s' % (k, str(v)))

        # Count total parameters
        for n in sorted(list(model.num_params_dict.keys())):
            nparams = model.num_params_dict[n]
            logger.info("%s %d" % (n, nparams))
        logger.info("Total %.2f M parameters" %
                    (model.total_parameters / 1000000))
        logger.info(model)

        # Initialize with pre-trained model's parameters
        if args.pretrained_model and os.path.isdir(args.pretrained_model):
            # Merge config with args
            for k, v in config_pre.items():
                setattr(args_pre, k, v)

            # Load the ASR model
            model_pre = Seq2seq(args_pre)
            model_pre.load_checkpoint(args.pretrained_model, epoch=-1)

            # Overwrite parameters
            param_dict = dict(model_pre.named_parameters())
            for n, p in model.named_parameters():
                if n in param_dict.keys() and p.size() == param_dict[n].size():
                    p.data = param_dict[n].data
                    logger.info('Overwrite %s' % n)

        # Set optimizer
        model.set_optimizer(optimizer=args.optimizer,
                            learning_rate_init=float(args.learning_rate),
                            weight_decay=float(args.weight_decay),
                            clip_grad_norm=args.clip_grad_norm,
                            lr_schedule=False,
                            factor=args.decay_rate,
                            patience_epoch=args.decay_patient_epoch)

        epoch, step = 1, 1
        learning_rate = float(args.learning_rate)
        metric_dev_best = 10000

    # NOTE: Restart from the last checkpoint
    # elif args.resume:
    #     # Set save path
    #     model.save_path = args.resume
    #
    #     # Setting for logging
    #     logger = set_logger(os.path.join(model.save_path, 'train.log'), key='training')
    #
    #     # Set optimizer
    #     model.set_optimizer(
    #         optimizer=config['optimizer'],
    #         learning_rate_init=float(config['learning_rate']),  # on-the-fly
    #         weight_decay=float(config['weight_decay']),
    #         clip_grad_norm=config['clip_grad_norm'],
    #         lr_schedule=False,
    #         factor=config['decay_rate'],
    #         patience_epoch=config['decay_patient_epoch'])
    #
    #     # Restore the last saved model
    #     epoch, step, learning_rate, metric_dev_best = model.load_checkpoint(
    #         save_path=args.resume, epoch=-1, restart=True)
    #
    #     if epoch >= config['convert_to_sgd_epoch']:
    #         model.set_optimizer(
    #             optimizer='sgd',
    #             learning_rate_init=float(config['learning_rate']),  # on-the-fly
    #             weight_decay=float(config['weight_decay']),
    #             clip_grad_norm=config['clip_grad_norm'],
    #             lr_schedule=False,
    #             factor=config['decay_rate'],
    #             patience_epoch=config['decay_patient_epoch'])
    #
    #     if config['rnnlm_cold_fusion']:
    #         if config['rnnlm_config_cold_fusion']['backward']:
    #             model.rnnlm_0_bwd.flatten_parameters()
    #         else:
    #             model.rnnlm_0_fwd.flatten_parameters()

    train_set.epoch = epoch - 1  # start from index:0

    # GPU setting
    if args.ngpus >= 1:
        model = CustomDataParallel(model,
                                   device_ids=list(range(0, args.ngpus, 1)),
                                   deterministic=False,
                                   benchmark=True)
        model.cuda()

    logger.info('PID: %s' % os.getpid())
    logger.info('USERNAME: %s' % os.uname()[1])

    # Set process name
    # if args.job_name:
    #     setproctitle(args.job_name)
    # else:
    #     setproctitle(dir_name)

    # Set learning rate controller
    lr_controller = Controller(learning_rate_init=learning_rate,
                               decay_type=args.decay_type,
                               decay_start_epoch=args.decay_start_epoch,
                               decay_rate=args.decay_rate,
                               decay_patient_epoch=args.decay_patient_epoch,
                               lower_better=True,
                               best_value=metric_dev_best,
                               model_size=args.d_model,
                               warmup_step=args.warmup_step,
                               factor=1)

    # Set reporter
    reporter = Reporter(model.module.save_path, tensorboard=True)

    if args.mtl_per_batch:
        # NOTE: from easier to harder tasks
        tasks = ['ys']
        if 0 < args.ctc_weight < 1:
            tasks = ['ys.ctc'] + tasks
        if 0 < args.bwd_weight < 1:
            tasks = ['ys.bwd'] + tasks
        if 0 < args.lmobj_weight < 1:
            tasks = ['ys.lmobj'] + tasks
        if args.train_set_sub1:
            if args.ctc_weight_sub1 > 0:
                tasks = ['ys_sub1.ctc'] + tasks
            else:
                tasks = ['ys_sub1'] + tasks
        if args.train_set_sub2:
            if args.ctc_weight_sub2 > 0:
                tasks = ['ys_sub2.ctc'] + tasks
            else:
                tasks = ['ys_sub2'] + tasks
    else:
        tasks = ['all']

    start_time_train = time.time()
    start_time_epoch = time.time()
    start_time_step = time.time()
    not_improved_epoch = 0
    pbar_epoch = tqdm(total=len(train_set))
    while True:
        # Compute loss in the training set
        batch_train, is_new_epoch = train_set.next()

        # Change tasks depending on task
        for task in tasks:
            model.module.optimizer.zero_grad()
            loss, reporter = model(batch_train, reporter=reporter, task=task)
            if len(model.device_ids) > 1:
                loss.backward(torch.ones(len(model.device_ids)))
            else:
                loss.backward()
            loss.detach()  # Trancate the graph
            if args.clip_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(model.module.parameters(),
                                               args.clip_grad_norm)
            model.module.optimizer.step()
            loss_train = loss.item()
            del loss
        reporter.step(is_eval=False)

        # Update learning rate
        if args.decay_type == 'warmup':
            model.module.optimizer, learning_rate = lr_controller.warmup_lr(
                optimizer=model.module.optimizer,
                learning_rate=learning_rate,
                step=step)

        if step % args.print_step == 0:
            # Compute loss in the dev set
            batch_dev = dev_set.next()[0]
            # Change tasks depending on task
            for task in tasks:
                loss, reporter = model(batch_dev,
                                       reporter=reporter,
                                       task=task,
                                       is_eval=True)
                loss_dev = loss.item()
                del loss
            reporter.step(is_eval=True)

            duration_step = time.time() - start_time_step
            if args.input_type == 'speech':
                x_len = max(len(x) for x in batch_train['xs'])
            elif args.input_type == 'text':
                x_len = max(len(x) for x in batch_train['ys'])
            logger.info(
                "step:%d(ep:%.2f) loss:%.3f(%.3f)/lr:%.5f/bs:%d/x_len:%d (%.2f min)"
                % (step, train_set.epoch_detail, loss_train, loss_dev,
                   learning_rate, len(
                       batch_train['utt_ids']), x_len, duration_step / 60))
            start_time_step = time.time()
        step += args.ngpus
        pbar_epoch.update(len(batch_train['utt_ids']))

        # Save fugures of loss and accuracy
        if step % (args.print_step * 10) == 0:
            reporter.snapshot()

        # Save checkpoint and evaluate model per epoch
        if is_new_epoch:
            duration_epoch = time.time() - start_time_epoch
            logger.info('========== EPOCH:%d (%.2f min) ==========' %
                        (epoch, duration_epoch / 60))

            if epoch < args.eval_start_epoch:
                # Save the model
                model.module.save_checkpoint(model.module.save_path, epoch,
                                             step - 1, learning_rate,
                                             metric_dev_best)
            else:
                start_time_eval = time.time()
                # dev
                if args.metric == 'edit_distance':
                    if args.unit in ['word', 'word_char']:
                        metric_dev = eval_word([model.module],
                                               dev_set,
                                               decode_params,
                                               epoch=epoch)[0]
                        logger.info('WER (%s): %.3f %%' %
                                    (dev_set.set, metric_dev))
                    elif args.unit == 'wp':
                        metric_dev = eval_wordpiece([model.module],
                                                    dev_set,
                                                    decode_params,
                                                    epoch=epoch)[0]
                        logger.info('WER (%s): %.3f %%' %
                                    (dev_set.set, metric_dev))
                    elif 'char' in args.unit:
                        dev_results = eval_char([model.module],
                                                dev_set,
                                                decode_params,
                                                epoch=epoch)
                        metric_dev = dev_results[1][0]
                        wer_dev = dev_results[0][0]
                        logger.info('CER (%s): %.3f %%' %
                                    (dev_set.set, metric_dev))
                        logger.info('WER (%s): %.3f %%' %
                                    (dev_set.set, wer_dev))
                    elif 'phone' in args.unit:
                        metric_dev = eval_phone([model.module],
                                                dev_set,
                                                decode_params,
                                                epoch=epoch)[0]
                        logger.info('PER (%s): %.3f %%' %
                                    (dev_set.set, metric_dev))
                elif args.metric == 'loss':
                    metric_dev = eval_loss([model.module], dev_set,
                                           decode_params)
                    logger.info('Loss (%s): %.3f %%' %
                                (dev_set.set, metric_dev))
                else:
                    raise NotImplementedError()

                # Update learning rate
                if args.decay_type != 'warmup':
                    model.module.optimizer, learning_rate = lr_controller.decay_lr(
                        optimizer=model.module.optimizer,
                        learning_rate=learning_rate,
                        epoch=epoch,
                        value=metric_dev)

                if metric_dev < metric_dev_best:
                    metric_dev_best = metric_dev
                    not_improved_epoch = 0
                    logger.info('||||| Best Score |||||')

                    # Save the model
                    model.module.save_checkpoint(model.module.save_path, epoch,
                                                 step - 1, learning_rate,
                                                 metric_dev_best)

                    # test
                    for eval_set in eval_sets:
                        if args.metric == 'edit_distance':
                            if args.unit in ['word', 'word_char']:
                                wer_test = eval_word([model.module],
                                                     eval_set,
                                                     decode_params,
                                                     epoch=epoch)[0]
                                logger.info('WER (%s): %.3f %%' %
                                            (eval_set.set, wer_test))
                            elif args.unit == 'wp':
                                wer_test = eval_wordpiece([model.module],
                                                          eval_set,
                                                          decode_params,
                                                          epoch=epoch)[0]
                                logger.info('WER (%s): %.3f %%' %
                                            (eval_set.set, wer_test))
                            elif 'char' in args.unit:
                                test_results = eval_char([model.module],
                                                         eval_set,
                                                         decode_params,
                                                         epoch=epoch)
                                cer_test = test_results[1][0]
                                wer_test = test_results[0][0]
                                logger.info('CER (%s): %.3f %%' %
                                            (eval_set.set, cer_test))
                                logger.info('WER (%s): %.3f %%' %
                                            (eval_set.set, wer_test))
                            elif 'phone' in args.unit:
                                per_test = eval_phone([model.module],
                                                      eval_set,
                                                      decode_params,
                                                      epoch=epoch)[0]
                                logger.info('PER (%s): %.3f %%' %
                                            (eval_set.set, per_test))
                        elif args.metric == 'loss':
                            loss_test = eval_loss([model.module], eval_set,
                                                  decode_params)
                            logger.info('Loss (%s): %.3f %%' %
                                        (eval_set.set, loss_test))
                        else:
                            raise NotImplementedError()
                else:
                    not_improved_epoch += 1

                duration_eval = time.time() - start_time_eval
                logger.info('Evaluation time: %.2f min' % (duration_eval / 60))

                # Early stopping
                if not_improved_epoch == args.not_improved_patient_epoch:
                    break

                if epoch == args.convert_to_sgd_epoch:
                    # Convert to fine-tuning stage
                    model.module.set_optimizer(
                        'sgd',
                        learning_rate_init=float(
                            args.learning_rate),  # TODO: ?
                        weight_decay=float(args.weight_decay),
                        clip_grad_norm=args.clip_grad_norm,
                        lr_schedule=False,
                        factor=args.decay_rate,
                        patience_epoch=args.decay_patient_epoch)
                    logger.info('========== Convert to SGD ==========')

            pbar_epoch = tqdm(total=len(train_set))

            if epoch == args.nepochs:
                break

            start_time_step = time.time()
            start_time_epoch = time.time()
            epoch += 1

    duration_train = time.time() - start_time_train
    logger.info('Total time: %.2f hour' % (duration_train / 3600))

    if reporter.tensorboard:
        reporter.tf_writer.close()
    pbar_epoch.close()

    return model.module.save_path
Exemple #13
0
def main():

    # Load a config file
    config = load_config(os.path.join(args.model, 'config.yml'))

    decode_params = vars(args)

    # Merge config with args
    for k, v in config.items():
        if not hasattr(args, k):
            setattr(args, k, v)

    # Setting for logging
    if os.path.isfile(os.path.join(args.plot_dir, 'plot.log')):
        os.remove(os.path.join(args.plot_dir, 'plot.log'))
    logger = set_logger(os.path.join(args.plot_dir, 'plot.log'),
                        key='decoding')

    for i, set in enumerate(args.eval_sets):
        subsample_factor = 1
        subsample_factor_sub1 = 1
        subsample = [int(s) for s in args.subsample.split('_')]
        if args.conv_poolings:
            for p in args.conv_poolings.split('_'):
                p = int(p.split(',')[0].replace('(', ''))
                if p > 1:
                    subsample_factor *= p
        if args.train_set_sub1 is not None:
            subsample_factor_sub1 = subsample_factor * np.prod(
                subsample[:args.enc_nlayers_sub1 - 1])
        subsample_factor *= np.prod(subsample)

        # Load dataset
        dataset = Dataset(
            csv_path=set,
            dict_path=os.path.join(args.model, 'dict.txt'),
            dict_path_sub1=os.path.join(args.model, 'dict_sub.txt') if
            os.path.isfile(os.path.join(args.model, 'dict_sub.txt')) else None,
            wp_model=os.path.join(args.model, 'wp.model'),
            unit=args.unit,
            unit_sub1=args.unit_sub1,
            batch_size=args.batch_size,
            is_test=True)

        if i == 0:
            args.vocab = dataset.vocab
            args.vocab_sub1 = dataset.vocab_sub1
            args.input_dim = dataset.input_dim

            # TODO(hirofumi): For cold fusion
            args.rnnlm_cold_fusion = None
            args.rnnlm_init = None

            # Load the ASR model
            model = Seq2seq(args)
            epoch, _, _, _ = model.load_checkpoint(args.model,
                                                   epoch=args.epoch)

            model.save_path = args.model

            # GPU setting
            model.cuda()

            logger.info('epoch: %d' % (epoch - 1))

        save_path = mkdir_join(args.plot_dir, 'att_weights')

        # Clean directory
        if save_path is not None and os.path.isdir(save_path):
            shutil.rmtree(save_path)
            os.mkdir(save_path)

        while True:
            batch, is_new_epoch = dataset.next(decode_params['batch_size'])
            best_hyps, aws, perm_idx = model.decode(batch['xs'],
                                                    decode_params,
                                                    exclude_eos=False)
            ys = [batch['ys'][i] for i in perm_idx]

            # Get CTC probs
            ctc_probs, indices_topk, x_lens = model.get_ctc_posteriors(
                batch['xs'], temperature=1, topk=min(100, model.vocab))
            # NOTE: ctc_probs: '[B, T, topk]'

            for b in range(len(batch['xs'])):
                if args.unit == 'word':
                    token_list = dataset.idx2word(best_hyps[b],
                                                  return_list=True)
                elif args.unit == 'wp':
                    token_list = dataset.idx2wp(best_hyps[b], return_list=True)
                elif args.unit == 'char':
                    token_list = dataset.idx2char(best_hyps[b],
                                                  return_list=True)
                elif args.unit == 'phone':
                    token_list = dataset.idx2phone(best_hyps[b],
                                                   return_list=True)
                else:
                    raise NotImplementedError(args.unit)
                token_list = [unicode(t, 'utf-8') for t in token_list]
                speaker = '_'.join(batch['utt_ids'][b].replace(
                    '-', '_').split('_')[:-2])

                plot_ctc_probs(
                    ctc_probs[b, :x_lens[b]],
                    indices_topk[b],
                    nframes=x_lens[b],
                    subsample_factor=subsample_factor,
                    spectrogram=batch['xs'][b][:, :dataset.input_dim],
                    save_path=mkdir_join(save_path, speaker,
                                         batch['utt_ids'][b] + '.png'),
                    figsize=(20, 8))

                ref = ys[b]
                hyp = ' '.join(token_list)
                logger.info('utt-id: %s' % batch['utt_ids'][b])
                logger.info('Ref: %s' % ref.lower())
                logger.info('Hyp: %s' % hyp)
                logger.info('-' * 50)

            if is_new_epoch:
                break
Exemple #14
0
def main():

    args = parse()

    # Load a conf file
    dir_name = os.path.dirname(args.recog_model[0])
    conf = load_config(os.path.join(dir_name, 'conf.yml'))

    # Overwrite conf
    for k, v in conf.items():
        if 'recog' not in k:
            setattr(args, k, v)

    # Setting for logging
    if os.path.isfile(os.path.join(args.recog_dir, 'plot.log')):
        os.remove(os.path.join(args.recog_dir, 'plot.log'))
    logger = set_logger(os.path.join(args.recog_dir, 'plot.log'),
                        key='decoding')

    for i, s in enumerate(args.recog_sets):
        # Load dataset
        dataset = Dataset(corpus=args.corpus,
                          tsv_path=s,
                          dict_path=os.path.join(dir_name, 'dict.txt'),
                          wp_model=os.path.join(dir_name, 'wp.model'),
                          unit=args.unit,
                          batch_size=args.recog_batch_size,
                          bptt=args.bptt,
                          serialize=args.serialize,
                          is_test=True)

        if i == 0:
            # Load the LM
            if args.lm_type == 'gated_cnn':
                model = GatedConvLM(args)
            else:
                model = RNNLM(args)
            epoch = model.load_checkpoint(args.recog_model[0])['epoch']
            model.save_path = dir_name

            logger.info('epoch: %d' % (epoch - 1))
            logger.info('batch size: %d' % args.recog_batch_size)
            # logger.info('recog unit: %s' % args.recog_unit)
            # logger.info('ensemble: %d' % (len(ensemble_models)))
            logger.info('BPTT: %d' % (args.bptt))
            logger.info('cache size: %d' % (args.recog_n_caches))
            logger.info('cache theta: %.3f' % (args.recog_cache_theta))
            logger.info('cache lambda: %.3f' % (args.recog_cache_lambda))
            model.cache_theta = args.recog_cache_theta
            model.cache_lambda = args.recog_cache_lambda

            # GPU setting
            model.cuda()

        assert args.recog_n_caches > 0
        save_path = mkdir_join(args.recog_dir, 'cache')

        # Clean directory
        if save_path is not None and os.path.isdir(save_path):
            shutil.rmtree(save_path)
            os.mkdir(save_path)

        if args.unit == 'word':
            idx2token = dataset.idx2word
        elif args.unit == 'wp':
            idx2token = dataset.idx2wp
        elif args.unit == 'char':
            idx2token = dataset.idx2char
        elif args.unit == 'phone':
            idx2token = dataset.idx2phone
        else:
            raise NotImplementedError(args.unit)

        hidden = None
        fig_count = 0
        toknen_count = 0
        n_tokens = args.recog_n_caches
        while True:
            ys, is_new_epoch = dataset.next()

            for t in range(ys.shape[1] - 1):
                loss, hidden = model(ys[:, t:t + 2],
                                     hidden,
                                     is_eval=True,
                                     n_caches=args.recog_n_caches)[:2]

                if len(model.cache_attn) > 0:
                    if toknen_count == n_tokens:
                        tokens_keys = idx2token(
                            model.cache_ids[:args.recog_n_caches],
                            return_list=True)
                        tokens_query = idx2token(model.cache_ids[-n_tokens:],
                                                 return_list=True)

                        # Slide attention matrix
                        n_keys = len(tokens_keys)
                        n_queries = len(tokens_query)
                        cache_probs = np.zeros(
                            (n_keys, n_queries))  # `[n_keys, n_queries]`
                        mask = np.zeros((n_keys, n_queries))
                        for i, aw in enumerate(model.cache_attn[-n_tokens:]):
                            cache_probs[:(n_keys - n_queries + i + 1),
                                        i] = aw[0,
                                                -(n_keys - n_queries + i + 1):]
                            mask[(n_keys - n_queries + i + 1):, i] = 1

                        plot_cache_weights(cache_probs,
                                           keys=tokens_keys,
                                           queries=tokens_query,
                                           save_path=mkdir_join(
                                               save_path,
                                               str(fig_count) + '.png'),
                                           figsize=(40, 16),
                                           mask=mask)
                        toknen_count = 0
                        fig_count += 1
                    else:
                        toknen_count += 1

            if is_new_epoch:
                break
Exemple #15
0
def main():

    # Load a config file
    config = load_config(os.path.join(args.model, 'config.yml'))

    decode_params = vars(args)

    # Merge config with args
    for k, v in config.items():
        if not hasattr(args, k):
            setattr(args, k, v)

    # Setting for logging
    logger = set_logger(os.path.join(args.plot_dir, 'plot.log'), key='decoding')

    for i, set in enumerate(args.eval_sets):
        # Load dataset
        eval_set = Dataset(csv_path=set,
                           dict_path=os.path.join(args.model, 'dict.txt'),
                           dict_path_sub=os.path.join(args.model, 'dict_sub.txt') if os.path.isfile(
                               os.path.join(args.model, 'dict_sub.txt')) else None,
                           wp_model=os.path.join(args.model, 'wp.model'),
                           unit=args.unit,
                           batch_size=args.batch_size,
                           max_num_frames=args.max_num_frames,
                           min_num_frames=args.min_num_frames,
                           is_test=True)

        if i == 0:
            args.vocab = eval_set.vocab
            args.vocab_sub = eval_set.vocab_sub
            args.input_dim = eval_set.input_dim

            # TODO(hirofumi): For cold fusion
            args.rnnlm_cold_fusion = None
            args.rnnlm_init = None

            # Load the ASR model
            model = Seq2seq(args)
            epoch, _, _, _ = model.load_checkpoint(args.model, epoch=args.epoch)

            model.save_path = args.model

            # For shallow fusion
            if args.rnnlm_cold_fusion is None and args.rnnlm is not None and args.rnnlm_weight > 0:
                # Load a RNNLM config file
                config_rnnlm = load_config(os.path.join(args.rnnlm, 'config.yml'))

                # Merge config with args
                args_rnnlm = argparse.Namespace()
                for k, v in config_rnnlm.items():
                    setattr(args_rnnlm, k, v)

                assert args.unit == args_rnnlm.unit
                args_rnnlm.vocab = eval_set.vocab

                # Load the pre-trianed RNNLM
                rnnlm = RNNLM(args_rnnlm)
                rnnlm.load_checkpoint(args.rnnlm, epoch=-1)
                if args_rnnlm.backward:
                    model.rnnlm_bwd_0 = rnnlm
                else:
                    model.rnnlm_fwd_0 = rnnlm

                logger.info('RNNLM path: %s' % args.rnnlm)
                logger.info('RNNLM weight: %.3f' % args.rnnlm_weight)
                logger.info('RNNLM backward: %s' % str(config_rnnlm['backward']))

            # GPU setting
            model.cuda()

            logger.info('beam width: %d' % args.beam_width)
            logger.info('length penalty: %.3f' % args.length_penalty)
            logger.info('coverage penalty: %.3f' % args.coverage_penalty)
            logger.info('coverage threshold: %.3f' % args.coverage_threshold)
            logger.info('epoch: %d' % (epoch - 1))

        save_path = mkdir_join(args.plot_dir, 'att_weights')

        # Clean directory
        if save_path is not None and os.path.isdir(save_path):
            shutil.rmtree(save_path)
            os.mkdir(save_path)

        while True:
            batch, is_new_epoch = eval_set.next(decode_params['batch_size'])
            best_hyps, aws, perm_idx = model.decode(batch['xs'], decode_params,
                                                    exclude_eos=False)
            ys = [batch['ys'][i] for i in perm_idx]

            if model.bwd_weight > 0.5:
                # Reverse the order
                best_hyps = [hyp[::-1] for hyp in best_hyps]
                aws = [aw[::-1] for aw in aws]

            for b in range(len(batch['xs'])):
                if args.unit == 'word':
                    token_list = eval_set.idx2word(best_hyps[b], return_list=True)
                if args.unit == 'wp':
                    token_list = eval_set.idx2wp(best_hyps[b], return_list=True)
                elif args.unit == 'char':
                    token_list = eval_set.idx2char(best_hyps[b], return_list=True)
                elif args.unit == 'phone':
                    token_list = eval_set.idx2phone(best_hyps[b], return_list=True)
                else:
                    raise NotImplementedError(args.unit)
                token_list = [unicode(t, 'utf-8') for t in token_list]
                speaker = '_'.join(batch['utt_ids'][b].replace('-', '_').split('_')[:-2])

                # error check
                assert len(batch['xs'][b]) <= 2000

                plot_attention_weights(aws[b][:len(token_list)],
                                       label_list=token_list,
                                       spectrogram=batch['xs'][b][:,
                                                                  :eval_set.input_dim] if args.input_type == 'speech' else None,
                                       save_path=mkdir_join(save_path, speaker, batch['utt_ids'][b] + '.png'),
                                       figsize=(20, 8))

                ref = ys[b]
                if model.bwd_weight > 0.5:
                    hyp = ' '.join(token_list[::-1])
                else:
                    hyp = ' '.join(token_list)
                logger.info('utt-id: %s' % batch['utt_ids'][b])
                logger.info('Ref: %s' % ref.lower())
                logger.info('Hyp: %s' % hyp)
                logger.info('-' * 50)

            if is_new_epoch:
                break
Exemple #16
0
def main():

    # Load a config file
    if args.resume_model is None:
        config = load_config(args.config)
    else:
        # Restart from the last checkpoint
        config = load_config(os.path.join(args.resume_model, 'config.yml'))

    # Check differences between args and yaml comfiguraiton
    for k, v in vars(args).items():
        if k not in config.keys():
            warnings.warn("key %s is automatically set to %s" % (k, str(v)))

    # Merge config with args
    for k, v in config.items():
        setattr(args, k, v)

    # Automatically reduce batch size in multi-GPU setting
    if args.ngpus > 1:
        args.batch_size -= 10
        args.print_step //= args.ngpus

    subsample_factor = 1
    subsample_factor_sub = 1
    for p in args.conv_poolings:
        if len(p) > 0:
            subsample_factor *= p[0]
    if args.train_set_sub is not None:
        subsample_factor_sub = subsample_factor * (2**sum(
            args.subsample[:args.enc_num_layers_sub - 1]))
    subsample_factor *= 2**sum(args.subsample)

    # Load dataset
    train_set = Dataset(csv_path=args.train_set,
                        dict_path=args.dict,
                        label_type=args.label_type,
                        batch_size=args.batch_size * args.ngpus,
                        max_epoch=args.num_epochs,
                        max_num_frames=args.max_num_frames,
                        min_num_frames=args.min_num_frames,
                        sort_by_input_length=True,
                        short2long=True,
                        sort_stop_epoch=args.sort_stop_epoch,
                        dynamic_batching=True,
                        use_ctc=args.ctc_weight > 0,
                        subsample_factor=subsample_factor,
                        csv_path_sub=args.train_set_sub,
                        dict_path_sub=args.dict_sub,
                        label_type_sub=args.label_type_sub,
                        use_ctc_sub=args.ctc_weight_sub > 0,
                        subsample_factor_sub=subsample_factor_sub,
                        skip_speech=(args.input_type != 'speech'))
    dev_set = Dataset(csv_path=args.dev_set,
                      dict_path=args.dict,
                      label_type=args.label_type,
                      batch_size=args.batch_size * args.ngpus,
                      max_epoch=args.num_epochs,
                      max_num_frames=args.max_num_frames,
                      min_num_frames=args.min_num_frames,
                      shuffle=True,
                      use_ctc=args.ctc_weight > 0,
                      subsample_factor=subsample_factor,
                      csv_path_sub=args.dev_set_sub,
                      dict_path_sub=args.dict_sub,
                      label_type_sub=args.label_type_sub,
                      use_ctc_sub=args.ctc_weight_sub > 0,
                      subsample_factor_sub=subsample_factor_sub,
                      skip_speech=(args.input_type != 'speech'))
    eval_sets = []
    for set in args.eval_sets:
        eval_sets += [
            Dataset(csv_path=set,
                    dict_path=args.dict,
                    label_type=args.label_type,
                    batch_size=1,
                    max_epoch=args.num_epochs,
                    is_test=True,
                    skip_speech=(args.input_type != 'speech'))
        ]

    args.num_classes = train_set.num_classes
    args.input_dim = train_set.input_dim
    args.num_classes_sub = train_set.num_classes_sub

    # Load a RNNLM config file for cold fusion & RNNLM initialization
    # if config['rnnlm_cf']:
    #     if args.model is not None:
    #         config['rnnlm_config_cold_fusion'] = load_config(
    #             os.path.join(config['rnnlm_cf'], 'config.yml'), is_eval=True)
    #     elif args.resume_model is not None:
    #         config = load_config(os.path.join(
    #             args.resume_model, 'config_rnnlm_cf.yml'))
    #     assert args.label_type == config['rnnlm_config_cold_fusion']['label_type']
    #     config['rnnlm_config_cold_fusion']['num_classes'] = train_set.num_classes
    args.rnnlm_cf = None
    args.rnnlm_init = None

    # Model setting
    model = Seq2seq(args)
    model.name = args.enc_type
    if len(args.conv_channels) > 0:
        tmp = model.name
        model.name = 'conv' + str(len(args.conv_channels)) + 'L'
        if args.conv_batch_norm:
            model.name += 'bn'
        model.name += tmp
    model.name += str(args.enc_num_units) + 'H'
    model.name += str(args.enc_num_projs) + 'P'
    model.name += str(args.enc_num_layers) + 'L'
    model.name += '_subsample' + str(subsample_factor)
    model.name += '_' + args.dec_type
    model.name += str(args.dec_num_units) + 'H'
    # model.name += str(args.dec_num_projs) + 'P'
    model.name += str(args.dec_num_layers) + 'L'
    model.name += '_' + args.att_type
    if args.att_num_heads > 1:
        model.name += '_head' + str(args.att_num_heads)
    model.name += '_' + args.optimizer
    model.name += '_lr' + str(args.learning_rate)
    model.name += '_bs' + str(args.batch_size)
    model.name += '_ss' + str(args.ss_prob)
    model.name += '_ls' + str(args.lsm_prob)
    if args.ctc_weight > 0:
        model.name += '_ctc' + str(args.ctc_weight)
    if args.bwd_weight > 0:
        model.name += '_bwd' + str(args.bwd_weight)
    if args.main_task_weight < 1:
        model.name += '_main' + str(args.main_task_weight)
        if args.ctc_weight_sub > 0:
            model.name += '_ctcsub' + str(args.ctc_weight_sub *
                                          (1 - args.main_task_weight))
        else:
            model.name += '_attsub' + str(1 - args.main_task_weight)

    if args.resume_model is None:
        # Load pre-trained RNNLM
        # if config['rnnlm_cf']:
        #     rnnlm = RNNLM(args)
        #     rnnlm.load_checkpoint(save_path=config['rnnlm_cf'], epoch=-1)
        #     rnnlm.flatten_parameters()
        #
        #     # Fix RNNLM parameters
        #     for param in rnnlm.parameters():
        #         param.requires_grad = False
        #
        #     # Set pre-trained parameters
        #     if config['rnnlm_config_cold_fusion']['backward']:
        #         model.dec_0_bwd.rnnlm = rnnlm
        #     else:
        #         model.dec_0_fwd.rnnlm = rnnlm
        # TODO(hirofumi): 最初にRNNLMのモデルをコピー

        # Set save path
        save_path = mkdir_join(
            args.model,
            '_'.join(os.path.basename(args.train_set).split('.')[:-1]),
            model.name)
        model.set_save_path(save_path)  # avoid overwriting

        # Save the config file as a yaml file
        save_config(vars(args), model.save_path)

        # Save the dictionary & wp_model
        shutil.copy(args.dict, os.path.join(save_path, 'dict.txt'))
        if args.dict_sub is not None:
            shutil.copy(args.dict_sub, os.path.join(save_path, 'dict_sub.txt'))
        if args.label_type == 'wordpiece':
            shutil.copy(args.wp_model, os.path.join(save_path, 'wp.model'))

        # Setting for logging
        logger = set_logger(os.path.join(model.save_path, 'train.log'),
                            key='training')

        for k, v in sorted(vars(args).items(), key=lambda x: x[0]):
            logger.info('%s: %s' % (k, str(v)))

        # if os.path.isdir(args.pretrained_model):
        #     # NOTE: Start training from the pre-trained model
        #     # This is defferent from resuming training
        #     model.load_checkpoint(args.pretrained_model, epoch=-1,
        #                           load_pretrained_model=True)

        # Count total parameters
        for name in sorted(list(model.num_params_dict.keys())):
            num_params = model.num_params_dict[name]
            logger.info("%s %d" % (name, num_params))
        logger.info("Total %.2f M parameters" %
                    (model.total_parameters / 1000000))

        # Set optimizer
        model.set_optimizer(optimizer=args.optimizer,
                            learning_rate_init=float(args.learning_rate),
                            weight_decay=float(args.weight_decay),
                            clip_grad_norm=args.clip_grad_norm,
                            lr_schedule=False,
                            factor=args.decay_rate,
                            patience_epoch=args.decay_patient_epoch)

        epoch, step = 1, 0
        learning_rate = float(args.learning_rate)
        metric_dev_best = 10000

    # NOTE: Restart from the last checkpoint
    # elif args.resume_model is not None:
    #     # Set save path
    #     model.save_path = args.resume_model
    #
    #     # Setting for logging
    #     logger = set_logger(os.path.join(model.save_path, 'train.log'), key='training')
    #
    #     # Set optimizer
    #     model.set_optimizer(
    #         optimizer=config['optimizer'],
    #         learning_rate_init=float(config['learning_rate']),  # on-the-fly
    #         weight_decay=float(config['weight_decay']),
    #         clip_grad_norm=config['clip_grad_norm'],
    #         lr_schedule=False,
    #         factor=config['decay_rate'],
    #         patience_epoch=config['decay_patient_epoch'])
    #
    #     # Restore the last saved model
    #     epoch, step, learning_rate, metric_dev_best = model.load_checkpoint(
    #         save_path=args.resume_model, epoch=-1, restart=True)
    #
    #     if epoch >= config['convert_to_sgd_epoch']:
    #         model.set_optimizer(
    #             optimizer='sgd',
    #             learning_rate_init=float(config['learning_rate']),  # on-the-fly
    #             weight_decay=float(config['weight_decay']),
    #             clip_grad_norm=config['clip_grad_norm'],
    #             lr_schedule=False,
    #             factor=config['decay_rate'],
    #             patience_epoch=config['decay_patient_epoch'])
    #
    #     if config['rnnlm_cf']:
    #         if config['rnnlm_config_cold_fusion']['backward']:
    #             model.rnnlm_0_bwd.flatten_parameters()
    #         else:
    #             model.rnnlm_0_fwd.flatten_parameters()

    train_set.epoch = epoch - 1  # start from index:0

    # GPU setting
    if args.ngpus >= 1:
        model = CustomDataParallel(model,
                                   device_ids=list(range(0, args.ngpus, 1)),
                                   deterministic=False,
                                   benchmark=True)
        model.cuda()

    logger.info('PID: %s' % os.getpid())
    logger.info('USERNAME: %s' % os.uname()[1])

    # Set process name
    # setproctitle(args.job_name)

    # Set learning rate controller
    lr_controller = Controller(learning_rate_init=learning_rate,
                               decay_type=args.decay_type,
                               decay_start_epoch=args.decay_start_epoch,
                               decay_rate=args.decay_rate,
                               decay_patient_epoch=args.decay_patient_epoch,
                               lower_better=True,
                               best_value=metric_dev_best)

    # Set reporter
    reporter = Reporter(model.module.save_path, max_loss=300)

    # Set the updater
    updater = Updater(args.clip_grad_norm)

    # Setting for tensorboard
    tf_writer = SummaryWriter(model.module.save_path)

    start_time_train = time.time()
    start_time_epoch = time.time()
    start_time_step = time.time()
    not_improved_epoch = 0.
    loss_train_mean, acc_train_mean = 0., 0.
    pbar_epoch = tqdm(total=len(train_set))
    pbar_all = tqdm(total=len(train_set) * args.num_epochs)
    while True:
        # Compute loss in the training set (including parameter update)
        batch_train, is_new_epoch = train_set.next()
        model, loss_train, acc_train = updater(model, batch_train)
        loss_train_mean += loss_train
        acc_train_mean += acc_train
        pbar_epoch.update(len(batch_train['utt_ids']))

        if (step + 1) % args.print_step == 0:
            # Compute loss in the dev set
            batch_dev = dev_set.next()[0]
            model, loss_dev, acc_dev = updater(model, batch_dev, is_eval=True)

            loss_train_mean /= args.print_step
            acc_train_mean /= args.print_step
            reporter.step(step, loss_train_mean, loss_dev, acc_train_mean,
                          acc_dev)

            # Logging by tensorboard
            tf_writer.add_scalar('train/loss', loss_train_mean, step + 1)
            tf_writer.add_scalar('dev/loss', loss_dev, step + 1)
            # for n, p in model.module.named_parameters():
            #     n = n.replace('.', '/')
            #     if p.grad is not None:
            #         tf_writer.add_histogram(n, p.data.cpu().numpy(), step + 1)
            #         tf_writer.add_histogram(n + '/grad', p.grad.data.cpu().numpy(), step + 1)

            duration_step = time.time() - start_time_step
            if args.input_type == 'speech':
                x_len = max(len(x) for x in batch_train['xs'])
            elif args.input_type == 'text':
                x_len = max(len(x) for x in batch_train['ys_sub'])
            logger.info(
                "...Step:%d(ep:%.2f) loss:%.2f(%.2f)/acc:%.2f(%.2f)/lr:%.5f/bs:%d/x_len:%d (%.2f min)"
                % (step + 1, train_set.epoch_detail, loss_train_mean, loss_dev,
                   acc_train_mean, acc_dev, learning_rate,
                   train_set.current_batch_size, x_len, duration_step / 60))
            start_time_step = time.time()
            loss_train_mean, acc_train_mean = 0, 0
        step += args.ngpus

        # Save checkpoint and evaluate model per epoch
        if is_new_epoch:
            duration_epoch = time.time() - start_time_epoch
            logger.info('===== EPOCH:%d (%.2f min) =====' %
                        (epoch, duration_epoch / 60))

            # Save fugures of loss and accuracy
            reporter.epoch()

            if epoch < args.eval_start_epoch:
                # Save the model
                model.module.save_checkpoint(model.module.save_path, epoch,
                                             step, learning_rate,
                                             metric_dev_best)
            else:
                start_time_eval = time.time()
                # dev
                if args.metric == 'ler':
                    if args.label_type == 'word':
                        metric_dev = eval_word([model.module],
                                               dev_set,
                                               decode_params,
                                               epoch=epoch)[0]
                        logger.info('  WER (%s): %.3f %%' %
                                    (dev_set.set, metric_dev))
                    elif args.label_type == 'wordpiece':
                        metric_dev = eval_wordpiece([model.module],
                                                    dev_set,
                                                    decode_params,
                                                    args.wp_model,
                                                    epoch=epoch)[0]
                        logger.info('  WER (%s): %.3f %%' %
                                    (dev_set.set, metric_dev))
                    elif 'char' in args.label_type:
                        metric_dev = eval_char([model.module],
                                               dev_set,
                                               decode_params,
                                               epoch=epoch)[1][0]
                        logger.info('  CER (%s): %.3f %%' %
                                    (dev_set.set, metric_dev))
                    elif 'phone' in args.label_type:
                        metric_dev = eval_phone([model.module],
                                                dev_set,
                                                decode_params,
                                                epoch=epoch)[0]
                        logger.info('  PER (%s): %.3f %%' %
                                    (dev_set.set, metric_dev))
                elif args.metric == 'loss':
                    metric_dev = eval_loss([model.module], dev_set,
                                           decode_params)
                    logger.info('  Loss (%s): %.3f %%' %
                                (dev_set.set, metric_dev))
                else:
                    raise NotImplementedError()

                if metric_dev < metric_dev_best:
                    metric_dev_best = metric_dev
                    not_improved_epoch = 0
                    logger.info('||||| Best Score |||||')

                    # Update learning rate
                    model.module.optimizer, learning_rate = lr_controller.decay_lr(
                        optimizer=model.module.optimizer,
                        learning_rate=learning_rate,
                        epoch=epoch,
                        value=metric_dev)

                    # Save the model
                    model.module.save_checkpoint(model.module.save_path, epoch,
                                                 step, learning_rate,
                                                 metric_dev_best)

                    # test
                    for eval_set in eval_sets:
                        if args.metric == 'ler':
                            if args.label_type == 'word':
                                wer_test = eval_word([model.module],
                                                     eval_set,
                                                     decode_params,
                                                     epoch=epoch)[0]
                                logger.info('  WER (%s): %.3f %%' %
                                            (eval_set.set, wer_test))
                            elif args.label_type == 'wordpiece':
                                wer_test = eval_wordpiece([model.module],
                                                          eval_set,
                                                          decode_params,
                                                          epoch=epoch)[0]
                                logger.info('  WER (%s): %.3f %%' %
                                            (eval_set.set, wer_test))
                            elif 'char' in args.label_type:
                                cer_test = eval_char([model.module],
                                                     eval_set,
                                                     decode_params,
                                                     epoch=epoch)[1][0]
                                logger.info('  CER (%s): %.3f / %.3f %%' %
                                            (eval_set.set, cer_test))
                            elif 'phone' in args.label_type:
                                per_test = eval_phone([model.module],
                                                      eval_set,
                                                      decode_params,
                                                      epoch=epoch)[0]
                                logger.info('  PER (%s): %.3f %%' %
                                            (eval_set.set, per_test))
                        elif args.metric == 'loss':
                            loss_test = eval_loss([model.module], eval_set,
                                                  decode_params)
                            logger.info('  Loss (%s): %.3f %%' %
                                        (eval_set.set, loss_test))
                        else:
                            raise NotImplementedError()
                else:
                    # Update learning rate
                    model.module.optimizer, learning_rate = lr_controller.decay_lr(
                        optimizer=model.module.optimizer,
                        learning_rate=learning_rate,
                        epoch=epoch,
                        value=metric_dev)

                    not_improved_epoch += 1

                duration_eval = time.time() - start_time_eval
                logger.info('Evaluation time: %.2f min' % (duration_eval / 60))

                # Early stopping
                if not_improved_epoch == args.not_improved_patient_epoch:
                    break

                if epoch == args.convert_to_sgd_epoch:
                    # Convert to fine-tuning stage
                    model.module.set_optimizer(
                        'sgd',
                        learning_rate_init=float(
                            args.learning_rate),  # TODO: ?
                        weight_decay=float(args.weight_decay),
                        clip_grad_norm=args.clip_grad_norm,
                        lr_schedule=False,
                        factor=args.decay_rate,
                        patience_epoch=args.decay_patient_epoch)
                    logger.info('========== Convert to SGD ==========')

            pbar_epoch = tqdm(total=len(train_set))
            pbar_all.update(len(train_set))

            if epoch == args.num_epochs:
                break

            start_time_step = time.time()
            start_time_epoch = time.time()
            epoch += 1

    duration_train = time.time() - start_time_train
    logger.info('Total time: %.2f hour' % (duration_train / 3600))

    tf_writer.close()
    pbar_epoch.close()
    pbar_all.close()

    return model.module.save_path
Exemple #17
0
def eval_word(models, dataset, decode_params, epoch, progressbar=False):
    """Evaluate the word-level model by WER.

    Args:
        models (list): the models to evaluate
        dataset: An instance of a `Dataset' class
        decode_params (dict):
        epoch (int):
        progressbar (bool): if True, visualize the progressbar
    Returns:
        wer (float): Word error rate
        num_sub (int): the number of substitution errors
        num_ins (int): the number of insertion errors
        num_del (int): the number of deletion errors
        decode_dir (str):

    """
    # Reset data counter
    dataset.reset()

    model = models[0]
    # TODO(hirofumi): ensemble decoding

    decode_dir = 'decode_' + dataset.set + '_ep' + str(epoch) + '_beam' + str(
        decode_params['beam_width'])
    decode_dir += '_lp' + str(decode_params['length_penalty'])
    decode_dir += '_cp' + str(decode_params['coverage_penalty'])
    decode_dir += '_' + str(decode_params['min_len_ratio']) + '_' + str(
        decode_params['max_len_ratio'])
    decode_dir += '_rnnlm' + str(decode_params['rnnlm_weight'])

    ref_trn_save_path = mkdir_join(model.save_path, decode_dir, 'ref.trn')
    hyp_trn_save_path = mkdir_join(model.save_path, decode_dir, 'hyp.trn')

    wer = 0
    num_sub, num_ins, num_del, = 0, 0, 0
    num_words = 0
    num_oov_total = 0
    if progressbar:
        pbar = tqdm(total=len(dataset))  # TODO(hirofumi): fix this

    with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path,
                                                     'w') as f_ref:
        while True:
            batch, is_new_epoch = dataset.next(decode_params['batch_size'])
            best_hyps, aw, perm_idx = model.decode(batch['xs'],
                                                   decode_params,
                                                   exclude_eos=True)
            ys = [batch['ys'][i] for i in perm_idx]

            for b in range(len(batch['xs'])):
                # Reference
                if dataset.is_test:
                    text_ref = ys[b]
                else:
                    text_ref = dataset.idx2word(ys[b])

                # Hypothesis
                text_hyp = dataset.idx2word(best_hyps[b])
                num_oov_total += text_hyp.count('<unk>')

                # Resolving UNK
                if decode_params['resolving_unk'] and '<unk>' in text_hyp:
                    best_hyps_sub, aw_sub, _ = model.decode(batch['xs'][b:b +
                                                                        1],
                                                            batch['xs'],
                                                            decode_params,
                                                            exclude_eos=True)
                    # task_index=1

                    text_hyp = resolve_unk(
                        text_hyp,
                        best_hyps_sub[0],
                        aw[b],
                        aw_sub[0],
                        dataset.idx2char,
                        diff_time_resolution=2**sum(model.subsample_list) //
                        2**sum(
                            model.
                            subsample_list[:model.encoder_num_layers_sub - 1]))
                    text_hyp = text_hyp.replace('*', '')

                # Write to trn
                speaker = '_'.join(batch['utt_ids'][b].replace(
                    '-', '_').split('_')[:-2])
                start = batch['utt_ids'][b].replace('-', '_').split('_')[-2]
                end = batch['utt_ids'][b].replace('-', '_').split('_')[-1]
                f_ref.write(text_ref + ' (' + speaker + '-' + start + '-' +
                            end + ')\n')
                f_hyp.write(text_hyp + ' (' + speaker + '-' + start + '-' +
                            end + ')\n')

                # Compute WER
                wer_b, sub_b, ins_b, del_b = compute_wer(
                    ref=text_ref.split(' '),
                    hyp=text_hyp.split(' '),
                    normalize=False)
                wer += wer_b
                num_sub += sub_b
                num_ins += ins_b
                num_del += del_b
                num_words += len(text_ref.split(' '))

                if progressbar:
                    pbar.update(1)

            if is_new_epoch:
                break

    if progressbar:
        pbar.close()

    # Reset data counters
    dataset.reset()

    wer /= num_words
    num_sub /= num_words
    num_ins /= num_words
    num_del /= num_words

    return wer, num_sub, num_ins, num_del, os.path.join(
        model.save_path, decode_dir)
Exemple #18
0
def main():

    args = parse()

    # Load a conf file
    dir_name = os.path.dirname(args.recog_model[0])
    conf = load_config(os.path.join(dir_name, 'conf.yml'))

    # Overwrite conf
    for k, v in conf.items():
        if 'recog' not in k:
            setattr(args, k, v)
    recog_params = vars(args)

    # Setting for logging
    if os.path.isfile(os.path.join(args.recog_dir, 'plot.log')):
        os.remove(os.path.join(args.recog_dir, 'plot.log'))
    logger = set_logger(os.path.join(args.recog_dir, 'plot.log'), key='decoding')

    for i, s in enumerate(args.recog_sets):
        subsample_factor = 1
        subsample = [int(s) for s in args.subsample.split('_')]
        if args.conv_poolings:
            for p in args.conv_poolings.split('_'):
                p = int(p.split(',')[0].replace('(', ''))
                if p > 1:
                    subsample_factor *= p
        subsample_factor *= np.prod(subsample)

        # Load dataset
        dataset = Dataset(corpus=args.corpus,
                          tsv_path=s,
                          dict_path=os.path.join(dir_name, 'dict.txt'),
                          dict_path_sub1=os.path.join(dir_name, 'dict_sub1.txt') if os.path.isfile(
                              os.path.join(dir_name, 'dict_sub1.txt')) else False,
                          nlsyms=args.nlsyms,
                          wp_model=os.path.join(dir_name, 'wp.model'),
                          unit=args.unit,
                          unit_sub1=args.unit_sub1,
                          batch_size=args.recog_batch_size,
                          concat_prev_n_utterances=args.recog_concat_prev_n_utterances,
                          is_test=True)

        if i == 0:
            # Load the ASR model
            model = Seq2seq(args)
            epoch = model.load_checkpoint(args.recog_model[0])['epoch']
            model.save_path = dir_name

            if not args.recog_unit:
                args.recog_unit = args.unit

            logger.info('recog unit: %s' % args.recog_unit)
            logger.info('epoch: %d' % (epoch - 1))
            logger.info('batch size: %d' % args.recog_batch_size)

            # GPU setting
            model.cuda()
            # TODO(hirofumi): move this

        save_path = mkdir_join(args.plot_dir, 'ctc_probs')

        # Clean directory
        if save_path is not None and os.path.isdir(save_path):
            shutil.rmtree(save_path)
            os.mkdir(save_path)

        while True:
            batch, is_new_epoch = dataset.next(recog_params['recog_batch_size'])
            best_hyps_id, aws, perm_ids, _ = model.decode(batch['xs'], recog_params,
                                                          exclude_eos=False)
            ys = [batch['ys'][i] for i in perm_ids]

            # Get CTC probs
            ctc_probs, indices_topk, xlens = model.get_ctc_posteriors(
                batch['xs'], temperature=1, topk=min(100, model.vocab))
            # NOTE: ctc_probs: '[B, T, topk]'

            for b in range(len(batch['xs'])):
                tokens = dataset.idx2token[0](best_hyps_id[b], return_list=True)
                tokens = [unicode(t, 'utf-8') for t in tokens]
                spk = '_'.join(batch['utt_ids'][b].replace('-', '_').split('_')[:-2])

                plot_ctc_probs(
                    ctc_probs[b, :xlens[b]],
                    indices_topk[b],
                    nframes=xlens[b],
                    subsample_factor=subsample_factor,
                    spectrogram=batch['xs'][b][:, :dataset.input_dim],
                    save_path=mkdir_join(save_path, spk, batch['utt_ids'][b] + '.png'),
                    figsize=(20, 8))

                ref = ys[b]
                hyp = ' '.join(tokens)
                logger.info('utt-id: %s' % batch['utt_ids'][b])
                logger.info('Ref: %s' % ref.lower())
                logger.info('Hyp: %s' % hyp)
                logger.info('-' * 50)

            if is_new_epoch:
                break