Exemplo n.º 1
0
def main():

    # Load a config file
    config = load_config(os.path.join(args.model, 'config.yml'))

    decode_params = vars(args)

    # Merge config with args
    for k, v in config.items():
        if not hasattr(args, k):
            setattr(args, k, v)

    # Setting for logging
    logger = set_logger(os.path.join(args.plot_dir, 'plot.log'), key='decoding')

    for i, set in enumerate(args.eval_sets):
        # Load dataset
        eval_set = Dataset(csv_path=set,
                           dict_path=os.path.join(args.model, 'dict.txt'),
                           dict_path_sub=os.path.join(args.model, 'dict_sub.txt') if os.path.isfile(
                               os.path.join(args.model, 'dict_sub.txt')) else None,
                           wp_model=os.path.join(args.model, 'wp.model'),
                           unit=args.unit,
                           batch_size=args.batch_size,
                           max_num_frames=args.max_num_frames,
                           min_num_frames=args.min_num_frames,
                           is_test=True)

        if i == 0:
            args.vocab = eval_set.vocab
            args.vocab_sub = eval_set.vocab_sub
            args.input_dim = eval_set.input_dim

            # TODO(hirofumi): For cold fusion
            args.rnnlm_cold_fusion = None
            args.rnnlm_init = None

            # Load the ASR model
            model = Seq2seq(args)
            epoch, _, _, _ = model.load_checkpoint(args.model, epoch=args.epoch)

            model.save_path = args.model

            # For shallow fusion
            if args.rnnlm_cold_fusion is None and args.rnnlm is not None and args.rnnlm_weight > 0:
                # Load a RNNLM config file
                config_rnnlm = load_config(os.path.join(args.rnnlm, 'config.yml'))

                # Merge config with args
                args_rnnlm = argparse.Namespace()
                for k, v in config_rnnlm.items():
                    setattr(args_rnnlm, k, v)

                assert args.unit == args_rnnlm.unit
                args_rnnlm.vocab = eval_set.vocab

                # Load the pre-trianed RNNLM
                rnnlm = RNNLM(args_rnnlm)
                rnnlm.load_checkpoint(args.rnnlm, epoch=-1)
                if args_rnnlm.backward:
                    model.rnnlm_bwd_0 = rnnlm
                else:
                    model.rnnlm_fwd_0 = rnnlm

                logger.info('RNNLM path: %s' % args.rnnlm)
                logger.info('RNNLM weight: %.3f' % args.rnnlm_weight)
                logger.info('RNNLM backward: %s' % str(config_rnnlm['backward']))

            # GPU setting
            model.cuda()

            logger.info('beam width: %d' % args.beam_width)
            logger.info('length penalty: %.3f' % args.length_penalty)
            logger.info('coverage penalty: %.3f' % args.coverage_penalty)
            logger.info('coverage threshold: %.3f' % args.coverage_threshold)
            logger.info('epoch: %d' % (epoch - 1))

        save_path = mkdir_join(args.plot_dir, 'att_weights')

        # Clean directory
        if save_path is not None and os.path.isdir(save_path):
            shutil.rmtree(save_path)
            os.mkdir(save_path)

        while True:
            batch, is_new_epoch = eval_set.next(decode_params['batch_size'])
            best_hyps, aws, perm_idx = model.decode(batch['xs'], decode_params,
                                                    exclude_eos=False)
            ys = [batch['ys'][i] for i in perm_idx]

            if model.bwd_weight > 0.5:
                # Reverse the order
                best_hyps = [hyp[::-1] for hyp in best_hyps]
                aws = [aw[::-1] for aw in aws]

            for b in range(len(batch['xs'])):
                if args.unit == 'word':
                    token_list = eval_set.idx2word(best_hyps[b], return_list=True)
                if args.unit == 'wp':
                    token_list = eval_set.idx2wp(best_hyps[b], return_list=True)
                elif args.unit == 'char':
                    token_list = eval_set.idx2char(best_hyps[b], return_list=True)
                elif args.unit == 'phone':
                    token_list = eval_set.idx2phone(best_hyps[b], return_list=True)
                else:
                    raise NotImplementedError(args.unit)
                token_list = [unicode(t, 'utf-8') for t in token_list]
                speaker = '_'.join(batch['utt_ids'][b].replace('-', '_').split('_')[:-2])

                # error check
                assert len(batch['xs'][b]) <= 2000

                plot_attention_weights(aws[b][:len(token_list)],
                                       label_list=token_list,
                                       spectrogram=batch['xs'][b][:,
                                                                  :eval_set.input_dim] if args.input_type == 'speech' else None,
                                       save_path=mkdir_join(save_path, speaker, batch['utt_ids'][b] + '.png'),
                                       figsize=(20, 8))

                ref = ys[b]
                if model.bwd_weight > 0.5:
                    hyp = ' '.join(token_list[::-1])
                else:
                    hyp = ' '.join(token_list)
                logger.info('utt-id: %s' % batch['utt_ids'][b])
                logger.info('Ref: %s' % ref.lower())
                logger.info('Hyp: %s' % hyp)
                logger.info('-' * 50)

            if is_new_epoch:
                break
