Ejemplo n.º 1
0
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Load dataset
    test_data = Dataset(
        data_save_path=args.data_save_path,
        backend=params['backend'],
        input_freq=params['input_freq'],
        use_delta=params['use_delta'],
        use_double_delta=params['use_double_delta'],
        # data_type='test_dev93',
        data_type='test_eval92',
        data_size=params['data_size'],
        label_type=params['label_type'],
        batch_size=args.eval_batch_size,
        splice=params['splice'],
        num_stack=params['num_stack'],
        num_skip=params['num_skip'],
        sort_utt=False,
        tool=params['tool'])

    params['num_classes'] = test_data.num_classes

    # Load model
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    # Restore the saved parameters
    model.load_checkpoint(save_path=args.model_path, epoch=args.epoch)

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    print('beam width: %d' % args.beam_width)

    if params['label_type'] == 'word':
        wer_eval92, df_eval92 = eval_word(models=[model],
                                          dataset=test_data,
                                          beam_width=args.beam_width,
                                          max_decode_len=MAX_DECODE_LEN_WORD,
                                          eval_batch_size=args.eval_batch_size,
                                          length_penalty=args.length_penalty,
                                          progressbar=True)
        print('  WER (eval92): %.3f %%' % (wer_eval92 * 100))
        print(df_eval92)
    else:
        wer_eval92, cer_eval92, df_eval92 = eval_char(
            models=[model],
            dataset=test_data,
            beam_width=args.beam_width,
            max_decode_len=MAX_DECODE_LEN_CHAR,
            eval_batch_size=args.eval_batch_size,
            length_penalty=args.length_penalty,
            progressbar=True)
        print('  WER / CER (eval92): %.3f / %.3f %%' % ((wer_eval92 * 100),
                                                        (cer_eval92 * 100)))
        print(df_eval92)
