def do_training(args, module, data_train, data_val, begin_epoch=0):
    from distutils.dir_util import mkpath
    from log_util import LogUtil

    log = LogUtil().getlogger()
    mkpath(os.path.dirname(get_checkpoint_path(args)))

    seq_len = args.config.get('arch', 'max_t_count')
    batch_size = args.config.getint('common', 'batch_size')
    save_checkpoint_every_n_epoch = args.config.getint(
        'common', 'save_checkpoint_every_n_epoch')

    contexts = parse_contexts(args)
    num_gpu = len(contexts)
    eval_metric = STTMetric(batch_size=batch_size,
                            num_gpu=num_gpu,
                            seq_length=seq_len)

    optimizer = args.config.get('train', 'optimizer')
    momentum = args.config.getfloat('train', 'momentum')
    learning_rate = args.config.getfloat('train', 'learning_rate')
    lr_scheduler = SimpleLRScheduler(learning_rate,
                                     momentum=momentum,
                                     optimizer=optimizer)

    n_epoch = begin_epoch
    num_epoch = args.config.getint('train', 'num_epoch')
    clip_gradient = args.config.getfloat('train', 'clip_gradient')
    weight_decay = args.config.getfloat('train', 'weight_decay')
    save_optimizer_states = args.config.getboolean('train',
                                                   'save_optimizer_states')
    show_every = args.config.getint('train', 'show_every')

    if clip_gradient == 0:
        clip_gradient = None

    module.bind(data_shapes=data_train.provide_data,
                label_shapes=data_train.provide_label,
                for_training=True)

    if begin_epoch == 0:
        module.init_params(initializer=get_initializer(args))

    def reset_optimizer():
        module.init_optimizer(kvstore='device',
                              optimizer=args.config.get('train', 'optimizer'),
                              optimizer_params={
                                  'clip_gradient': clip_gradient,
                                  'wd': weight_decay
                              },
                              force_init=True)

    reset_optimizer()

    while True:

        if n_epoch >= num_epoch:
            break

        eval_metric.reset()
        log.info('---------train---------')
        for nbatch, data_batch in enumerate(data_train):

            if data_batch.effective_sample_count is not None:
                lr_scheduler.effective_sample_count = data_batch.effective_sample_count

            module.forward_backward(data_batch)
            module.update()
            if (nbatch + 1) % show_every == 0:
                module.update_metric(eval_metric, data_batch.label)
        # commented for Libri_sample data set to see only train cer
        log.info('---------validation---------')
        for nbatch, data_batch in enumerate(data_val):
            module.update_metric(eval_metric, data_batch.label)
        #module.score(eval_data=data_val, num_batch=None, eval_metric=eval_metric, reset=True)

        data_train.reset()
        # save checkpoints
        if n_epoch % save_checkpoint_every_n_epoch == 0:
            log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch)
            module.save_checkpoint(prefix=get_checkpoint_path(args),
                                   epoch=n_epoch,
                                   save_optimizer_states=save_optimizer_states)

        n_epoch += 1

    log.info('FINISH')