Exemplo n.º 2
0
def main():

    # Load a config file
    config = load_config(os.path.join(args.model, 'config.yml'))

    decode_params = vars(args)

    # Merge config with args
    for k, v in config.items():
        if not hasattr(args, k):
            setattr(args, k, v)

    # Setting for logging
    logger = set_logger(os.path.join(args.model, 'decode.log'), key='decoding')

    for i, set in enumerate(args.eval_sets):
        # Load dataset
        eval_set = Dataset(
            csv_path=set,
            dict_path=os.path.join(args.model, 'dict.txt'),
            dict_path_sub=os.path.join(args.model, 'dict_sub.txt') if
            os.path.isfile(os.path.join(args.model, 'dict_sub.txt')) else None,
            label_type=args.label_type,
            batch_size=args.batch_size,
            max_epoch=args.num_epochs,
            max_num_frames=args.max_num_frames,
            min_num_frames=args.min_num_frames,
            is_test=False)

        if i == 0:
            args.num_classes = eval_set.num_classes
            args.input_dim = eval_set.input_dim
            args.num_classes_sub = eval_set.num_classes_sub

            # TODO(hirofumi): For cold fusion
            args.rnnlm_cf = None
            args.rnnlm_init = None

            # Load the ASR model
            model = Seq2seq(args)

            # Restore the saved parameters
            epoch, _, _, _ = model.load_checkpoint(args.model,
                                                   epoch=args.epoch)

            model.save_path = args.model

            # For shallow fusion
            if args.rnnlm_cf is None and args.rnnlm is not None and args.rnnlm_weight > 0:
                # Load a RNNLM config file
                config_rnnlm = load_config(
                    os.path.join(args.rnnlm, 'config.yml'))

                # Merge config with args
                args_rnnlm = argparse.Namespace()
                for k, v in config_rnnlm.items():
                    setattr(args_rnnlm, k, v)

                assert args.label_type == args_rnnlm.label_type
                args_rnnlm.num_classes = eval_set.num_classes

                # Load the pre-trianed RNNLM
                rnnlm = RNNLM(args_rnnlm)
                rnnlm.load_checkpoint(args.rnnlm, epoch=-1)
                if args_rnnlm.backward:
                    model.rnnlm_bwd_0 = rnnlm
                else:
                    model.rnnlm_fwd_0 = rnnlm

                logger.info('RNNLM path: %s' % args.rnnlm)
                logger.info('RNNLM weight: %.3f' % args.rnnlm_weight)
                logger.info('RNNLM backward: %s' %
                            str(config_rnnlm['backward']))

            # GPU setting
            model.set_cuda(deterministic=False, benchmark=True)

            logger.info('beam width: %d' % args.beam_width)
            logger.info('length penalty: %.3f' % args.length_penalty)
            logger.info('coverage penalty: %.3f' % args.coverage_penalty)
            logger.info('coverage threshold: %.3f' % args.coverage_threshold)
            logger.info('epoch: %d' % (epoch - 1))

        save_path = mkdir_join(args.model, 'att_weights')

        # Clean directory
        if save_path is not None and os.path.isdir(save_path):
            shutil.rmtree(save_path)
            os.mkdir(save_path)

        while True:
            batch, is_new_epoch = eval_set.next(decode_params['batch_size'])
            best_hyps, aw, perm_idx = model.decode(batch['xs'],
                                                   decode_params,
                                                   exclude_eos=False)
            ys = [batch['ys'][i] for i in perm_idx]

            for b in range(len(batch['xs'])):
                if args.label_type in ['word', 'wordpiece']:
                    token_list = eval_set.idx2word(best_hyps[b],
                                                   return_list=True)
                elif args.label_type == 'char':
                    token_list = eval_set.idx2char(best_hyps[b],
                                                   return_list=True)
                elif args.label_type == 'phone':
                    token_list = eval_set.idx2phone(best_hyps[b],
                                                    return_list=True)
                else:
                    raise NotImplementedError()
                token_list = [unicode(t, 'utf-8') for t in token_list]
                speaker = '_'.join(batch['utt_ids'][b].replace(
                    '-', '_').split('_')[:-2])

                # error check
                assert len(batch['xs'][b]) <= 2000

                plot_attention_weights(
                    aw[b][:len(token_list)],
                    label_list=token_list,
                    spectrogram=batch['xs'][b][:, :eval_set.input_dim]
                    if args.input_type == 'speech' else None,
                    save_path=mkdir_join(save_path, speaker,
                                         batch['utt_ids'][b] + '.png'),
                    figsize=(20, 8))

                # Reference
                if eval_set.is_test:
                    text_ref = ys[b]
                else:
                    if args.label_type in ['word', 'wordpiece']:
                        text_ref = eval_set.idx2word(ys[b])
                    if args.label_type in ['word', 'wordpiece']:
                        token_list = eval_set.idx2word(ys[b])
                    elif args.label_type == 'char':
                        token_list = eval_set.idx2char(ys[b])
                    elif args.label_type == 'phone':
                        token_list = eval_set.idx2phone(ys[b])

                # Hypothesis
                text_hyp = ' '.join(token_list)

                sys.stdout = open(
                    os.path.join(save_path, speaker,
                                 batch['utt_ids'][b] + '.txt'), 'w')
                ler = wer_align(
                    ref=text_ref.split(' '),
                    hyp=text_hyp.encode('utf-8').split(' '),
                    normalize=True,
                    double_byte=False)[0]  # TODO(hirofumi): add corpus to args
                print('\nLER: %.3f %%\n\n' % ler)

            if is_new_epoch:
                break