Ejemplo n.º 2
0
def main():

    args = parser.parse_args()

    ##################################################
    # DATSET
    ##################################################
    if args.model_save_path is not None:
        # Load a config file (.yml)
        params = load_config(args.config_path)
    # NOTE: Retrain the saved model from the last checkpoint
    elif args.saved_model_path is not None:
        params = load_config(os.path.join(args.saved_model_path, 'config.yml'))
    else:
        raise ValueError("Set model_save_path or saved_model_path.")

    # Load dataset
    train_data = Dataset(data_save_path=args.data_save_path,
                         backend=params['backend'],
                         input_freq=params['input_freq'],
                         use_delta=params['use_delta'],
                         use_double_delta=params['use_double_delta'],
                         data_type=params['data_size'],
                         data_size=params['data_size'],
                         label_type=params['label_type'],
                         label_type_sub=params['label_type_sub'],
                         batch_size=params['batch_size'],
                         max_epoch=params['num_epoch'],
                         splice=params['splice'],
                         num_stack=params['num_stack'],
                         num_skip=params['num_skip'],
                         min_frame_num=params['min_frame_num'],
                         sort_utt=True,
                         sort_stop_epoch=params['sort_stop_epoch'],
                         tool=params['tool'],
                         num_enque=None,
                         dynamic_batching=params['dynamic_batching'])
    dev93_data = Dataset(data_save_path=args.data_save_path,
                         backend=params['backend'],
                         input_freq=params['input_freq'],
                         use_delta=params['use_delta'],
                         use_double_delta=params['use_double_delta'],
                         data_type='test_dev93',
                         data_size=params['data_size'],
                         label_type=params['label_type'],
                         label_type_sub=params['label_type_sub'],
                         batch_size=params['batch_size'],
                         splice=params['splice'],
                         num_stack=params['num_stack'],
                         num_skip=params['num_skip'],
                         shuffle=True,
                         tool=params['tool'])
    eval92_data = Dataset(data_save_path=args.data_save_path,
                          backend=params['backend'],
                          input_freq=params['input_freq'],
                          use_delta=params['use_delta'],
                          use_double_delta=params['use_double_delta'],
                          data_type='test_eval92',
                          data_size=params['data_size'],
                          label_type=params['label_type'],
                          label_type_sub=params['label_type_sub'],
                          batch_size=params['batch_size'],
                          splice=params['splice'],
                          num_stack=params['num_stack'],
                          num_skip=params['num_skip'],
                          tool=params['tool'])

    params['num_classes'] = train_data.num_classes
    params['num_classes_sub'] = train_data.num_classes_sub

    ##################################################
    # MODEL
    ##################################################
    # Model setting
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    if args.model_save_path is not None:

        # Set save path
        save_path = mkdir_join(
            args.model_save_path, params['backend'], params['model_type'],
            params['label_type'] + '_' + params['label_type_sub'],
            params['data_size'], model.name)
        model.set_save_path(save_path)

        # Save config file
        save_config(config_path=args.config_path, save_path=model.save_path)

        # Setting for logging
        logger = set_logger(model.save_path)

        if os.path.isdir(params['char_init']):
            # NOTE: Start training from the pre-trained character model
            model.load_checkpoint(save_path=params['char_init'],
                                  epoch=-1,
                                  load_pretrained_model=True)

        # Count total parameters
        for name in sorted(list(model.num_params_dict.keys())):
            num_params = model.num_params_dict[name]
            logger.info("%s %d" % (name, num_params))
        logger.info("Total %.3f M parameters" %
                    (model.total_parameters / 1000000))

        # Define optimizer
        model.set_optimizer(optimizer=params['optimizer'],
                            learning_rate_init=float(params['learning_rate']),
                            weight_decay=float(params['weight_decay']),
                            clip_grad_norm=params['clip_grad_norm'],
                            lr_schedule=False,
                            factor=params['decay_rate'],
                            patience_epoch=params['decay_patient_epoch'])

        epoch, step = 1, 0
        learning_rate = float(params['learning_rate'])
        metric_dev_best = 1

    # NOTE: Retrain the saved model from the last checkpoint
    elif args.saved_model_path is not None:

        # Set save path
        model.save_path = args.saved_model_path

        # Setting for logging
        logger = set_logger(model.save_path, restart=True)

        # Define optimizer
        model.set_optimizer(
            optimizer=params['optimizer'],
            learning_rate_init=float(params['learning_rate']),  # on-the-fly
            weight_decay=float(params['weight_decay']),
            clip_grad_norm=params['clip_grad_norm'],
            lr_schedule=False,
            factor=params['decay_rate'],
            patience_epoch=params['decay_patient_epoch'])

        # Restore the last saved model
        epoch, step, learning_rate, metric_dev_best = model.load_checkpoint(
            save_path=args.saved_model_path, epoch=-1, restart=True)

    else:
        raise ValueError("Set model_save_path or saved_model_path.")

    train_data.epoch = epoch - 1

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    logger.info('PID: %s' % os.getpid())
    logger.info('USERNAME: %s' % os.uname()[1])

    # Set process name
    setproctitle('wsj_' + params['backend'] + '_' + params['model_type'] +
                 '_' + params['label_type'] + '_' + params['label_type_sub'] +
                 '_' + params['data_size'])

    ##################################################
    # TRAINING LOOP
    ##################################################
    # Define learning rate controller
    lr_controller = Controller(
        learning_rate_init=learning_rate,
        backend=params['backend'],
        decay_start_epoch=params['decay_start_epoch'],
        decay_rate=params['decay_rate'],
        decay_patient_epoch=params['decay_patient_epoch'],
        lower_better=True)

    # Setting for tensorboard
    if params['backend'] == 'pytorch':
        tf_writer = SummaryWriter(model.save_path)

    # Train model
    csv_steps, csv_loss_train, csv_loss_dev = [], [], []
    start_time_train = time.time()
    start_time_epoch = time.time()
    start_time_step = time.time()
    not_improved_epoch = 0
    best_model = model
    loss_train_mean, loss_main_train_mean, loss_sub_train_mean = 0., 0., 0.
    pbar_epoch = tqdm(total=len(train_data))
    while True:
        # Compute loss in the training set (including parameter update)
        batch_train, is_new_epoch = train_data.next()
        model, loss_train_val, loss_main_train_val, loss_sub_train_val = train_hierarchical_step(
            model,
            batch_train,
            params['clip_grad_norm'],
            backend=params['backend'])
        loss_train_mean += loss_train_val
        loss_main_train_mean += loss_main_train_val
        loss_sub_train_mean += loss_sub_train_val

        pbar_epoch.update(len(batch_train['xs']))

        if (step + 1) % params['print_step'] == 0:

            # Compute loss in the dev set
            batch_dev = dev93_data.next()[0]
            loss_dev, loss_main_dev, loss_sub_dev = model(
                batch_dev['xs'],
                batch_dev['ys'],
                batch_dev['x_lens'],
                batch_dev['y_lens'],
                batch_dev['ys_sub'],
                batch_dev['y_lens_sub'],
                is_eval=True)

            loss_train_mean /= params['print_step']
            loss_main_train_mean /= params['print_step']
            loss_sub_train_mean /= params['print_step']
            csv_steps.append(step)
            csv_loss_train.append(loss_train_mean)
            csv_loss_dev.append(loss_dev)

            # Logging by tensorboard
            if params['backend'] == 'pytorch':
                tf_writer.add_scalar('train/loss', loss_train_mean, step + 1)
                tf_writer.add_scalar('train/loss_main', loss_main_train_mean,
                                     step + 1)
                tf_writer.add_scalar('train/loss_sub', loss_sub_train_mean,
                                     step + 1)
                tf_writer.add_scalar('dev/loss', loss_dev, step + 1)
                tf_writer.add_scalar('dev/loss_main', loss_main_dev, step + 1)
                tf_writer.add_scalar('dev/loss_sub', loss_sub_dev, step + 1)
                # for name, param in model.named_parameters():
                #     name = name.replace('.', '/')
                #     tf_writer.add_histogram(
                #         name, param.data.cpu().numpy(), step + 1)
                #     if param.grad is not None:
                #         tf_writer.add_histogram(
                #             name + '/grad', param.grad.data.cpu().numpy(), step + 1)
                #     # TODO: fix this

            duration_step = time.time() - start_time_step
            logger.info(
                "...Step:%d(epoch:%.3f) loss:%.3f/%.3f/%.3f(%.3f/%.3f/%.3f)/lr:%.5f/batch:%d/x_lens:%d (%.3f min)"
                % (step + 1, train_data.epoch_detail, loss_train_mean,
                   loss_main_train_mean, loss_sub_train_mean, loss_dev,
                   loss_main_dev, loss_sub_dev, learning_rate,
                   train_data.current_batch_size, max(batch_train['x_lens']) *
                   params['num_stack'], duration_step / 60))
            start_time_step = time.time()
            loss_train_mean, loss_main_train_mean, loss_sub_train_mean = 0., 0., 0.
        step += 1

        # Save checkpoint and evaluate model per epoch
        if is_new_epoch:
            duration_epoch = time.time() - start_time_epoch
            logger.info('===== EPOCH:%d (%.3f min) =====' %
                        (epoch, duration_epoch / 60))

            # Save fugure of loss
            plot_loss(csv_loss_train,
                      csv_loss_dev,
                      csv_steps,
                      save_path=model.save_path)

            if epoch < params['eval_start_epoch']:
                # Save the model
                model.save_checkpoint(model.save_path, epoch, step,
                                      learning_rate, metric_dev_best)
            else:
                start_time_eval = time.time()
                # dev
                if model.main_loss_weight > 0:
                    metric_dev, _ = eval_word(
                        models=[model],
                        dataset=dev93_data,
                        beam_width=1,
                        max_decode_len=MAX_DECODE_LEN_WORD,
                        eval_batch_size=1)
                    logger.info('  WER (dev93, main): %.3f %%' %
                                (metric_dev * 100))
                else:
                    wer_dev_sub, metric_dev, _ = eval_char(
                        models=[model],
                        dataset=dev93_data,
                        beam_width=1,
                        max_decode_len=MAX_DECODE_LEN_CHAR,
                        eval_batch_size=1)
                    logger.info('  WER / CER (dev93, sub): %.3f / %.3f %%' %
                                ((wer_dev_sub * 100), (metric_dev * 100)))

                if metric_dev < metric_dev_best:
                    metric_dev_best = metric_dev
                    not_improved_epoch = 0
                    best_model = copy.deepcopy(model)
                    logger.info('||||| Best Score |||||')

                    # Save the model
                    model.save_checkpoint(model.save_path, epoch, step,
                                          learning_rate, metric_dev_best)

                    # test
                    if model.main_loss_weight > 0:
                        wer_eval92, _ = eval_word(
                            models=[model],
                            dataset=eval92_data,
                            beam_width=1,
                            max_decode_len=MAX_DECODE_LEN_WORD,
                            eval_batch_size=1)
                        logger.info('  WER (eval92, main): %.3f %%' %
                                    (wer_eval92 * 100))
                    else:
                        wer_eval92_sub, cer_eval92_sub, _ = eval_char(
                            models=[model],
                            dataset=eval92_data,
                            beam_width=1,
                            max_decode_len=MAX_DECODE_LEN_CHAR,
                            eval_batch_size=1)
                        logger.info(
                            ' WER / CER (eval92, sub): %.3f / %.3f %%' %
                            ((wer_eval92_sub * 100), (cer_eval92_sub * 100)))
                else:
                    not_improved_epoch += 1

                duration_eval = time.time() - start_time_eval
                logger.info('Evaluation time: %.3f min' % (duration_eval / 60))

                # Early stopping
                if not_improved_epoch == params['not_improved_patient_epoch']:
                    break

                # Update learning rate
                model.optimizer, learning_rate = lr_controller.decay_lr(
                    optimizer=model.optimizer,
                    learning_rate=learning_rate,
                    epoch=epoch,
                    value=metric_dev)

                if epoch == params['convert_to_sgd_epoch']:
                    # Convert to fine-tuning stage
                    model.set_optimizer(
                        'sgd',
                        learning_rate_init=learning_rate,
                        weight_decay=float(params['weight_decay']),
                        clip_grad_norm=params['clip_grad_norm'],
                        lr_schedule=False,
                        factor=params['decay_rate'],
                        patience_epoch=params['decay_patient_epoch'])
                    logger.info('========== Convert to SGD ==========')

                    # Inject Gaussian noise to all parameters
                    if float(params['weight_noise_std']) > 0:
                        model.weight_noise_injection = True

            pbar_epoch = tqdm(total=len(train_data))
            print('========== EPOCH:%d (%.3f min) ==========' %
                  (epoch, duration_epoch / 60))

            if epoch == params['num_epoch']:
                break

            start_time_step = time.time()
            start_time_epoch = time.time()
            epoch += 1

    duration_train = time.time() - start_time_train
    logger.info('Total time: %.3f hour' % (duration_train / 3600))

    if params['backend'] == 'pytorch':
        tf_writer.close()

    # TODO: evaluate the best model by beam search here

    # Training was finished correctly
    with open(os.path.join(model.save_path, 'COMPLETE'), 'w') as f:
        f.write('')