Exemple #2
0
    decoding_method = args.config.get('train', 'method')
    contexts = parse_contexts(args)

    init_states, test_sets, label_mean_sets = prepare_data(args)
    state_names = [x[0] for x in init_states]

    batch_size = args.config.getint('train', 'batch_size')
    num_hidden = args.config.getint('arch', 'num_hidden')
    num_hidden_proj = args.config.getint('arch', 'num_hidden_proj')
    num_lstm_layer = args.config.getint('arch', 'num_lstm_layer')
    feat_dim = args.config.getint('data', 'xdim')
    label_dim = args.config.getint('data', 'ydim')
    out_file = args.config.get('data', 'out_file')
    out_dir = args.config.get('data', 'out_dir')
    num_epoch = args.config.getint('train', 'num_epoch')
    model_name = get_checkpoint_path(args)
    logging.basicConfig(level=logging.DEBUG,
                        format='%(asctime)-15s %(message)s')

    # load the model
    sym, arg_params, aux_params = mx.model.load_checkpoint(
        model_name, num_epoch)

    if decoding_method == METHOD_BUCKETING:
        buckets = args.config.get('train', 'buckets')
        buckets = list(map(int, re.split(r'\W+', buckets)))
        data_test = BucketSentenceIter(test_sets,
                                       buckets,
                                       batch_size,
                                       init_states,
                                       feat_dim=feat_dim,
Exemple #3
0
def do_training(training_method, args, module, data_train, data_val):
    from distutils.dir_util import mkpath
    mkpath(os.path.dirname(get_checkpoint_path(args)))

    batch_size = data_train.batch_size
    batch_end_callbacks = [mx.callback.Speedometer(batch_size, 
                                                   args.config.getint('train', 'show_every'))]
    eval_allow_extra = True if training_method == METHOD_TBPTT else False
    eval_metric = [mx.metric.np(CrossEntropy, allow_extra_outputs=eval_allow_extra),
                   mx.metric.np(Acc_exclude_padding, allow_extra_outputs=eval_allow_extra)]
    eval_metric = mx.metric.create(eval_metric)
    optimizer = args.config.get('train', 'optimizer')
    momentum = args.config.getfloat('train', 'momentum')
    learning_rate = args.config.getfloat('train', 'learning_rate')
    lr_scheduler = SimpleLRScheduler(learning_rate, momentum=momentum, optimizer=optimizer)

    if training_method == METHOD_TBPTT:
        lr_scheduler.seq_len = data_train.truncate_len

    n_epoch = 0
    num_epoch = args.config.getint('train', 'num_epoch')
    learning_rate = args.config.getfloat('train', 'learning_rate')
    decay_factor = args.config.getfloat('train', 'decay_factor')
    decay_bound = args.config.getfloat('train', 'decay_lower_bound')
    clip_gradient = args.config.getfloat('train', 'clip_gradient')
    weight_decay = args.config.getfloat('train', 'weight_decay')
    if clip_gradient == 0:
        clip_gradient = None

    last_acc = -float("Inf")
    last_params = None

    module.bind(data_shapes=data_train.provide_data,
                label_shapes=data_train.provide_label,
                for_training=True)
    module.init_params(initializer=get_initializer(args))

    def reset_optimizer():
        if optimizer == "sgd" or optimizer == "speechSGD":
            module.init_optimizer(kvstore='device',
                              optimizer=args.config.get('train', 'optimizer'),
                              optimizer_params={'lr_scheduler': lr_scheduler,
                                                'momentum': momentum,
                                                'rescale_grad': 1.0,
                                                'clip_gradient': clip_gradient,
                                                'wd': weight_decay},
                              force_init=True)
        else:
            module.init_optimizer(kvstore='device',
                              optimizer=args.config.get('train', 'optimizer'),
                              optimizer_params={'lr_scheduler': lr_scheduler,
                                                'rescale_grad': 1.0,
                                                'clip_gradient': clip_gradient,
                                                'wd': weight_decay},
                              force_init=True)
    reset_optimizer()

    while True:
        tic = time.time()
        eval_metric.reset()

        for nbatch, data_batch in enumerate(data_train):
            if training_method == METHOD_TBPTT:
                lr_scheduler.effective_sample_count = data_train.batch_size * truncate_len
                lr_scheduler.momentum = np.power(np.power(momentum, 1.0/(data_train.batch_size * truncate_len)), data_batch.effective_sample_count)
            else:
                if data_batch.effective_sample_count is not None:
                    lr_scheduler.effective_sample_count = 1#data_batch.effective_sample_count

            module.forward_backward(data_batch)
            module.update()
            module.update_metric(eval_metric, data_batch.label)

            batch_end_params = mx.model.BatchEndParam(epoch=n_epoch, nbatch=nbatch,
                                                      eval_metric=eval_metric,
                                                      locals=None)
            for callback in batch_end_callbacks:
                callback(batch_end_params)

            if training_method == METHOD_TBPTT:
                # copy over states
                outputs = module.get_outputs()
                # outputs[0] is softmax, 1:end are states
                for i in range(1, len(outputs)):
                    outputs[i].copyto(data_train.init_state_arrays[i-1])

        for name, val in eval_metric.get_name_value():
            logging.info('Epoch[%d] Train-%s=%f', n_epoch, name, val)
        toc = time.time()
        logging.info('Epoch[%d] Time cost=%.3f', n_epoch, toc-tic)

        data_train.reset()

        # test on eval data
        score_with_state_forwarding(module, data_val, eval_metric)

        # test whether we should decay learning rate
        curr_acc = None
        for name, val in eval_metric.get_name_value():
            logging.info("Epoch[%d] Dev-%s=%f", n_epoch, name, val)
            if name == 'CrossEntropy':
                curr_acc = val
        assert curr_acc is not None, 'cannot find Acc_exclude_padding in eval metric'

        if n_epoch > 0 and lr_scheduler.dynamic_lr > decay_bound and curr_acc > last_acc:
            logging.info('Epoch[%d] !!! Dev set performance drops, reverting this epoch',
                         n_epoch)
            logging.info('Epoch[%d] !!! LR decay: %g => %g', n_epoch,
                         lr_scheduler.dynamic_lr, lr_scheduler.dynamic_lr / float(decay_factor))

            lr_scheduler.dynamic_lr /= decay_factor
            # we reset the optimizer because the internal states (e.g. momentum)
            # might already be exploded, so we want to start from fresh
            reset_optimizer()
            module.set_params(*last_params)
        else:
            last_params = module.get_params()
            last_acc = curr_acc
            n_epoch += 1

            # save checkpoints
            mx.model.save_checkpoint(get_checkpoint_path(args), n_epoch,
                                     module.symbol, *last_params)

        if n_epoch == num_epoch:
            break
Exemple #4
0
def do_training(args, module, data_train, data_val, begin_epoch=0):
    from distutils.dir_util import mkpath
    from log_util import LogUtil

    log = LogUtil().getlogger()
    mkpath(os.path.dirname(get_checkpoint_path(args)))

    seq_len = args.config.get('arch', 'max_t_count')
    batch_size = args.config.getint('common', 'batch_size')
    save_checkpoint_every_n_epoch = args.config.getint('common', 'save_checkpoint_every_n_epoch')

    contexts = parse_contexts(args)
    num_gpu = len(contexts)
    eval_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, seq_length=seq_len)

    optimizer = args.config.get('train', 'optimizer')
    momentum = args.config.getfloat('train', 'momentum')
    learning_rate = args.config.getfloat('train', 'learning_rate')
    lr_scheduler = SimpleLRScheduler(learning_rate, momentum=momentum, optimizer=optimizer)

    n_epoch = begin_epoch
    num_epoch = args.config.getint('train', 'num_epoch')
    clip_gradient = args.config.getfloat('train', 'clip_gradient')
    weight_decay = args.config.getfloat('train', 'weight_decay')
    save_optimizer_states = args.config.getboolean('train', 'save_optimizer_states')
    show_every = args.config.getint('train', 'show_every')

    if clip_gradient == 0:
        clip_gradient = None

    module.bind(data_shapes=data_train.provide_data,
                label_shapes=data_train.provide_label,
                for_training=True)

    if begin_epoch == 0:
        module.init_params(initializer=get_initializer(args))

    def reset_optimizer():
        module.init_optimizer(kvstore='device',
                              optimizer=args.config.get('train', 'optimizer'),
                              optimizer_params={'clip_gradient': clip_gradient,
                                                'wd': weight_decay},
                              force_init=True)

    reset_optimizer()

    while True:

        if n_epoch >= num_epoch:
            break

        eval_metric.reset()
        log.info('---------train---------')
        for nbatch, data_batch in enumerate(data_train):

            if data_batch.effective_sample_count is not None:
                lr_scheduler.effective_sample_count = data_batch.effective_sample_count

            module.forward_backward(data_batch)
            module.update()
            if (nbatch+1) % show_every == 0:
                module.update_metric(eval_metric, data_batch.label)
        # commented for Libri_sample data set to see only train cer
        log.info('---------validation---------')
        for nbatch, data_batch in enumerate(data_val):
            module.update_metric(eval_metric, data_batch.label)
        #module.score(eval_data=data_val, num_batch=None, eval_metric=eval_metric, reset=True)

        data_train.reset()
        # save checkpoints
        if n_epoch % save_checkpoint_every_n_epoch == 0:
            log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch)
            module.save_checkpoint(prefix=get_checkpoint_path(args), epoch=n_epoch, save_optimizer_states=save_optimizer_states)

        n_epoch += 1

    log.info('FINISH')
