def main():

    args = parser.parse_args()

    ##################################################
    # DATSET
    ##################################################
    if args.model_save_path is not None:
        # Load a config file (.yml)
        params = load_config(args.config_path)
    # NOTE: Retrain the saved model from the last checkpoint
    elif args.saved_model_path is not None:
        params = load_config(os.path.join(args.saved_model_path, 'config.yml'))
    else:
        raise ValueError("Set model_save_path or saved_model_path.")

    # Load dataset
    train_data = Dataset(data_save_path=args.data_save_path,
                         backend=params['backend'],
                         input_freq=params['input_freq'],
                         use_delta=params['use_delta'],
                         use_double_delta=params['use_double_delta'],
                         data_type='train',
                         data_size=params['data_size'],
                         label_type=params['label_type'],
                         batch_size=params['batch_size'],
                         max_epoch=params['num_epoch'],
                         splice=params['splice'],
                         num_stack=params['num_stack'],
                         num_skip=params['num_skip'],
                         min_frame_num=params['min_frame_num'],
                         sort_utt=True,
                         sort_stop_epoch=params['sort_stop_epoch'],
                         tool=params['tool'],
                         num_enque=None,
                         dynamic_batching=params['dynamic_batching'])
    dev_data = Dataset(data_save_path=args.data_save_path,
                       backend=params['backend'],
                       input_freq=params['input_freq'],
                       use_delta=params['use_delta'],
                       use_double_delta=params['use_double_delta'],
                       data_type='dev',
                       data_size=params['data_size'],
                       label_type=params['label_type'],
                       batch_size=params['batch_size'],
                       splice=params['splice'],
                       num_stack=params['num_stack'],
                       num_skip=params['num_skip'],
                       shuffle=True,
                       tool=params['tool'])
    eval1_data = Dataset(data_save_path=args.data_save_path,
                         backend=params['backend'],
                         input_freq=params['input_freq'],
                         use_delta=params['use_delta'],
                         use_double_delta=params['use_double_delta'],
                         data_type='eval1',
                         data_size=params['data_size'],
                         label_type=params['label_type'],
                         batch_size=params['batch_size'],
                         splice=params['splice'],
                         num_stack=params['num_stack'],
                         num_skip=params['num_skip'],
                         tool=params['tool'])

    params['num_classes'] = train_data.num_classes
    params['num_classes_sub'] = train_data.num_classes

    ##################################################
    # MODEL
    ##################################################
    # Model setting
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    if args.model_save_path is not None:

        # Set save path
        save_path = mkdir_join(args.model_save_path, params['backend'],
                               params['model_type'], params['label_type'],
                               params['data_size'], model.name)
        model.set_save_path(save_path)

        # Save config file
        save_config(config_path=args.config_path, save_path=model.save_path)

        # Setting for logging
        logger = set_logger(model.save_path)

        if os.path.isdir(params['char_init']):
            # NOTE: Start training from the pre-trained character model
            model.load_checkpoint(save_path=params['char_init'],
                                  epoch=-1,
                                  load_pretrained_model=True)

        # Count total parameters
        for name in sorted(list(model.num_params_dict.keys())):
            num_params = model.num_params_dict[name]
            logger.info("%s %d" % (name, num_params))
        logger.info("Total %.3f M parameters" %
                    (model.total_parameters / 1000000))

        # Define optimizer
        model.set_optimizer(optimizer=params['optimizer'],
                            learning_rate_init=float(params['learning_rate']),
                            weight_decay=float(params['weight_decay']),
                            clip_grad_norm=params['clip_grad_norm'],
                            lr_schedule=False,
                            factor=params['decay_rate'],
                            patience_epoch=params['decay_patient_epoch'])

        epoch, step = 1, 0
        learning_rate = float(params['learning_rate'])
        metric_dev_best = 1

    # NOTE: Retrain the saved model from the last checkpoint
    elif args.saved_model_path is not None:

        # Set save path
        model.save_path = args.saved_model_path

        # Setting for logging
        logger = set_logger(model.save_path, restart=True)

        # Define optimizer
        model.set_optimizer(
            optimizer=params['optimizer'],
            learning_rate_init=float(params['learning_rate']),  # on-the-fly
            weight_decay=float(params['weight_decay']),
            clip_grad_norm=params['clip_grad_norm'],
            lr_schedule=False,
            factor=params['decay_rate'],
            patience_epoch=params['decay_patient_epoch'])

        # Restore the last saved model
        epoch, step, learning_rate, metric_dev_best = model.load_checkpoint(
            save_path=args.saved_model_path, epoch=-1, restart=True)

    else:
        raise ValueError("Set model_save_path or saved_model_path.")

    train_data.epoch = epoch - 1

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    logger.info('PID: %s' % os.getpid())
    logger.info('USERNAME: %s' % os.uname()[1])

    # Set process name
    setproctitle('csj_' + params['backend'] + '_' + params['model_type'] +
                 '_' + params['label_type'] + '_' + params['data_size'])

    ##################################################
    # TRAINING LOOP
    ##################################################
    # Define learning rate controller
    lr_controller = Controller(
        learning_rate_init=learning_rate,
        backend=params['backend'],
        decay_start_epoch=params['decay_start_epoch'],
        decay_rate=params['decay_rate'],
        decay_patient_epoch=params['decay_patient_epoch'],
        lower_better=True)

    # Setting for tensorboard
    if params['backend'] == 'pytorch':
        tf_writer = SummaryWriter(model.save_path)

    # Train model
    csv_steps, csv_loss_train, csv_loss_dev = [], [], []
    start_time_train = time.time()
    start_time_epoch = time.time()
    start_time_step = time.time()
    not_improved_epoch = 0
    best_model = model
    loss_train_mean = 0.
    pbar_epoch = tqdm(total=len(train_data))
    while True:
        # Compute loss in the training set (including parameter update)
        batch_train, is_new_epoch = train_data.next()
        model, loss_train = train_step(model, batch_train,
                                       params['clip_grad_norm'],
                                       params['backend'])
        loss_train_mean += loss_train

        pbar_epoch.update(len(batch_train['xs']))

        if (step + 1) % params['print_step'] == 0:

            # Compute loss in the dev set
            batch_dev = dev_data.next()[0]
            loss_dev = model(batch_dev['xs'],
                             batch_dev['ys'],
                             batch_dev['x_lens'],
                             batch_dev['y_lens'],
                             is_eval=True)

            loss_train_mean /= params['print_step']
            csv_steps.append(step)
            csv_loss_train.append(loss_train_mean)
            csv_loss_dev.append(loss_dev)

            # Logging by tensorboard
            if params['backend'] == 'pytorch':
                tf_writer.add_scalar('train/loss', loss_train_mean, step + 1)
                tf_writer.add_scalar('dev/loss', loss_dev, step + 1)
                # for name, param in model.named_parameters():
                #     name = name.replace('.', '/')
                #     tf_writer.add_histogram(
                #         name, param.data.cpu().numpy(), step + 1)
                #     tf_writer.add_histogram(
                #         name + '/grad', param.grad.data.cpu().numpy(), step + 1)
                # TODO: fix this

            duration_step = time.time() - start_time_step
            logger.info(
                "...Step:%d(epoch:%.3f) loss:%.3f(%.3f)/lr:%.5f/batch:%d/x_lens:%d (%.3f min)"
                % (step + 1, train_data.epoch_detail, loss_train_mean,
                   loss_dev, learning_rate, train_data.current_batch_size,
                   max(batch_train['x_lens']) * params['num_stack'],
                   duration_step / 60))
            start_time_step = time.time()
            loss_train_mean = 0.
        step += 1

        # Save checkpoint and evaluate model per epoch
        if is_new_epoch:
            duration_epoch = time.time() - start_time_epoch
            logger.info('===== EPOCH:%d (%.3f min) =====' %
                        (epoch, duration_epoch / 60))

            # Save fugure of loss
            plot_loss(csv_loss_train,
                      csv_loss_dev,
                      csv_steps,
                      save_path=model.save_path)

            if epoch < params['eval_start_epoch']:
                # Save the model
                model.save_checkpoint(model.save_path, epoch, step,
                                      learning_rate, metric_dev_best)
            else:
                start_time_eval = time.time()
                # dev
                if params['label_type'] == 'word':
                    metric_dev, _ = eval_word(
                        models=[model],
                        dataset=dev_data,
                        eval_batch_size=1,
                        beam_width=1,
                        max_decode_len=MAX_DECODE_LEN_WORD)
                    logger.info('  WER (dev): %.3f %%' % (metric_dev * 100))
                else:
                    wer_dev, metric_dev, _ = eval_char(
                        models=[model],
                        dataset=dev_data,
                        eval_batch_size=1,
                        beam_width=1,
                        max_decode_len=MAX_DECODE_LEN_CHAR)
                    logger.info('  WER / CER (dev): %.3f / %.3f %%' %
                                ((wer_dev * 100), (metric_dev * 100)))

                if metric_dev < metric_dev_best:
                    metric_dev_best = metric_dev
                    not_improved_epoch = 0
                    best_model = copy.deepcopy(model)
                    logger.info('||||| Best Score |||||')

                    # Save the model
                    model.save_checkpoint(model.save_path, epoch, step,
                                          learning_rate, metric_dev_best)

                    # test
                    if params['label_type'] == 'word':
                        wer_eval1, _ = eval_word(
                            models=[model],
                            dataset=eval1_data,
                            eval_batch_size=1,
                            beam_width=1,
                            max_decode_len=MAX_DECODE_LEN_CHAR)
                        logger.info('  WER (eval1): %.3f %%' %
                                    (wer_eval1 * 100))
                    else:
                        wer_eval1, cer_eval1, _ = eval_char(
                            models=[model],
                            dataset=eval1_data,
                            eval_batch_size=1,
                            beam_width=1,
                            max_decode_len=MAX_DECODE_LEN_CHAR)
                        logger.info('  WER / CER (eval1): %.3f / %.3f %%' %
                                    ((wer_eval1 * 100), (cer_eval1 * 100)))
                else:
                    not_improved_epoch += 1

                duration_eval = time.time() - start_time_eval
                logger.info('Evaluation time: %.3f min' % (duration_eval / 60))

                # Early stopping
                if not_improved_epoch == params['not_improved_patient_epoch']:
                    break

                # Update learning rate
                model.optimizer, learning_rate = lr_controller.decay_lr(
                    optimizer=model.optimizer,
                    learning_rate=learning_rate,
                    epoch=epoch,
                    value=metric_dev)

                if epoch == params['convert_to_sgd_epoch']:
                    # Convert to fine-tuning stage
                    model.set_optimizer(
                        'sgd',
                        learning_rate_init=learning_rate,
                        weight_decay=float(params['weight_decay']),
                        clip_grad_norm=params['clip_grad_norm'],
                        lr_schedule=False,
                        factor=params['decay_rate'],
                        patience_epoch=params['decay_patient_epoch'])
                    logger.info('========== Convert to SGD ==========')

                    # Inject Gaussian noise to all parameters
                    if float(params['weight_noise_std']) > 0:
                        model.weight_noise_injection = True

            pbar_epoch = tqdm(total=len(train_data))
            print('========== EPOCH:%d (%.3f min) ==========' %
                  (epoch, duration_epoch / 60))

            if epoch == params['num_epoch']:
                break

            start_time_step = time.time()
            start_time_epoch = time.time()
            epoch += 1

    # TODO: evaluate the best model by beam search here

    duration_train = time.time() - start_time_train
    logger.info('Total time: %.3f hour' % (duration_train / 3600))

    if params['backend'] == 'pytorch':
        tf_writer.close()

    # Training was finished correctly
    with open(os.path.join(model.save_path, 'COMPLETE'), 'w') as f:
        f.write('')