Exemplo n.º 3
0
def main():

    # Load a config file
    if args.resume_model is None:
        config = load_config(args.config)
    else:
        # Restart from the last checkpoint
        config = load_config(os.path.join(args.resume_model, 'config.yml'))

    # Check differences between args and yaml comfiguraiton
    for k, v in vars(args).items():
        if k not in config.keys():
            warnings.warn("key %s is automatically set to %s" % (k, str(v)))

    # Merge config with args
    for k, v in config.items():
        setattr(args, k, v)

    # Automatically reduce batch size in multi-GPU setting
    if args.ngpus > 1:
        args.batch_size -= 10
        args.print_step //= args.ngpus

    subsample_factor = 1
    subsample_factor_sub = 1
    for p in args.conv_poolings:
        if len(p) > 0:
            subsample_factor *= p[0]
    if args.train_set_sub is not None:
        subsample_factor_sub = subsample_factor * (2**sum(
            args.subsample[:args.enc_num_layers_sub - 1]))
    subsample_factor *= 2**sum(args.subsample)

    # Load dataset
    train_set = Dataset(csv_path=args.train_set,
                        dict_path=args.dict,
                        label_type=args.label_type,
                        batch_size=args.batch_size * args.ngpus,
                        max_epoch=args.num_epochs,
                        max_num_frames=args.max_num_frames,
                        min_num_frames=args.min_num_frames,
                        sort_by_input_length=True,
                        short2long=True,
                        sort_stop_epoch=args.sort_stop_epoch,
                        dynamic_batching=True,
                        use_ctc=args.ctc_weight > 0,
                        subsample_factor=subsample_factor,
                        csv_path_sub=args.train_set_sub,
                        dict_path_sub=args.dict_sub,
                        label_type_sub=args.label_type_sub,
                        use_ctc_sub=args.ctc_weight_sub > 0,
                        subsample_factor_sub=subsample_factor_sub,
                        skip_speech=(args.input_type != 'speech'))
    dev_set = Dataset(csv_path=args.dev_set,
                      dict_path=args.dict,
                      label_type=args.label_type,
                      batch_size=args.batch_size * args.ngpus,
                      max_epoch=args.num_epochs,
                      max_num_frames=args.max_num_frames,
                      min_num_frames=args.min_num_frames,
                      shuffle=True,
                      use_ctc=args.ctc_weight > 0,
                      subsample_factor=subsample_factor,
                      csv_path_sub=args.dev_set_sub,
                      dict_path_sub=args.dict_sub,
                      label_type_sub=args.label_type_sub,
                      use_ctc_sub=args.ctc_weight_sub > 0,
                      subsample_factor_sub=subsample_factor_sub,
                      skip_speech=(args.input_type != 'speech'))
    eval_sets = []
    for set in args.eval_sets:
        eval_sets += [
            Dataset(csv_path=set,
                    dict_path=args.dict,
                    label_type=args.label_type,
                    batch_size=1,
                    max_epoch=args.num_epochs,
                    is_test=True,
                    skip_speech=(args.input_type != 'speech'))
        ]

    args.num_classes = train_set.num_classes
    args.input_dim = train_set.input_dim
    args.num_classes_sub = train_set.num_classes_sub

    # Load a RNNLM config file for cold fusion & RNNLM initialization
    # if config['rnnlm_cf']:
    #     if args.model is not None:
    #         config['rnnlm_config_cold_fusion'] = load_config(
    #             os.path.join(config['rnnlm_cf'], 'config.yml'), is_eval=True)
    #     elif args.resume_model is not None:
    #         config = load_config(os.path.join(
    #             args.resume_model, 'config_rnnlm_cf.yml'))
    #     assert args.label_type == config['rnnlm_config_cold_fusion']['label_type']
    #     config['rnnlm_config_cold_fusion']['num_classes'] = train_set.num_classes
    args.rnnlm_cf = None
    args.rnnlm_init = None

    # Model setting
    model = Seq2seq(args)
    model.name = args.enc_type
    if len(args.conv_channels) > 0:
        tmp = model.name
        model.name = 'conv' + str(len(args.conv_channels)) + 'L'
        if args.conv_batch_norm:
            model.name += 'bn'
        model.name += tmp
    model.name += str(args.enc_num_units) + 'H'
    model.name += str(args.enc_num_projs) + 'P'
    model.name += str(args.enc_num_layers) + 'L'
    model.name += '_subsample' + str(subsample_factor)
    model.name += '_' + args.dec_type
    model.name += str(args.dec_num_units) + 'H'
    # model.name += str(args.dec_num_projs) + 'P'
    model.name += str(args.dec_num_layers) + 'L'
    model.name += '_' + args.att_type
    if args.att_num_heads > 1:
        model.name += '_head' + str(args.att_num_heads)
    model.name += '_' + args.optimizer
    model.name += '_lr' + str(args.learning_rate)
    model.name += '_bs' + str(args.batch_size)
    model.name += '_ss' + str(args.ss_prob)
    model.name += '_ls' + str(args.lsm_prob)
    if args.ctc_weight > 0:
        model.name += '_ctc' + str(args.ctc_weight)
    if args.bwd_weight > 0:
        model.name += '_bwd' + str(args.bwd_weight)
    if args.main_task_weight < 1:
        model.name += '_main' + str(args.main_task_weight)
        if args.ctc_weight_sub > 0:
            model.name += '_ctcsub' + str(args.ctc_weight_sub *
                                          (1 - args.main_task_weight))
        else:
            model.name += '_attsub' + str(1 - args.main_task_weight)

    if args.resume_model is None:
        # Load pre-trained RNNLM
        # if config['rnnlm_cf']:
        #     rnnlm = RNNLM(args)
        #     rnnlm.load_checkpoint(save_path=config['rnnlm_cf'], epoch=-1)
        #     rnnlm.flatten_parameters()
        #
        #     # Fix RNNLM parameters
        #     for param in rnnlm.parameters():
        #         param.requires_grad = False
        #
        #     # Set pre-trained parameters
        #     if config['rnnlm_config_cold_fusion']['backward']:
        #         model.dec_0_bwd.rnnlm = rnnlm
        #     else:
        #         model.dec_0_fwd.rnnlm = rnnlm
        # TODO(hirofumi): 最初にRNNLMのモデルをコピー

        # Set save path
        save_path = mkdir_join(
            args.model,
            '_'.join(os.path.basename(args.train_set).split('.')[:-1]),
            model.name)
        model.set_save_path(save_path)  # avoid overwriting

        # Save the config file as a yaml file
        save_config(vars(args), model.save_path)

        # Save the dictionary & wp_model
        shutil.copy(args.dict, os.path.join(save_path, 'dict.txt'))
        if args.dict_sub is not None:
            shutil.copy(args.dict_sub, os.path.join(save_path, 'dict_sub.txt'))
        if args.label_type == 'wordpiece':
            shutil.copy(args.wp_model, os.path.join(save_path, 'wp.model'))

        # Setting for logging
        logger = set_logger(os.path.join(model.save_path, 'train.log'),
                            key='training')

        for k, v in sorted(vars(args).items(), key=lambda x: x[0]):
            logger.info('%s: %s' % (k, str(v)))

        # if os.path.isdir(args.pretrained_model):
        #     # NOTE: Start training from the pre-trained model
        #     # This is defferent from resuming training
        #     model.load_checkpoint(args.pretrained_model, epoch=-1,
        #                           load_pretrained_model=True)

        # Count total parameters
        for name in sorted(list(model.num_params_dict.keys())):
            num_params = model.num_params_dict[name]
            logger.info("%s %d" % (name, num_params))
        logger.info("Total %.2f M parameters" %
                    (model.total_parameters / 1000000))

        # Set optimizer
        model.set_optimizer(optimizer=args.optimizer,
                            learning_rate_init=float(args.learning_rate),
                            weight_decay=float(args.weight_decay),
                            clip_grad_norm=args.clip_grad_norm,
                            lr_schedule=False,
                            factor=args.decay_rate,
                            patience_epoch=args.decay_patient_epoch)

        epoch, step = 1, 0
        learning_rate = float(args.learning_rate)
        metric_dev_best = 10000

    # NOTE: Restart from the last checkpoint
    # elif args.resume_model is not None:
    #     # Set save path
    #     model.save_path = args.resume_model
    #
    #     # Setting for logging
    #     logger = set_logger(os.path.join(model.save_path, 'train.log'), key='training')
    #
    #     # Set optimizer
    #     model.set_optimizer(
    #         optimizer=config['optimizer'],
    #         learning_rate_init=float(config['learning_rate']),  # on-the-fly
    #         weight_decay=float(config['weight_decay']),
    #         clip_grad_norm=config['clip_grad_norm'],
    #         lr_schedule=False,
    #         factor=config['decay_rate'],
    #         patience_epoch=config['decay_patient_epoch'])
    #
    #     # Restore the last saved model
    #     epoch, step, learning_rate, metric_dev_best = model.load_checkpoint(
    #         save_path=args.resume_model, epoch=-1, restart=True)
    #
    #     if epoch >= config['convert_to_sgd_epoch']:
    #         model.set_optimizer(
    #             optimizer='sgd',
    #             learning_rate_init=float(config['learning_rate']),  # on-the-fly
    #             weight_decay=float(config['weight_decay']),
    #             clip_grad_norm=config['clip_grad_norm'],
    #             lr_schedule=False,
    #             factor=config['decay_rate'],
    #             patience_epoch=config['decay_patient_epoch'])
    #
    #     if config['rnnlm_cf']:
    #         if config['rnnlm_config_cold_fusion']['backward']:
    #             model.rnnlm_0_bwd.flatten_parameters()
    #         else:
    #             model.rnnlm_0_fwd.flatten_parameters()

    train_set.epoch = epoch - 1  # start from index:0

    # GPU setting
    if args.ngpus >= 1:
        model = CustomDataParallel(model,
                                   device_ids=list(range(0, args.ngpus, 1)),
                                   deterministic=False,
                                   benchmark=True)
        model.cuda()

    logger.info('PID: %s' % os.getpid())
    logger.info('USERNAME: %s' % os.uname()[1])

    # Set process name
    # setproctitle(args.job_name)

    # Set learning rate controller
    lr_controller = Controller(learning_rate_init=learning_rate,
                               decay_type=args.decay_type,
                               decay_start_epoch=args.decay_start_epoch,
                               decay_rate=args.decay_rate,
                               decay_patient_epoch=args.decay_patient_epoch,
                               lower_better=True,
                               best_value=metric_dev_best)

    # Set reporter
    reporter = Reporter(model.module.save_path, max_loss=300)

    # Set the updater
    updater = Updater(args.clip_grad_norm)

    # Setting for tensorboard
    tf_writer = SummaryWriter(model.module.save_path)

    start_time_train = time.time()
    start_time_epoch = time.time()
    start_time_step = time.time()
    not_improved_epoch = 0.
    loss_train_mean, acc_train_mean = 0., 0.
    pbar_epoch = tqdm(total=len(train_set))
    pbar_all = tqdm(total=len(train_set) * args.num_epochs)
    while True:
        # Compute loss in the training set (including parameter update)
        batch_train, is_new_epoch = train_set.next()
        model, loss_train, acc_train = updater(model, batch_train)
        loss_train_mean += loss_train
        acc_train_mean += acc_train
        pbar_epoch.update(len(batch_train['utt_ids']))

        if (step + 1) % args.print_step == 0:
            # Compute loss in the dev set
            batch_dev = dev_set.next()[0]
            model, loss_dev, acc_dev = updater(model, batch_dev, is_eval=True)

            loss_train_mean /= args.print_step
            acc_train_mean /= args.print_step
            reporter.step(step, loss_train_mean, loss_dev, acc_train_mean,
                          acc_dev)

            # Logging by tensorboard
            tf_writer.add_scalar('train/loss', loss_train_mean, step + 1)
            tf_writer.add_scalar('dev/loss', loss_dev, step + 1)
            # for n, p in model.module.named_parameters():
            #     n = n.replace('.', '/')
            #     if p.grad is not None:
            #         tf_writer.add_histogram(n, p.data.cpu().numpy(), step + 1)
            #         tf_writer.add_histogram(n + '/grad', p.grad.data.cpu().numpy(), step + 1)

            duration_step = time.time() - start_time_step
            if args.input_type == 'speech':
                x_len = max(len(x) for x in batch_train['xs'])
            elif args.input_type == 'text':
                x_len = max(len(x) for x in batch_train['ys_sub'])
            logger.info(
                "...Step:%d(ep:%.2f) loss:%.2f(%.2f)/acc:%.2f(%.2f)/lr:%.5f/bs:%d/x_len:%d (%.2f min)"
                % (step + 1, train_set.epoch_detail, loss_train_mean, loss_dev,
                   acc_train_mean, acc_dev, learning_rate,
                   train_set.current_batch_size, x_len, duration_step / 60))
            start_time_step = time.time()
            loss_train_mean, acc_train_mean = 0, 0
        step += args.ngpus

        # Save checkpoint and evaluate model per epoch
        if is_new_epoch:
            duration_epoch = time.time() - start_time_epoch
            logger.info('===== EPOCH:%d (%.2f min) =====' %
                        (epoch, duration_epoch / 60))

            # Save fugures of loss and accuracy
            reporter.epoch()

            if epoch < args.eval_start_epoch:
                # Save the model
                model.module.save_checkpoint(model.module.save_path, epoch,
                                             step, learning_rate,
                                             metric_dev_best)
            else:
                start_time_eval = time.time()
                # dev
                if args.metric == 'ler':
                    if args.label_type == 'word':
                        metric_dev = eval_word([model.module],
                                               dev_set,
                                               decode_params,
                                               epoch=epoch)[0]
                        logger.info('  WER (%s): %.3f %%' %
                                    (dev_set.set, metric_dev))
                    elif args.label_type == 'wordpiece':
                        metric_dev = eval_wordpiece([model.module],
                                                    dev_set,
                                                    decode_params,
                                                    args.wp_model,
                                                    epoch=epoch)[0]
                        logger.info('  WER (%s): %.3f %%' %
                                    (dev_set.set, metric_dev))
                    elif 'char' in args.label_type:
                        metric_dev = eval_char([model.module],
                                               dev_set,
                                               decode_params,
                                               epoch=epoch)[1][0]
                        logger.info('  CER (%s): %.3f %%' %
                                    (dev_set.set, metric_dev))
                    elif 'phone' in args.label_type:
                        metric_dev = eval_phone([model.module],
                                                dev_set,
                                                decode_params,
                                                epoch=epoch)[0]
                        logger.info('  PER (%s): %.3f %%' %
                                    (dev_set.set, metric_dev))
                elif args.metric == 'loss':
                    metric_dev = eval_loss([model.module], dev_set,
                                           decode_params)
                    logger.info('  Loss (%s): %.3f %%' %
                                (dev_set.set, metric_dev))
                else:
                    raise NotImplementedError()

                if metric_dev < metric_dev_best:
                    metric_dev_best = metric_dev
                    not_improved_epoch = 0
                    logger.info('||||| Best Score |||||')

                    # Update learning rate
                    model.module.optimizer, learning_rate = lr_controller.decay_lr(
                        optimizer=model.module.optimizer,
                        learning_rate=learning_rate,
                        epoch=epoch,
                        value=metric_dev)

                    # Save the model
                    model.module.save_checkpoint(model.module.save_path, epoch,
                                                 step, learning_rate,
                                                 metric_dev_best)

                    # test
                    for eval_set in eval_sets:
                        if args.metric == 'ler':
                            if args.label_type == 'word':
                                wer_test = eval_word([model.module],
                                                     eval_set,
                                                     decode_params,
                                                     epoch=epoch)[0]
                                logger.info('  WER (%s): %.3f %%' %
                                            (eval_set.set, wer_test))
                            elif args.label_type == 'wordpiece':
                                wer_test = eval_wordpiece([model.module],
                                                          eval_set,
                                                          decode_params,
                                                          epoch=epoch)[0]
                                logger.info('  WER (%s): %.3f %%' %
                                            (eval_set.set, wer_test))
                            elif 'char' in args.label_type:
                                cer_test = eval_char([model.module],
                                                     eval_set,
                                                     decode_params,
                                                     epoch=epoch)[1][0]
                                logger.info('  CER (%s): %.3f / %.3f %%' %
                                            (eval_set.set, cer_test))
                            elif 'phone' in args.label_type:
                                per_test = eval_phone([model.module],
                                                      eval_set,
                                                      decode_params,
                                                      epoch=epoch)[0]
                                logger.info('  PER (%s): %.3f %%' %
                                            (eval_set.set, per_test))
                        elif args.metric == 'loss':
                            loss_test = eval_loss([model.module], eval_set,
                                                  decode_params)
                            logger.info('  Loss (%s): %.3f %%' %
                                        (eval_set.set, loss_test))
                        else:
                            raise NotImplementedError()
                else:
                    # Update learning rate
                    model.module.optimizer, learning_rate = lr_controller.decay_lr(
                        optimizer=model.module.optimizer,
                        learning_rate=learning_rate,
                        epoch=epoch,
                        value=metric_dev)

                    not_improved_epoch += 1

                duration_eval = time.time() - start_time_eval
                logger.info('Evaluation time: %.2f min' % (duration_eval / 60))

                # Early stopping
                if not_improved_epoch == args.not_improved_patient_epoch:
                    break

                if epoch == args.convert_to_sgd_epoch:
                    # Convert to fine-tuning stage
                    model.module.set_optimizer(
                        'sgd',
                        learning_rate_init=float(
                            args.learning_rate),  # TODO: ?
                        weight_decay=float(args.weight_decay),
                        clip_grad_norm=args.clip_grad_norm,
                        lr_schedule=False,
                        factor=args.decay_rate,
                        patience_epoch=args.decay_patient_epoch)
                    logger.info('========== Convert to SGD ==========')

            pbar_epoch = tqdm(total=len(train_set))
            pbar_all.update(len(train_set))

            if epoch == args.num_epochs:
                break

            start_time_step = time.time()
            start_time_epoch = time.time()
            epoch += 1

    duration_train = time.time() - start_time_train
    logger.info('Total time: %.2f hour' % (duration_train / 3600))

    tf_writer.close()
    pbar_epoch.close()
    pbar_all.close()

    return model.module.save_path