Exemple #5
0
                    kv=kv)
    # if mode is 'predict', it predict label from the input by the input model
    elif mode == 'predict':
        # predict through data
        if is_bucketing:
            max_t_count = args.config.getint('arch', 'max_t_count')
            load_optimizer_states = args.config.getboolean(
                'load', 'load_optimizer_states')
            model_file = args.config.get('common', 'model_file')
            model_name = os.path.splitext(model_file)[0]
            model_num_epoch = int(model_name[-4:])

            model_path = 'checkpoints/' + str(model_name[:-5])
            prefix = args.config.get('common', 'prefix')
            if os.path.isabs(prefix):
                model_path = config_util.get_checkpoint_path(args).rsplit(
                    "/", 1)[0] + "/" + str(model_name[:-5])

            model = STTBucketingModule(
                sym_gen=model_loaded,
                default_bucket_key=data_train.default_bucket_key,
                context=contexts)

            model.bind(data_shapes=data_train.provide_data,
                       label_shapes=data_train.provide_label,
                       for_training=True)
            _, arg_params, aux_params = mx.model.load_checkpoint(
                model_path, model_num_epoch)
            model.set_params(arg_params, aux_params, allow_missing=True)
            model_loaded = model
        else:
            model_loaded.bind(for_training=False,
def do_training(training_method, args, module, data_train, data_val):
    from distutils.dir_util import mkpath
    mkpath(os.path.dirname(get_checkpoint_path(args)))

    batch_size = data_train.batch_size
    batch_end_callbacks = [
        mx.callback.Speedometer(batch_size,
                                args.config.getint('train', 'show_every'))
    ]
    eval_allow_extra = True if training_method == METHOD_TBPTT else False
    eval_metric = [
        mx.metric.np(CrossEntropy, allow_extra_outputs=eval_allow_extra),
        mx.metric.np(Acc_exclude_padding, allow_extra_outputs=eval_allow_extra)
    ]
    eval_metric = mx.metric.create(eval_metric)
    optimizer = args.config.get('train', 'optimizer')
    momentum = args.config.getfloat('train', 'momentum')
    learning_rate = args.config.getfloat('train', 'learning_rate')
    lr_scheduler = SimpleLRScheduler(learning_rate,
                                     momentum=momentum,
                                     optimizer=optimizer)

    if training_method == METHOD_TBPTT:
        lr_scheduler.seq_len = data_train.truncate_len

    n_epoch = 0
    num_epoch = args.config.getint('train', 'num_epoch')
    learning_rate = args.config.getfloat('train', 'learning_rate')
    decay_factor = args.config.getfloat('train', 'decay_factor')
    decay_bound = args.config.getfloat('train', 'decay_lower_bound')
    clip_gradient = args.config.getfloat('train', 'clip_gradient')
    weight_decay = args.config.getfloat('train', 'weight_decay')
    if clip_gradient == 0:
        clip_gradient = None

    last_acc = -float("Inf")
    last_params = None

    module.bind(data_shapes=data_train.provide_data,
                label_shapes=data_train.provide_label,
                for_training=True)
    module.init_params(initializer=get_initializer(args))

    def reset_optimizer():
        if optimizer == "sgd" or optimizer == "speechSGD":
            module.init_optimizer(kvstore='local',
                                  optimizer=args.config.get(
                                      'train', 'optimizer'),
                                  optimizer_params={
                                      'lr_scheduler': lr_scheduler,
                                      'momentum': momentum,
                                      'rescale_grad': 1.0,
                                      'clip_gradient': clip_gradient,
                                      'wd': weight_decay
                                  },
                                  force_init=True)
        else:
            module.init_optimizer(kvstore='local',
                                  optimizer=args.config.get(
                                      'train', 'optimizer'),
                                  optimizer_params={
                                      'lr_scheduler': lr_scheduler,
                                      'rescale_grad': 1.0,
                                      'clip_gradient': clip_gradient,
                                      'wd': weight_decay
                                  },
                                  force_init=True)

    reset_optimizer()

    while True:
        tic = time.time()
        eval_metric.reset()

        for nbatch, data_batch in enumerate(data_train):
            if training_method == METHOD_TBPTT:
                lr_scheduler.effective_sample_count = data_train.batch_size * truncate_len
                lr_scheduler.momentum = np.power(
                    np.power(momentum,
                             1.0 / (data_train.batch_size * truncate_len)),
                    data_batch.effective_sample_count)
            else:
                if data_batch.effective_sample_count is not None:
                    lr_scheduler.effective_sample_count = data_batch.effective_sample_count

            module.forward_backward(data_batch)
            module.update()
            module.update_metric(eval_metric, data_batch.label)

            batch_end_params = mx.model.BatchEndParam(epoch=n_epoch,
                                                      nbatch=nbatch,
                                                      eval_metric=eval_metric,
                                                      locals=None)
            for callback in batch_end_callbacks:
                callback(batch_end_params)

            if training_method == METHOD_TBPTT:
                # copy over states
                outputs = module.get_outputs()
                # outputs[0] is softmax, 1:end are states
                for i in range(1, len(outputs)):
                    outputs[i].copyto(data_train.init_state_arrays[i - 1])

        for name, val in eval_metric.get_name_value():
            logging.info('Epoch[%d] Train-%s=%f', n_epoch, name, val)
        toc = time.time()
        logging.info('Epoch[%d] Time cost=%.3f', n_epoch, toc - tic)

        data_train.reset()

        # test on eval data
        score_with_state_forwarding(module, data_val, eval_metric)

        # test whether we should decay learning rate
        curr_acc = None
        for name, val in eval_metric.get_name_value():
            logging.info("Epoch[%d] Dev-%s=%f", n_epoch, name, val)
            if name == 'CrossEntropy':
                curr_acc = val
        assert curr_acc is not None, 'cannot find Acc_exclude_padding in eval metric'

        if n_epoch > 0 and lr_scheduler.dynamic_lr > decay_bound and curr_acc > last_acc:
            logging.info(
                'Epoch[%d] !!! Dev set performance drops, reverting this epoch',
                n_epoch)
            logging.info('Epoch[%d] !!! LR decay: %g => %g', n_epoch,
                         lr_scheduler.dynamic_lr,
                         lr_scheduler.dynamic_lr / float(decay_factor))

            lr_scheduler.dynamic_lr /= decay_factor
            # we reset the optimizer because the internal states (e.g. momentum)
            # might already be exploded, so we want to start from fresh
            reset_optimizer()
            module.set_params(*last_params)
        else:
            last_params = module.get_params()
            last_acc = curr_acc
            n_epoch += 1

            # save checkpoints
            mx.model.save_checkpoint(get_checkpoint_path(args), n_epoch,
                                     module.symbol, *last_params)

        if n_epoch == num_epoch:
            break
def do_training(args, module, data_train, data_val, begin_epoch=0):
    from distutils.dir_util import mkpath
    from log_util import LogUtil

    log = LogUtil().getlogger()
    mkpath(os.path.dirname(get_checkpoint_path(args)))

    #seq_len = args.config.get('arch', 'max_t_count')
    batch_size = args.config.getint('common', 'batch_size')
    save_checkpoint_every_n_epoch = args.config.getint(
        'common', 'save_checkpoint_every_n_epoch')
    save_checkpoint_every_n_batch = args.config.getint(
        'common', 'save_checkpoint_every_n_batch')
    enable_logging_train_metric = args.config.getboolean(
        'train', 'enable_logging_train_metric')
    enable_logging_validation_metric = args.config.getboolean(
        'train', 'enable_logging_validation_metric')

    contexts = parse_contexts(args)
    num_gpu = len(contexts)
    eval_metric = STTMetric(batch_size=batch_size,
                            num_gpu=num_gpu,
                            is_logging=enable_logging_validation_metric,
                            is_epoch_end=True)
    # mxboard setting
    loss_metric = STTMetric(batch_size=batch_size,
                            num_gpu=num_gpu,
                            is_logging=enable_logging_train_metric,
                            is_epoch_end=False)

    optimizer = args.config.get('optimizer', 'optimizer')
    learning_rate = args.config.getfloat('train', 'learning_rate')
    learning_rate_annealing = args.config.getfloat('train',
                                                   'learning_rate_annealing')

    mode = args.config.get('common', 'mode')
    num_epoch = args.config.getint('train', 'num_epoch')
    clip_gradient = args.config.getfloat('optimizer', 'clip_gradient')
    weight_decay = args.config.getfloat('optimizer', 'weight_decay')
    save_optimizer_states = args.config.getboolean('train',
                                                   'save_optimizer_states')
    show_every = args.config.getint('train', 'show_every')
    optimizer_params_dictionary = json.loads(
        args.config.get('optimizer', 'optimizer_params_dictionary'))
    kvstore_option = args.config.get('common', 'kvstore_option')
    n_epoch = begin_epoch
    is_bucketing = args.config.getboolean('arch', 'is_bucketing')

    if clip_gradient == 0:
        clip_gradient = None
    if is_bucketing and mode == 'load':
        model_file = args.config.get('common', 'model_file')
        model_name = os.path.splitext(model_file)[0]
        model_num_epoch = int(model_name[-4:])

        model_path = 'checkpoints/' + str(model_name[:-5])
        symbol, data_names, label_names = module(1600)
        model = STTBucketingModule(
            sym_gen=module,
            default_bucket_key=data_train.default_bucket_key,
            context=contexts)
        data_train.reset()

        model.bind(data_shapes=data_train.provide_data,
                   label_shapes=data_train.provide_label,
                   for_training=True)
        _, arg_params, aux_params = mx.model.load_checkpoint(
            model_path, model_num_epoch)
        model.set_params(arg_params, aux_params)
        module = model
    else:
        module.bind(data_shapes=data_train.provide_data,
                    label_shapes=data_train.provide_label,
                    for_training=True)

    if begin_epoch == 0 and mode == 'train':
        module.init_params(initializer=get_initializer(args))

    lr_scheduler = SimpleLRScheduler(learning_rate=learning_rate)

    def reset_optimizer(force_init=False):
        optimizer_params = {
            'lr_scheduler': lr_scheduler,
            'clip_gradient': clip_gradient,
            'wd': weight_decay
        }
        optimizer_params.update(optimizer_params_dictionary)
        module.init_optimizer(kvstore=kvstore_option,
                              optimizer=optimizer,
                              optimizer_params=optimizer_params,
                              force_init=force_init)

    if mode == "train":
        reset_optimizer(force_init=True)
    else:
        reset_optimizer(force_init=False)
        data_train.reset()
        data_train.is_first_epoch = True

    #mxboard setting
    mxlog_dir = args.config.get('common', 'mxboard_log_dir')
    summary_writer = SummaryWriter(mxlog_dir)

    while True:

        if n_epoch >= num_epoch:
            break
        loss_metric.reset()
        log.info('---------train---------')
        for nbatch, data_batch in enumerate(data_train):
            module.forward_backward(data_batch)
            module.update()
            # mxboard setting
            if (nbatch + 1) % show_every == 0:
                module.update_metric(loss_metric, data_batch.label)
            #summary_writer.add_scalar('loss batch', loss_metric.get_batch_loss(), nbatch)
            if (nbatch + 1) % save_checkpoint_every_n_batch == 0:
                log.info('Epoch[%d] Batch[%d] SAVE CHECKPOINT', n_epoch,
                         nbatch)
                module.save_checkpoint(
                    prefix=get_checkpoint_path(args) + "n_epoch" +
                    str(n_epoch) + "n_batch",
                    epoch=(int(
                        (nbatch + 1) / save_checkpoint_every_n_batch) - 1),
                    save_optimizer_states=save_optimizer_states)
        # commented for Libri_sample data set to see only train cer
        log.info('---------validation---------')
        data_val.reset()
        eval_metric.reset()
        for nbatch, data_batch in enumerate(data_val):
            # when is_train = False it leads to high cer when batch_norm
            module.forward(data_batch, is_train=True)
            module.update_metric(eval_metric, data_batch.label)

        # mxboard setting
        val_cer, val_n_label, val_l_dist, _ = eval_metric.get_name_value()
        log.info("Epoch[%d] val cer=%f (%d / %d)", n_epoch, val_cer,
                 int(val_n_label - val_l_dist), val_n_label)
        curr_acc = val_cer
        summary_writer.add_scalar('CER validation', val_cer, n_epoch)
        assert curr_acc is not None, 'cannot find Acc_exclude_padding in eval metric'

        data_train.reset()
        data_train.is_first_epoch = False

        # mxboard setting
        train_cer, train_n_label, train_l_dist, train_ctc_loss = loss_metric.get_name_value(
        )
        summary_writer.add_scalar('loss epoch', train_ctc_loss, n_epoch)
        summary_writer.add_scalar('CER train', train_cer, n_epoch)

        # save checkpoints
        if n_epoch % save_checkpoint_every_n_epoch == 0:
            log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch)
            module.save_checkpoint(prefix=get_checkpoint_path(args),
                                   epoch=n_epoch,
                                   save_optimizer_states=save_optimizer_states)

        n_epoch += 1

        lr_scheduler.learning_rate = learning_rate / learning_rate_annealing

    log.info('FINISH')