Example #2
0
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    wer_mean, cer_mean = 0, 0
    with open(join(args.model_path, 'RESULTS'), 'w') as f:
        for i, data_type in enumerate(['eval1', 'eval2', 'eval3']):
            # Load dataset
            eval_data = Dataset(
                data_save_path=args.data_save_path,
                backend=params['backend'],
                input_freq=params['input_freq'],
                use_delta=params['use_delta'],
                use_double_delta=params['use_double_delta'],
                data_type=data_type, data_size=params['data_size'],
                label_type=params['label_type'],
                batch_size=args.eval_batch_size, splice=params['splice'],
                num_stack=params['num_stack'], num_skip=params['num_skip'],
                shuffle=False, tool=params['tool'])

            if i == 0:
                params['num_classes'] = eval_data.num_classes

                # Load model
                model = load(model_type=params['model_type'],
                             params=params,
                             backend=params['backend'])

                # Restore the saved parameters
                model.load_checkpoint(
                    save_path=args.model_path, epoch=args.epoch)

                # GPU setting
                model.set_cuda(deterministic=False, benchmark=True)

            print('beam width: %d' % args.beam_width)
            f.write('beam width: %d\n' % args.beam_width)

            if params['label_type'] == 'word':
                wer, df = eval_word(
                    models=[model],
                    dataset=eval_data,
                    eval_batch_size=args.eval_batch_size,
                    beam_width=args.beam_width,
                    max_decode_len=MAX_DECODE_LEN_WORD,
                    length_penalty=args.length_penalty,
                    progressbar=True)
                wer_mean += wer
                print('  WER (%s): %.3f %%' % (data_type, (wer * 100)))
                f.write('  WER (%s): %.3f %%' % (data_type, (wer * 100)))
                print(df)
            else:
                wer, cer, df = eval_char(
                    models=[model],
                    dataset=eval_data,
                    eval_batch_size=args.eval_batch_size,
                    beam_width=args.beam_width,
                    max_decode_len=MAX_DECODE_LEN_CHAR,
                    length_penalty=args.length_penalty,
                    progressbar=True)
                wer_mean += wer
                cer_mean += cer
                print(' WER / CER (%s, sub): %.3f / %.3f %%' %
                      (data_type, (wer * 100), (cer * 100)))
                f.write(' WER / CER (%s, sub): %.3f / %.3f %%' %
                        (data_type, (wer * 100), (cer * 100)))
                print(df)

        if params['label_type'] == 'word':
            print('  WER (mean): %.3f %%' % (wer * 100 / 3))
            f.write('  WER (mean): %.3f %%' % (wer * 100 / 3))
        else:
            print('  WER / CER (mean): %.3f / %.3f %%' %
                  ((wer * 100 / 3), (cer * 100 / 3)))
            f.write('  WER / CER (mean): %.3f / %.3f %%' %
                    ((wer * 100 / 3), (cer * 100 / 3)))