Exemplo n.º 4
0
def main():

    # Load a config file
    if args.resume_model is None:
        config = load_config(args.config)
    else:
        # Restart from the last checkpoint
        config = load_config(os.path.join(args.resume_model, 'config.yml'))

    # Check differences between args and yaml comfiguraiton
    for k, v in vars(args).items():
        if k not in config.keys():
            warnings.warn("key %s is automatically set to %s" % (k, str(v)))

    # Merge config with args
    for k, v in config.items():
        setattr(args, k, v)

    # Load dataset
    train_set = Dataset(csv_path=args.train_set,
                        dict_path=args.dict,
                        label_type=args.label_type,
                        batch_size=args.batch_size * args.ngpus,
                        bptt=args.bptt,
                        eos=args.eos,
                        max_epoch=args.num_epochs,
                        shuffle=True)
    dev_set = Dataset(csv_path=args.dev_set,
                      dict_path=args.dict,
                      label_type=args.label_type,
                      batch_size=args.batch_size * args.ngpus,
                      bptt=args.bptt,
                      eos=args.eos,
                      shuffle=True)
    eval_sets = []
    for set in args.eval_sets:
        eval_sets += [Dataset(csv_path=set,
                              dict_path=args.dict,
                              label_type=args.label_type,
                              batch_size=1,
                              bptt=args.bptt,
                              eos=args.eos,
                              is_test=True)]

    args.num_classes = train_set.num_classes

    # Model setting
    model = RNNLM(args)
    model.name = args.rnn_type
    model.name += str(args.num_units) + 'H'
    model.name += str(args.num_projs) + 'P'
    model.name += str(args.num_layers) + 'L'
    model.name += '_emb' + str(args.emb_dim)
    model.name += '_' + args.optimizer
    model.name += '_lr' + str(args.learning_rate)
    model.name += '_bs' + str(args.batch_size)
    if args.tie_weights:
        model.name += '_tie'
    if args.residual:
        model.name += '_residual'
    if args.backward:
        model.name += '_bwd'

    if args.resume_model is None:
        # Set save path
        save_path = mkdir_join(args.model, '_'.join(os.path.basename(args.train_set).split('.')[:-1]), model.name)
        model.set_save_path(save_path)  # avoid overwriting

        # Save the config file as a yaml file
        save_config(vars(args), model.save_path)

        # Save the dictionary & wp_model
        shutil.copy(args.dict, os.path.join(save_path, 'dict.txt'))
        if args.label_type == 'wordpiece':
            shutil.copy(args.wp_model, os.path.join(save_path, 'wp.model'))

        # Setting for logging
        logger = set_logger(os.path.join(model.save_path, 'train.log'), key='training')

        for k, v in sorted(vars(args).items(), key=lambda x: x[0]):
            logger.info('%s: %s' % (k, str(v)))

        # Count total parameters
        for name in sorted(list(model.num_params_dict.keys())):
            num_params = model.num_params_dict[name]
            logger.info("%s %d" % (name, num_params))
        logger.info("Total %.2f M parameters" % (model.total_parameters / 1000000))

        # Set optimizer
        model.set_optimizer(optimizer=args.optimizer,
                            learning_rate_init=float(args.learning_rate),
                            weight_decay=float(args.weight_decay),
                            clip_grad_norm=args.clip_grad_norm,
                            lr_schedule=False,
                            factor=args.decay_rate,
                            patience_epoch=args.decay_patient_epoch)

        epoch, step = 1, 0
        learning_rate = float(args.learning_rate)
        metric_dev_best = 10000

    else:
        raise NotImplementedError()

    train_set.epoch = epoch - 1

    # GPU setting
    if args.ngpus >= 1:
        model = CustomDataParallel(model,
                                   device_ids=list(range(0, args.ngpus, 1)),
                                   deterministic=True,
                                   benchmark=False)
        model.cuda()

    logger.info('PID: %s' % os.getpid())
    logger.info('USERNAME: %s' % os.uname()[1])

    # Set process name
    # setproctitle(args.job_name)

    # Set learning rate controller
    lr_controller = Controller(learning_rate_init=learning_rate,
                               decay_type=args.decay_type,
                               decay_start_epoch=args.decay_start_epoch,
                               decay_rate=args.decay_rate,
                               decay_patient_epoch=args.decay_patient_epoch,
                               lower_better=True,
                               best_value=metric_dev_best)

    # Set reporter
    reporter = Reporter(model.module.save_path, max_loss=10)

    # Set the updater
    updater = Updater(args.clip_grad_norm)

    # Setting for tensorboard
    tf_writer = SummaryWriter(model.module.save_path)

    start_time_train = time.time()
    start_time_epoch = time.time()
    start_time_step = time.time()
    not_improved_epoch = 0
    loss_train_mean, acc_train_mean = 0., 0.
    pbar_epoch = tqdm(total=len(train_set))
    pbar_all = tqdm(total=len(train_set) * args.num_epochs)
    while True:
        # Compute loss in the training set (including parameter update)
        ys_train, is_new_epoch = train_set.next()
        model, loss_train, acc_train = updater(model, ys_train, args.bptt)
        loss_train_mean += loss_train
        acc_train_mean += acc_train
        pbar_epoch.update(np.sum([len(y) for y in ys_train]))

        if (step + 1) % args.print_step == 0:
            # Compute loss in the dev set
            ys_dev = dev_set.next()[0]
            model, loss_dev, acc_dev = updater(model, ys_dev, args.bptt, is_eval=True)

            loss_train_mean /= args.print_step
            acc_train_mean /= args.print_step
            reporter.step(step, loss_train_mean, loss_dev, acc_train_mean, acc_dev)

            # Logging by tensorboard
            tf_writer.add_scalar('train/loss', loss_train_mean, step + 1)
            tf_writer.add_scalar('dev/loss', loss_dev, step + 1)
            for n, p in model.module.named_parameters():
                n = n.replace('.', '/')
                if p.grad is not None:
                    tf_writer.add_histogram(n, p.data.cpu().numpy(), step + 1)
                    tf_writer.add_histogram(n + '/grad', p.grad.data.cpu().numpy(), step + 1)

            duration_step = time.time() - start_time_step
            logger.info("...Step:%d(ep:%.2f) loss:%.2f(%.2f)/acc:%.2f(%.2f)/ppl:%.2f(%.2f)/lr:%.5f/bs:%d (%.2f min)" %
                        (step + 1, train_set.epoch_detail,
                         loss_train_mean, loss_dev, acc_train_mean, acc_dev,
                         math.exp(loss_train_mean), math.exp(loss_dev),
                         learning_rate, len(ys_train), duration_step / 60))
            start_time_step = time.time()
            loss_train_mean, acc_train_mean = 0., 0.
        step += args.ngpus

        # Save checkpoint and evaluate model per epoch
        if is_new_epoch:
            duration_epoch = time.time() - start_time_epoch
            logger.info('===== EPOCH:%d (%.2f min) =====' % (epoch, duration_epoch / 60))

            # Save fugures of loss and accuracy
            reporter.epoch()

            if epoch < args.eval_start_epoch:
                # Save the model
                model.module.save_checkpoint(model.module.save_path, epoch, step,
                                             learning_rate, metric_dev_best)
            else:
                start_time_eval = time.time()
                # dev
                ppl_dev = eval_ppl([model.module], dev_set, args.bptt)
                logger.info(' PPL (%s): %.3f' % (dev_set.set, ppl_dev))

                if ppl_dev < metric_dev_best:
                    metric_dev_best = ppl_dev
                    not_improved_epoch = 0
                    logger.info('||||| Best Score |||||')

                    # Update learning rate
                    model.module.optimizer, learning_rate = lr_controller.decay_lr(
                        optimizer=model.module.optimizer,
                        learning_rate=learning_rate,
                        epoch=epoch,
                        value=ppl_dev)

                    # Save the model
                    model.module.save_checkpoint(model.module.save_path, epoch, step,
                                                 learning_rate, metric_dev_best)

                    # test
                    ppl_test_mean = 0.
                    for eval_set in eval_sets:
                        ppl_test = eval_ppl([model.module], eval_set, args.bptt)
                        logger.info(' PPL (%s): %.3f' % (eval_set.set, ppl_test))
                        ppl_test_mean += ppl_test
                    if len(eval_sets) > 0:
                        logger.info(' PPL (mean): %.3f' % (ppl_test_mean / len(eval_sets)))
                else:
                    # Update learning rate
                    model.module.optimizer, learning_rate = lr_controller.decay_lr(
                        optimizer=model.module.optimizer,
                        learning_rate=learning_rate,
                        epoch=epoch,
                        value=ppl_dev)

                    not_improved_epoch += 1

                duration_eval = time.time() - start_time_eval
                logger.info('Evaluation time: %.2f min' % (duration_eval / 60))

                # Early stopping
                if not_improved_epoch == args.not_improved_patient_epoch:
                    break

                if epoch == args.convert_to_sgd_epoch:
                    # Convert to fine-tuning stage
                    model.module.set_optimizer(
                        'sgd',
                        learning_rate_init=float(args.learning_rate),  # TODO: ?
                        weight_decay=float(args.weight_decay),
                        clip_grad_norm=args.clip_grad_norm,
                        lr_schedule=False,
                        factor=args.decay_rate,
                        patience_epoch=args.decay_patient_epoch)
                    logger.info('========== Convert to SGD ==========')

            pbar_epoch = tqdm(total=len(train_set))
            pbar_all.update(len(train_set))

            if epoch == args.num_epoch:
                break

            start_time_step = time.time()
            start_time_epoch = time.time()
            epoch += 1

    duration_train = time.time() - start_time_train
    logger.info('Total time: %.2f hour' % (duration_train / 3600))

    tf_writer.close()
    pbar_epoch.close()
    pbar_all.close()

    return model.module.save_path