Exemple #8
0
    args.config.write(sys.stderr)

    decoding_method = args.config.get('train', 'method')
    contexts = parse_contexts(args)

    init_states, test_sets = prepare_data(args)
    state_names = [x[0] for x in init_states]

    batch_size = args.config.getint('train', 'batch_size')
    num_hidden = args.config.getint('arch', 'num_hidden')
    num_lstm_layer = args.config.getint('arch', 'num_lstm_layer')
    feat_dim = args.config.getint('data', 'xdim')
    label_dim = args.config.getint('data', 'ydim')
    out_file = args.config.get('data', 'out_file')
    num_epoch = args.config.getint('train', 'num_epoch')
    model_name = get_checkpoint_path(args)
    logging.basicConfig(level=logging.DEBUG, format='%(asctime)-15s %(message)s')

    # load the model
    label_mean = np.zeros((label_dim,1), dtype='float32')
    data_test = TruncatedSentenceIter(test_sets, batch_size, init_states,
                                         20, feat_dim=feat_dim,
                                         do_shuffling=False, pad_zeros=True, has_label=True)

    for i, batch in enumerate(data_test.labels):
        hist, edges = np.histogram(batch.flat, bins=range(0,label_dim+1))
        label_mean += hist.reshape(label_dim,1)

    kaldiWriter = KaldiWriteOut(None, out_file)
    kaldiWriter.open_or_fd()
    kaldiWriter.write("label_mean", label_mean)