Example #3
0
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Setting for logging
    logger = set_logger(args.model_path)

    wer_mean, wer_sub_mean, cer_sub_mean = 0, 0, 0
    for i, data_type in enumerate(['eval1', 'eval2', 'eval3']):
        # Load dataset
        dataset = Dataset(data_save_path=args.data_save_path,
                          backend=params['backend'],
                          input_freq=params['input_freq'],
                          use_delta=params['use_delta'],
                          use_double_delta=params['use_double_delta'],
                          data_type=data_type,
                          data_size=params['data_size'],
                          label_type=params['label_type'],
                          label_type_sub=params['label_type_sub'],
                          batch_size=args.eval_batch_size,
                          splice=params['splice'],
                          num_stack=params['num_stack'],
                          num_skip=params['num_skip'],
                          shuffle=False,
                          tool=params['tool'])

        if i == 0:
            params['num_classes'] = dataset.num_classes
            params['num_classes_sub'] = dataset.num_classes_sub

            # Load model
            model = load(model_type=params['model_type'],
                         params=params,
                         backend=params['backend'])

            # Restore the saved parameters
            epoch, _, _, _ = model.load_checkpoint(save_path=args.model_path,
                                                   epoch=args.epoch)

            # GPU setting
            model.set_cuda(deterministic=False, benchmark=True)

            logger.info('beam width (main): %d\n' % args.beam_width)
            logger.info('beam width (sub) : %d\n' % args.beam_width_sub)
            logger.info('epoch: %d' % (epoch - 1))
            logger.info('a2c oracle: %s\n' % str(args.a2c_oracle))
            logger.info('resolving_unk: %s\n' % str(args.resolving_unk))
            logger.info('joint_decoding: %s\n' % str(args.joint_decoding))
            logger.info('score_sub_weight : %f' % args.score_sub_weight)

        wer, df = eval_word(models=[model],
                            dataset=dataset,
                            eval_batch_size=args.eval_batch_size,
                            beam_width=args.beam_width,
                            max_decode_len=MAX_DECODE_LEN_WORD,
                            min_decode_len=MIN_DECODE_LEN_WORD,
                            beam_width_sub=args.beam_width_sub,
                            max_decode_len_sub=MAX_DECODE_LEN_CHAR,
                            min_decode_len_sub=MIN_DECODE_LEN_CHAR,
                            length_penalty=args.length_penalty,
                            coverage_penalty=args.coverage_penalty,
                            progressbar=True,
                            resolving_unk=args.resolving_unk,
                            a2c_oracle=args.a2c_oracle,
                            joint_decoding=args.joint_decoding,
                            score_sub_weight=args.score_sub_weight)
        wer_mean += wer
        logger.info('  WER (%s, main): %.3f %%' % (data_type, (wer * 100)))
        logger.info(df)

        wer_sub, cer_sub, df_sub = eval_char(
            models=[model],
            dataset=dataset,
            eval_batch_size=args.eval_batch_size,
            beam_width=args.beam_width_sub,
            max_decode_len=MAX_DECODE_LEN_CHAR,
            min_decode_len=MIN_DECODE_LEN_CHAR,
            length_penalty=args.length_penalty,
            coverage_penalty=args.coverage_penalty,
            progressbar=True)
        wer_sub_mean += wer_sub
        cer_sub_mean += cer_sub
        logger.info(' WER / CER (%s, sub): %.3f / %.3f %%' % (data_type,
                                                              (wer_sub * 100),
                                                              (cer_sub * 100)))
        logger.info(df_sub)

    logger.info('  WER (mean, main): %.3f %%' % (wer_mean * 100 / 3))
    logger.info('  WER / CER (mean, sub): %.3f / %.3f %%' %
                ((wer_sub_mean * 100 / 3), (cer_sub_mean * 100 / 3)))