def main():

    a2c_oracle = False
    resolving_unk = False

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Setting for logging
    logger = set_logger(args.model_path)

    for i, data_type in enumerate(
        ['dev_clean', 'dev_other', 'test_clean', 'test_other']):
        # Load dataset
        dataset = Dataset(data_save_path=args.data_save_path,
                          backend=params['backend'],
                          input_freq=params['input_freq'],
                          use_delta=params['use_delta'],
                          use_double_delta=params['use_double_delta'],
                          data_type=data_type,
                          data_size=params['data_size'],
                          label_type=params['label_type'],
                          label_type_sub=params['label_type_sub'],
                          batch_size=args.eval_batch_size,
                          splice=params['splice'],
                          num_stack=params['num_stack'],
                          num_skip=params['num_skip'],
                          sort_utt=False,
                          tool=params['tool'])

        if i == 0:
            params['num_classes'] = dataset.num_classes
            params['num_classes_sub'] = dataset.num_classes_sub

            # Load model
            model = load(model_type=params['model_type'],
                         params=params,
                         backend=params['backend'])

            # Restore the saved parameters
            epoch, _, _, _ = model.load_checkpoint(save_path=args.model_path,
                                                   epoch=args.epoch)

            # GPU setting
            model.set_cuda(deterministic=False, benchmark=True)

            logger.info('beam width (main): %d' % args.beam_width)
            logger.info('beam width (sub) : %d' % args.beam_width_sub)
            logger.info('epoch: %d' % (epoch - 1))
            logger.info('a2c oracle: %s' % str(a2c_oracle))
            logger.info('resolving_unk: %s' % str(resolving_unk))

        wer, df = eval_word(models=[model],
                            dataset=dataset,
                            eval_batch_size=args.eval_batch_size,
                            beam_width=args.beam_width,
                            max_decode_len=MAX_DECODE_LEN_WORD,
                            min_decode_len=MIN_DECODE_LEN_WORD,
                            beam_width_sub=args.beam_width_sub,
                            max_decode_len_sub=MAX_DECODE_LEN_CHAR,
                            min_decode_len_sub=MIN_DECODE_LEN_CHAR,
                            length_penalty=args.length_penalty,
                            coverage_penalty=args.coverage_penalty,
                            progressbar=True,
                            resolving_unk=resolving_unk,
                            a2c_oracle=a2c_oracle)
        logger.info('  WER (%s, main): %.3f %%' % (dataset.data_type,
                                                   (wer * 100)))
        logger.info(df)

        wer_sub, cer_sub, df_sub = eval_char(
            models=[model],
            dataset=dataset,
            eval_batch_size=args.eval_batch_size,
            beam_width=args.beam_width_sub,
            max_decode_len=MAX_DECODE_LEN_CHAR,
            min_decode_len=MIN_DECODE_LEN_CHAR,
            length_penalty=args.length_penalty,
            coverage_penalty=args.coverage_penalty,
            progressbar=True)
        logger.info(' WER / CER (%s, sub): %.3f / %.3f %%' %
                    (dataset.data_type, (wer_sub * 100), (cer_sub * 100)))
        logger.info(df_sub)
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Load dataset
    test_data = Dataset(data_save_path=args.data_save_path,
                        backend=params['backend'],
                        input_freq=params['input_freq'],
                        use_delta=params['use_delta'],
                        use_double_delta=params['use_double_delta'],
                        data_type='test_eval92',
                        data_size=params['data_size'],
                        label_type=params['label_type'],
                        label_type_sub=params['label_type_sub'],
                        batch_size=args.eval_batch_size,
                        splice=params['splice'],
                        num_stack=params['num_stack'],
                        num_skip=params['num_skip'],
                        sort_utt=False,
                        tool=params['tool'])

    params['num_classes'] = test_data.num_classes
    params['num_classes_sub'] = test_data.num_classes_sub

    # Load model
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    # Restore the saved parameters
    model.load_checkpoint(save_path=args.model_path, epoch=args.epoch)

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    a2c_oracle = False
    resolving_unk = True

    print('beam width (main): %d' % args.beam_width)
    print('beam width (sub) : %d' % args.beam_width_sub)
    print('a2c oracle: %s' % str(a2c_oracle))
    print('resolving_unk: %s' % str(resolving_unk))

    wer_eval92, df_eval92 = eval_word(models=[model],
                                      dataset=test_data,
                                      beam_width=args.beam_width,
                                      beam_width_sub=args.beam_width_sub,
                                      max_decode_len=MAX_DECODE_LEN_WORD,
                                      max_decode_len_sub=MAX_DECODE_LEN_CHAR,
                                      eval_batch_size=args.eval_batch_size,
                                      progressbar=True,
                                      resolving_unk=resolving_unk,
                                      a2c_oracle=a2c_oracle)
    print('  WER (eval92, main): %.3f %%' % (wer_eval92 * 100))
    print(df_eval92)
    wer_eval92_sub, cer_eval92_sub, df_eval92_sub = eval_char(
        models=[model],
        dataset=test_data,
        beam_width=args.beam_width_sub,
        max_decode_len=MAX_DECODE_LEN_CHAR,
        eval_batch_size=args.eval_batch_size,
        progressbar=True)
    print(' WER / CER (eval92, sub): %.3f / %.3f %%' %
          ((wer_eval92_sub * 100), (cer_eval92_sub * 100)))
    print(df_eval92_sub)