Exemple #9
0
def do_training(args, module, data_train, data_val, begin_epoch=0):
    from distutils.dir_util import mkpath
    from log_util import LogUtil

    log = LogUtil.getInstance().getlogger()
    mkpath(os.path.dirname(get_checkpoint_path(args)))

    #seq_len = args.config.get('arch', 'max_t_count')
    batch_size = args.config.getint('common', 'batch_size')
    save_checkpoint_every_n_epoch = args.config.getint('common', 'save_checkpoint_every_n_epoch')
    save_checkpoint_every_n_batch = args.config.getint('common', 'save_checkpoint_every_n_batch')
    enable_logging_train_metric = args.config.getboolean('train', 'enable_logging_train_metric')
    enable_logging_validation_metric = args.config.getboolean('train', 'enable_logging_validation_metric')

    contexts = parse_contexts(args)
    num_gpu = len(contexts)
    eval_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, is_logging=enable_logging_validation_metric,is_epoch_end=True)
    # mxboard setting
    loss_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, is_logging=enable_logging_train_metric,is_epoch_end=False)

    optimizer = args.config.get('optimizer', 'optimizer')
    learning_rate = args.config.getfloat('train', 'learning_rate')
    learning_rate_annealing = args.config.getfloat('train', 'learning_rate_annealing')

    mode = args.config.get('common', 'mode')
    num_epoch = args.config.getint('train', 'num_epoch')
    clip_gradient = args.config.getfloat('optimizer', 'clip_gradient')
    weight_decay = args.config.getfloat('optimizer', 'weight_decay')
    save_optimizer_states = args.config.getboolean('train', 'save_optimizer_states')
    show_every = args.config.getint('train', 'show_every')
    optimizer_params_dictionary = json.loads(args.config.get('optimizer', 'optimizer_params_dictionary'))
    kvstore_option = args.config.get('common', 'kvstore_option')
    n_epoch=begin_epoch
    is_bucketing = args.config.getboolean('arch', 'is_bucketing')

    if clip_gradient == 0:
        clip_gradient = None
    if is_bucketing and mode == 'load':
        model_file = args.config.get('common', 'model_file')
        model_name = os.path.splitext(model_file)[0]
        model_num_epoch = int(model_name[-4:])

        model_path = 'checkpoints/' + str(model_name[:-5])
        symbol, data_names, label_names = module(1600)
        model = STTBucketingModule(
            sym_gen=module,
            default_bucket_key=data_train.default_bucket_key,
            context=contexts)
        data_train.reset()

        model.bind(data_shapes=data_train.provide_data,
                   label_shapes=data_train.provide_label,
                   for_training=True)
        _, arg_params, aux_params = mx.model.load_checkpoint(model_path, model_num_epoch)
        model.set_params(arg_params, aux_params)
        module = model
    else:
        module.bind(data_shapes=data_train.provide_data,
                label_shapes=data_train.provide_label,
                for_training=True)

    if begin_epoch == 0 and mode == 'train':
        module.init_params(initializer=get_initializer(args))


    lr_scheduler = SimpleLRScheduler(learning_rate=learning_rate)

    def reset_optimizer(force_init=False):
        optimizer_params = {'lr_scheduler': lr_scheduler,
                            'clip_gradient': clip_gradient,
                            'wd': weight_decay}
        optimizer_params.update(optimizer_params_dictionary)
        module.init_optimizer(kvstore=kvstore_option,
                              optimizer=optimizer,
                              optimizer_params=optimizer_params,
                              force_init=force_init)
    if mode == "train":
        reset_optimizer(force_init=True)
    else:
        reset_optimizer(force_init=False)
        data_train.reset()
        data_train.is_first_epoch = True

    #mxboard setting
    mxlog_dir = args.config.get('common', 'mxboard_log_dir')
    summary_writer = SummaryWriter(mxlog_dir)

    while True:

        if n_epoch >= num_epoch:
            break
        loss_metric.reset()
        log.info('---------train---------')
        for nbatch, data_batch in enumerate(data_train):
            module.forward_backward(data_batch)
            module.update()
            # mxboard setting
            if (nbatch + 1) % show_every == 0:
                module.update_metric(loss_metric, data_batch.label)
            #summary_writer.add_scalar('loss batch', loss_metric.get_batch_loss(), nbatch)
            if (nbatch+1) % save_checkpoint_every_n_batch == 0:
                log.info('Epoch[%d] Batch[%d] SAVE CHECKPOINT', n_epoch, nbatch)
                module.save_checkpoint(prefix=get_checkpoint_path(args)+"n_epoch"+str(n_epoch)+"n_batch", epoch=(int((nbatch+1)/save_checkpoint_every_n_batch)-1), save_optimizer_states=save_optimizer_states)
        # commented for Libri_sample data set to see only train cer
        log.info('---------validation---------')
        data_val.reset()
        eval_metric.reset()
        for nbatch, data_batch in enumerate(data_val):
            # when is_train = False it leads to high cer when batch_norm
            module.forward(data_batch, is_train=True)
            module.update_metric(eval_metric, data_batch.label)

        # mxboard setting
        val_cer, val_n_label, val_l_dist, _ = eval_metric.get_name_value()
        log.info("Epoch[%d] val cer=%f (%d / %d)", n_epoch, val_cer, int(val_n_label - val_l_dist), val_n_label)
        curr_acc = val_cer
        summary_writer.add_scalar('CER validation', val_cer, n_epoch)
        assert curr_acc is not None, 'cannot find Acc_exclude_padding in eval metric'

        data_train.reset()
        data_train.is_first_epoch = False

        # mxboard setting
        train_cer, train_n_label, train_l_dist, train_ctc_loss = loss_metric.get_name_value()
        summary_writer.add_scalar('loss epoch', train_ctc_loss, n_epoch)
        summary_writer.add_scalar('CER train', train_cer, n_epoch)

        # save checkpoints
        if n_epoch % save_checkpoint_every_n_epoch == 0:
            log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch)
            module.save_checkpoint(prefix=get_checkpoint_path(args), epoch=n_epoch, save_optimizer_states=save_optimizer_states)

        n_epoch += 1

        lr_scheduler.learning_rate=learning_rate/learning_rate_annealing

    log.info('FINISH')