Exemplo n.º 5
0
def main():

    # Load a config file
    config = load_config(os.path.join(args.model, 'config.yml'))

    decode_params = vars(args)

    # Merge config with args
    for k, v in config.items():
        if not hasattr(args, k):
            setattr(args, k, v)

    # Setting for logging
    logger = set_logger(os.path.join(args.model, 'decode.log'), key='decoding')

    wer_mean, cer_mean, per_mean = 0, 0, 0
    for i, set in enumerate(args.eval_sets):
        # Load dataset
        eval_set = Dataset(
            csv_path=set,
            dict_path=os.path.join(args.model, 'dict.txt'),
            dict_path_sub=os.path.join(args.model, 'dict_sub.txt') if
            os.path.isfile(os.path.join(args.model, 'dict_sub.txt')) else None,
            label_type=args.label_type,
            batch_size=args.batch_size,
            max_epoch=args.num_epochs,
            is_test=True)

        if i == 0:
            args.num_classes = eval_set.num_classes
            args.input_dim = eval_set.input_dim
            args.num_classes_sub = eval_set.num_classes_sub

            # For cold fusion
            # if args.rnnlm_cf:
            #     # Load a RNNLM config file
            #     config['rnnlm_config'] = load_config(os.path.join(args.model, 'config_rnnlm.yml'))
            #
            #     assert args.label_type == config['rnnlm_config']['label_type']
            #     rnnlm_args.num_classes = eval_set.num_classes
            #     logger.info('RNNLM path: %s' % config['rnnlm'])
            #     logger.info('RNNLM weight: %.3f' % args.rnnlm_weight)
            # else:
            #     pass

            args.rnnlm_cf = None
            args.rnnlm_init = None

            # Load the ASR model
            model = Seq2seq(args)

            # Restore the saved parameters
            epoch, _, _, _ = model.load_checkpoint(args.model,
                                                   epoch=args.epoch)

            model.save_path = args.model

            # For shallow fusion
            if args.rnnlm_cf is None and args.rnnlm is not None and args.rnnlm_weight > 0:
                # Load a RNNLM config file
                config_rnnlm = load_config(
                    os.path.join(args.rnnlm, 'config.yml'))

                # Merge config with args
                args_rnnlm = argparse.Namespace()
                for k, v in config_rnnlm.items():
                    setattr(args_rnnlm, k, v)

                assert args.label_type == args_rnnlm.label_type
                args_rnnlm.num_classes = eval_set.num_classes

                # Load the pre-trianed RNNLM
                rnnlm = RNNLM(args_rnnlm)
                rnnlm.load_checkpoint(args.rnnlm, epoch=-1)
                if args_rnnlm.backward:
                    model.rnnlm_bwd_0 = rnnlm
                else:
                    model.rnnlm_fwd_0 = rnnlm

                logger.info('RNNLM path: %s' % args.rnnlm)
                logger.info('RNNLM weight: %.3f' % args.rnnlm_weight)
                logger.info('RNNLM backward: %s' %
                            str(config_rnnlm['backward']))

            # GPU setting
            model.set_cuda(deterministic=False, benchmark=True)

            logger.info('beam width: %d' % args.beam_width)
            logger.info('length penalty: %.3f' % args.length_penalty)
            logger.info('coverage penalty: %.3f' % args.coverage_penalty)
            logger.info('coverage threshold: %.3f' % args.coverage_threshold)
            logger.info('epoch: %d' % (epoch - 1))

        start_time = time.time()

        if args.label_type == 'word':
            wer, _, _, _, decode_dir = eval_word([model],
                                                 eval_set,
                                                 decode_params,
                                                 epoch=epoch - 1,
                                                 progressbar=True)
            wer_mean += wer
            logger.info('  WER (%s): %.3f %%' % (eval_set.set, wer))
        elif args.label_type == 'wordpiece':
            wer, _, _, _, decode_dir = eval_wordpiece([model],
                                                      eval_set,
                                                      decode_params,
                                                      os.path.join(
                                                          args.model,
                                                          'wp.model'),
                                                      epoch=epoch - 1,
                                                      progressbar=True)
            wer_mean += wer
            logger.info('  WER (%s): %.3f %%' % (eval_set.set, wer))

        elif 'char' in args.label_type:
            (wer, _, _, _), (cer, _, _,
                             _), decode_dir = eval_char([model],
                                                        eval_set,
                                                        decode_params,
                                                        epoch=epoch - 1,
                                                        progressbar=True)
            wer_mean += wer
            cer_mean += cer
            logger.info('  WER / CER (%s): %.3f / %.3f %%' %
                        (eval_set.set, wer, cer))

        elif 'phone' in args.label_type:
            per, _, _, _, decode_dir = eval_phone([model],
                                                  eval_set,
                                                  decode_params,
                                                  epoch=epoch - 1,
                                                  progressbar=True)
            per_mean += per
            logger.info('  PER (%s): %.3f %%' % (eval_set.set, per))
        else:
            raise ValueError(args.label_type)

        logger.info('Elasped time: %.2f [sec.]:' % (time.time() - start_time))

    if args.label_type == 'word':
        logger.info('  WER (mean): %.3f %%\n' %
                    (wer_mean / len(args.eval_sets)))
    if args.label_type == 'wordpiece':
        logger.info('  WER (mean): %.3f %%\n' %
                    (wer_mean / len(args.eval_sets)))
    elif 'char' in args.label_type:
        logger.info(
            '  WER / CER (mean): %.3f / %.3f %%\n' %
            (wer_mean / len(args.eval_sets), cer_mean / len(args.eval_sets)))
    elif 'phone' in args.label_type:
        logger.info('  PER (mean): %.3f %%\n' %
                    (per_mean / len(args.eval_sets)))

    print(decode_dir)