Exemple #10
0
def do_training(args, module, data_train, data_val, begin_epoch=0, kv=None):
    from distutils.dir_util import mkpath

    host_name = socket.gethostname()
    log = LogUtil().getlogger()
    mkpath(os.path.dirname(config_util.get_checkpoint_path(args)))

    # seq_len = args.config.get('arch', 'max_t_count')
    batch_size = args.config.getint('common', 'batch_size')
    val_batch_size = args.config.getint('common', 'val_batch_size')
    save_checkpoint_every_n_epoch = args.config.getint(
        'common', 'save_checkpoint_every_n_epoch')
    save_checkpoint_every_n_batch = args.config.getint(
        'common', 'save_checkpoint_every_n_batch')
    enable_logging_train_metric = args.config.getboolean(
        'train', 'enable_logging_train_metric')
    enable_logging_validation_metric = args.config.getboolean(
        'train', 'enable_logging_validation_metric')

    contexts = config_util.parse_contexts(args)
    num_gpu = len(contexts)
    eval_metric = STTMetric(batch_size=val_batch_size,
                            num_gpu=num_gpu,
                            is_logging=enable_logging_validation_metric,
                            is_epoch_end=True)
    # tensorboard setting
    loss_metric = STTMetric(batch_size=batch_size,
                            num_gpu=num_gpu,
                            is_logging=enable_logging_train_metric,
                            is_epoch_end=False)

    optimizer = args.config.get('optimizer', 'optimizer')
    learning_rate = args.config.getfloat('train', 'learning_rate')
    learning_rate_start = args.config.getfloat('train', 'learning_rate_start')
    learning_rate_annealing = args.config.getfloat('train',
                                                   'learning_rate_annealing')
    lr_factor = args.config.getfloat('train', 'lr_factor')

    mode = args.config.get('common', 'mode')
    num_epoch = args.config.getint('train', 'num_epoch')
    clip_gradient = args.config.getfloat('optimizer', 'clip_gradient')
    weight_decay = args.config.getfloat('optimizer', 'weight_decay')
    save_optimizer_states = args.config.getboolean('train',
                                                   'save_optimizer_states')
    show_every = args.config.getint('train', 'show_every')
    optimizer_params_dictionary = json.loads(
        args.config.get('optimizer', 'optimizer_params_dictionary'))
    kvstore_option = args.config.get('common', 'kvstore_option')
    n_epoch = begin_epoch
    is_bucketing = args.config.getboolean('arch', 'is_bucketing')

    # kv = mx.kv.create(kvstore_option)
    # data = mx.io.ImageRecordIter(num_parts=kv.num_workers, part_index=kv.rank)
    # # a.set_optimizer(optimizer)
    # updater = mx.optimizer.get_updater(optimizer)
    # a._set_updater(updater=updater)

    if clip_gradient == 0:
        clip_gradient = None
    if is_bucketing and mode == 'load':
        model_file = args.config.get('common', 'model_file')
        model_name = os.path.splitext(model_file)[0]
        model_num_epoch = int(model_name[-4:])

        model_path = 'checkpoints/' + str(model_name[:-5])
        prefix = args.config.get('common', 'prefix')
        if os.path.isabs(prefix):
            model_path = config_util.get_checkpoint_path(args).rsplit(
                "/", 1)[0] + "/" + str(model_name[:-5])
        # symbol, data_names, label_names = module(1600)
        model = mx.mod.BucketingModule(
            sym_gen=module,
            default_bucket_key=data_train.default_bucket_key,
            context=contexts)
        data_train.reset()

        model.bind(data_shapes=data_train.provide_data,
                   label_shapes=data_train.provide_label,
                   for_training=True)
        _, arg_params, aux_params = mx.model.load_checkpoint(
            model_path, model_num_epoch)

        # arg_params2 = {}
        # for item in arg_params.keys():
        #     if not item.startswith("forward") and not item.startswith("backward") and not item.startswith("rear"):
        #         arg_params2[item] = arg_params[item]
        # model.set_params(arg_params2, aux_params, allow_missing=True, allow_extra=True)

        model.set_params(arg_params, aux_params)
        module = model
    else:
        module.bind(data_shapes=data_train.provide_data,
                    label_shapes=data_train.provide_label,
                    for_training=True)

    if begin_epoch == 0 and mode == 'train':
        module.init_params(initializer=get_initializer(args))
        # model_file = args.config.get('common', 'model_file')
        # model_name = os.path.splitext(model_file)[0]
        # model_num_epoch = int(model_name[-4:])
        # model_path = 'checkpoints/' + str(model_name[:-5])
        # _, arg_params, aux_params = mx.model.load_checkpoint(model_path, model_num_epoch)
        # arg_params2 = {}
        # for item in arg_params.keys():
        #     if not item.startswith("forward") and not item.startswith("backward") and not item.startswith("rear"):
        #         arg_params2[item] = arg_params[item]
        # module.set_params(arg_params, aux_params, allow_missing=True, allow_extra=True)

    lr_scheduler = SimpleLRScheduler(learning_rate=learning_rate)

    # lr, lr_scheduler = _get_lr_scheduler(args, kv)

    def reset_optimizer(force_init=False):
        optimizer_params = {
            'lr_scheduler': lr_scheduler,
            'clip_gradient': clip_gradient,
            'wd': weight_decay
        }
        optimizer_params.update(optimizer_params_dictionary)
        module.init_optimizer(kvstore=kv,
                              optimizer=optimizer,
                              optimizer_params=optimizer_params,
                              force_init=force_init)

    if mode == "train":
        reset_optimizer(force_init=True)
    else:
        reset_optimizer(force_init=False)
        data_train.reset()
        data_train.is_first_epoch = True

    # tensorboard setting
    tblog_dir = args.config.get('common', 'tensorboard_log_dir')
    summary_writer = SummaryWriter(tblog_dir)
    learning_rate_pre = 0
    while True:

        if n_epoch >= num_epoch:
            break
        loss_metric.reset()
        log.info(host_name + '---------train---------')

        step_epochs = [
            int(l)
            for l in args.config.get('train', 'lr_step_epochs').split(',')
        ]
        # warm up to step_epochs[0] if step_epochs[0] > 0
        if n_epoch < step_epochs[0]:
            learning_rate_cur = learning_rate_start + n_epoch * (
                learning_rate - learning_rate_start) / step_epochs[0]
        else:
            # scaling lr every epoch
            if len(step_epochs) == 1:
                learning_rate_cur = learning_rate
                for s in range(n_epoch):
                    learning_rate_cur /= learning_rate_annealing
            # scaling lr by step_epochs[1:]
            else:
                learning_rate_cur = learning_rate
                for s in step_epochs[1:]:
                    if n_epoch > s:
                        learning_rate_cur *= lr_factor

        if learning_rate_pre and args.config.getboolean(
                'train', 'momentum_correction'):
            lr_scheduler.learning_rate = learning_rate_cur * learning_rate_cur / learning_rate_pre
        else:
            lr_scheduler.learning_rate = learning_rate_cur
        learning_rate_pre = learning_rate_cur
        log.info("n_epoch %d's lr is %.7f" %
                 (n_epoch, lr_scheduler.learning_rate))
        summary_writer.add_scalar('lr', lr_scheduler.learning_rate, n_epoch)
        for nbatch, data_batch in enumerate(data_train):

            module.forward_backward(data_batch)
            module.update()
            # tensorboard setting
            if (nbatch + 1) % show_every == 0:
                # loss_metric.set_audio_paths(data_batch.index)
                module.update_metric(loss_metric, data_batch.label)
                # print("loss=========== %.2f" % loss_metric.get_batch_loss())
            # summary_writer.add_scalar('loss batch', loss_metric.get_batch_loss(), nbatch)
            if (nbatch + 1) % save_checkpoint_every_n_batch == 0:
                log.info('Epoch[%d] Batch[%d] SAVE CHECKPOINT', n_epoch,
                         nbatch)
                save_checkpoint(
                    module,
                    prefix=config_util.get_checkpoint_path(args) + "n_epoch" +
                    str(n_epoch) + "n_batch",
                    epoch=(int(
                        (nbatch + 1) / save_checkpoint_every_n_batch) - 1),
                    save_optimizer_states=save_optimizer_states)
        # commented for Libri_sample data set to see only train cer
        log.info(host_name + '---------validation---------')
        data_val.reset()
        eval_metric.reset()
        for nbatch, data_batch in enumerate(data_val):
            # when is_train = False it leads to high cer when batch_norm
            module.forward(data_batch, is_train=True)
            eval_metric.set_audio_paths(data_batch.index)
            module.update_metric(eval_metric, data_batch.label)

        # tensorboard setting
        val_cer, val_n_label, val_l_dist, val_ctc_loss = eval_metric.get_name_value(
        )
        log.info("Epoch[%d] val cer=%f (%d / %d), ctc_loss=%f", n_epoch,
                 val_cer, int(val_n_label - val_l_dist), val_n_label,
                 val_ctc_loss)
        curr_acc = val_cer
        summary_writer.add_scalar('CER validation', val_cer, n_epoch)
        summary_writer.add_scalar('loss validation', val_ctc_loss, n_epoch)
        assert curr_acc is not None, 'cannot find Acc_exclude_padding in eval metric'
        np.random.seed(n_epoch)
        data_train.reset()
        data_train.is_first_epoch = False

        # tensorboard setting
        train_cer, train_n_label, train_l_dist, train_ctc_loss = loss_metric.get_name_value(
        )
        summary_writer.add_scalar('loss epoch', train_ctc_loss, n_epoch)
        summary_writer.add_scalar('CER train', train_cer, n_epoch)

        # save checkpoints
        if n_epoch % save_checkpoint_every_n_epoch == 0:
            log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch)
            save_checkpoint(module,
                            prefix=config_util.get_checkpoint_path(args),
                            epoch=n_epoch,
                            save_optimizer_states=save_optimizer_states)

        n_epoch += 1

    log.info('FINISH')
Exemple #11
0
def do_training(args, module, data_train, data_val, begin_epoch=0):
    from distutils.dir_util import mkpath
    from log_util import LogUtil

    log = LogUtil().getlogger()
    mkpath(os.path.dirname(get_checkpoint_path(args)))

    seq_len = args.config.get('arch', 'max_t_count')
    batch_size = args.config.getint('common', 'batch_size')
    save_checkpoint_every_n_epoch = args.config.getint('common', 'save_checkpoint_every_n_epoch')
    save_checkpoint_every_n_batch = args.config.getint('common', 'save_checkpoint_every_n_batch')
    enable_logging_train_metric = args.config.getboolean('train', 'enable_logging_train_metric')
    enable_logging_validation_metric = args.config.getboolean('train', 'enable_logging_validation_metric')

    contexts = parse_contexts(args)
    num_gpu = len(contexts)
    eval_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, seq_length=seq_len,is_logging=enable_logging_validation_metric,is_epoch_end=True)
    # tensorboard setting
    loss_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, seq_length=seq_len,is_logging=enable_logging_train_metric,is_epoch_end=False)

    optimizer = args.config.get('train', 'optimizer')
    momentum = args.config.getfloat('train', 'momentum')
    learning_rate = args.config.getfloat('train', 'learning_rate')
    learning_rate_annealing = args.config.getfloat('train', 'learning_rate_annealing')

    mode = args.config.get('common', 'mode')
    num_epoch = args.config.getint('train', 'num_epoch')
    clip_gradient = args.config.getfloat('train', 'clip_gradient')
    weight_decay = args.config.getfloat('train', 'weight_decay')
    save_optimizer_states = args.config.getboolean('train', 'save_optimizer_states')
    show_every = args.config.getint('train', 'show_every')
    n_epoch=begin_epoch

    if clip_gradient == 0:
        clip_gradient = None

    module.bind(data_shapes=data_train.provide_data,
                label_shapes=data_train.provide_label,
                for_training=True)

    if begin_epoch == 0 and mode == 'train':
        module.init_params(initializer=get_initializer(args))


    lr_scheduler = SimpleLRScheduler(learning_rate=learning_rate)

    def reset_optimizer(force_init=False):
        if optimizer == "sgd":
            module.init_optimizer(kvstore='device',
                                  optimizer=optimizer,
                                  optimizer_params={'lr_scheduler': lr_scheduler,
                                                    'momentum': momentum,
                                                    'clip_gradient': clip_gradient,
                                                    'wd': weight_decay},
                                  force_init=force_init)
        elif optimizer == "adam":
            module.init_optimizer(kvstore='device',
                                  optimizer=optimizer,
                                  optimizer_params={'lr_scheduler': lr_scheduler,
                                                    #'momentum': momentum,
                                                    'clip_gradient': clip_gradient,
                                                    'wd': weight_decay},
                                  force_init=force_init)
        else:
            raise Exception('Supported optimizers are sgd and adam. If you want to implement others define them in train.py')
    if mode == "train":
        reset_optimizer(force_init=True)
    else:
        reset_optimizer(force_init=False)

    #tensorboard setting
    tblog_dir = args.config.get('common', 'tensorboard_log_dir')
    summary_writer = SummaryWriter(tblog_dir)
    while True:

        if n_epoch >= num_epoch:
            break

        loss_metric.reset()
        log.info('---------train---------')
        for nbatch, data_batch in enumerate(data_train):

            module.forward_backward(data_batch)
            module.update()
            # tensorboard setting
            if (nbatch + 1) % show_every == 0:
                module.update_metric(loss_metric, data_batch.label)
            #summary_writer.add_scalar('loss batch', loss_metric.get_batch_loss(), nbatch)
            if (nbatch+1) % save_checkpoint_every_n_batch == 0:
                log.info('Epoch[%d] Batch[%d] SAVE CHECKPOINT', n_epoch, nbatch)
                module.save_checkpoint(prefix=get_checkpoint_path(args)+"n_epoch"+str(n_epoch)+"n_batch", epoch=(int((nbatch+1)/save_checkpoint_every_n_batch)-1), save_optimizer_states=save_optimizer_states)
        # commented for Libri_sample data set to see only train cer
        log.info('---------validation---------')
        data_val.reset()
        eval_metric.reset()
        for nbatch, data_batch in enumerate(data_val):
            # when is_train = False it leads to high cer when batch_norm
            module.forward(data_batch, is_train=True)
            module.update_metric(eval_metric, data_batch.label)

        # tensorboard setting
        val_cer, val_n_label, val_l_dist, _ = eval_metric.get_name_value()
        log.info("Epoch[%d] val cer=%f (%d / %d)", n_epoch, val_cer, int(val_n_label - val_l_dist), val_n_label)
        curr_acc = val_cer
        summary_writer.add_scalar('CER validation', val_cer, n_epoch)
        assert curr_acc is not None, 'cannot find Acc_exclude_padding in eval metric'

        data_train.reset()

        # tensorboard setting
        train_cer, train_n_label, train_l_dist, train_ctc_loss = loss_metric.get_name_value()
        summary_writer.add_scalar('loss epoch', train_ctc_loss, n_epoch)
        summary_writer.add_scalar('CER train', train_cer, n_epoch)

        # save checkpoints
        if n_epoch % save_checkpoint_every_n_epoch == 0:
            log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch)
            module.save_checkpoint(prefix=get_checkpoint_path(args), epoch=n_epoch, save_optimizer_states=save_optimizer_states)

        n_epoch += 1

        lr_scheduler.learning_rate=learning_rate/learning_rate_annealing

    log.info('FINISH')