def do_training(args, module, data_train, data_val, begin_epoch=0): from distutils.dir_util import mkpath from log_util import LogUtil log = LogUtil().getlogger() mkpath(os.path.dirname(get_checkpoint_path(args))) seq_len = args.config.get('arch', 'max_t_count') batch_size = args.config.getint('common', 'batch_size') save_checkpoint_every_n_epoch = args.config.getint( 'common', 'save_checkpoint_every_n_epoch') contexts = parse_contexts(args) num_gpu = len(contexts) eval_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, seq_length=seq_len) optimizer = args.config.get('train', 'optimizer') momentum = args.config.getfloat('train', 'momentum') learning_rate = args.config.getfloat('train', 'learning_rate') lr_scheduler = SimpleLRScheduler(learning_rate, momentum=momentum, optimizer=optimizer) n_epoch = begin_epoch num_epoch = args.config.getint('train', 'num_epoch') clip_gradient = args.config.getfloat('train', 'clip_gradient') weight_decay = args.config.getfloat('train', 'weight_decay') save_optimizer_states = args.config.getboolean('train', 'save_optimizer_states') show_every = args.config.getint('train', 'show_every') if clip_gradient == 0: clip_gradient = None module.bind(data_shapes=data_train.provide_data, label_shapes=data_train.provide_label, for_training=True) if begin_epoch == 0: module.init_params(initializer=get_initializer(args)) def reset_optimizer(): module.init_optimizer(kvstore='device', optimizer=args.config.get('train', 'optimizer'), optimizer_params={ 'clip_gradient': clip_gradient, 'wd': weight_decay }, force_init=True) reset_optimizer() while True: if n_epoch >= num_epoch: break eval_metric.reset() log.info('---------train---------') for nbatch, data_batch in enumerate(data_train): if data_batch.effective_sample_count is not None: lr_scheduler.effective_sample_count = data_batch.effective_sample_count module.forward_backward(data_batch) module.update() if (nbatch + 1) % show_every == 0: module.update_metric(eval_metric, data_batch.label) # commented for Libri_sample data set to see only train cer log.info('---------validation---------') for nbatch, data_batch in enumerate(data_val): module.update_metric(eval_metric, data_batch.label) #module.score(eval_data=data_val, num_batch=None, eval_metric=eval_metric, reset=True) data_train.reset() # save checkpoints if n_epoch % save_checkpoint_every_n_epoch == 0: log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch) module.save_checkpoint(prefix=get_checkpoint_path(args), epoch=n_epoch, save_optimizer_states=save_optimizer_states) n_epoch += 1 log.info('FINISH')
decoding_method = args.config.get('train', 'method') contexts = parse_contexts(args) init_states, test_sets, label_mean_sets = prepare_data(args) state_names = [x[0] for x in init_states] batch_size = args.config.getint('train', 'batch_size') num_hidden = args.config.getint('arch', 'num_hidden') num_hidden_proj = args.config.getint('arch', 'num_hidden_proj') num_lstm_layer = args.config.getint('arch', 'num_lstm_layer') feat_dim = args.config.getint('data', 'xdim') label_dim = args.config.getint('data', 'ydim') out_file = args.config.get('data', 'out_file') out_dir = args.config.get('data', 'out_dir') num_epoch = args.config.getint('train', 'num_epoch') model_name = get_checkpoint_path(args) logging.basicConfig(level=logging.DEBUG, format='%(asctime)-15s %(message)s') # load the model sym, arg_params, aux_params = mx.model.load_checkpoint( model_name, num_epoch) if decoding_method == METHOD_BUCKETING: buckets = args.config.get('train', 'buckets') buckets = list(map(int, re.split(r'\W+', buckets))) data_test = BucketSentenceIter(test_sets, buckets, batch_size, init_states, feat_dim=feat_dim,
def do_training(training_method, args, module, data_train, data_val): from distutils.dir_util import mkpath mkpath(os.path.dirname(get_checkpoint_path(args))) batch_size = data_train.batch_size batch_end_callbacks = [mx.callback.Speedometer(batch_size, args.config.getint('train', 'show_every'))] eval_allow_extra = True if training_method == METHOD_TBPTT else False eval_metric = [mx.metric.np(CrossEntropy, allow_extra_outputs=eval_allow_extra), mx.metric.np(Acc_exclude_padding, allow_extra_outputs=eval_allow_extra)] eval_metric = mx.metric.create(eval_metric) optimizer = args.config.get('train', 'optimizer') momentum = args.config.getfloat('train', 'momentum') learning_rate = args.config.getfloat('train', 'learning_rate') lr_scheduler = SimpleLRScheduler(learning_rate, momentum=momentum, optimizer=optimizer) if training_method == METHOD_TBPTT: lr_scheduler.seq_len = data_train.truncate_len n_epoch = 0 num_epoch = args.config.getint('train', 'num_epoch') learning_rate = args.config.getfloat('train', 'learning_rate') decay_factor = args.config.getfloat('train', 'decay_factor') decay_bound = args.config.getfloat('train', 'decay_lower_bound') clip_gradient = args.config.getfloat('train', 'clip_gradient') weight_decay = args.config.getfloat('train', 'weight_decay') if clip_gradient == 0: clip_gradient = None last_acc = -float("Inf") last_params = None module.bind(data_shapes=data_train.provide_data, label_shapes=data_train.provide_label, for_training=True) module.init_params(initializer=get_initializer(args)) def reset_optimizer(): if optimizer == "sgd" or optimizer == "speechSGD": module.init_optimizer(kvstore='device', optimizer=args.config.get('train', 'optimizer'), optimizer_params={'lr_scheduler': lr_scheduler, 'momentum': momentum, 'rescale_grad': 1.0, 'clip_gradient': clip_gradient, 'wd': weight_decay}, force_init=True) else: module.init_optimizer(kvstore='device', optimizer=args.config.get('train', 'optimizer'), optimizer_params={'lr_scheduler': lr_scheduler, 'rescale_grad': 1.0, 'clip_gradient': clip_gradient, 'wd': weight_decay}, force_init=True) reset_optimizer() while True: tic = time.time() eval_metric.reset() for nbatch, data_batch in enumerate(data_train): if training_method == METHOD_TBPTT: lr_scheduler.effective_sample_count = data_train.batch_size * truncate_len lr_scheduler.momentum = np.power(np.power(momentum, 1.0/(data_train.batch_size * truncate_len)), data_batch.effective_sample_count) else: if data_batch.effective_sample_count is not None: lr_scheduler.effective_sample_count = 1#data_batch.effective_sample_count module.forward_backward(data_batch) module.update() module.update_metric(eval_metric, data_batch.label) batch_end_params = mx.model.BatchEndParam(epoch=n_epoch, nbatch=nbatch, eval_metric=eval_metric, locals=None) for callback in batch_end_callbacks: callback(batch_end_params) if training_method == METHOD_TBPTT: # copy over states outputs = module.get_outputs() # outputs[0] is softmax, 1:end are states for i in range(1, len(outputs)): outputs[i].copyto(data_train.init_state_arrays[i-1]) for name, val in eval_metric.get_name_value(): logging.info('Epoch[%d] Train-%s=%f', n_epoch, name, val) toc = time.time() logging.info('Epoch[%d] Time cost=%.3f', n_epoch, toc-tic) data_train.reset() # test on eval data score_with_state_forwarding(module, data_val, eval_metric) # test whether we should decay learning rate curr_acc = None for name, val in eval_metric.get_name_value(): logging.info("Epoch[%d] Dev-%s=%f", n_epoch, name, val) if name == 'CrossEntropy': curr_acc = val assert curr_acc is not None, 'cannot find Acc_exclude_padding in eval metric' if n_epoch > 0 and lr_scheduler.dynamic_lr > decay_bound and curr_acc > last_acc: logging.info('Epoch[%d] !!! Dev set performance drops, reverting this epoch', n_epoch) logging.info('Epoch[%d] !!! LR decay: %g => %g', n_epoch, lr_scheduler.dynamic_lr, lr_scheduler.dynamic_lr / float(decay_factor)) lr_scheduler.dynamic_lr /= decay_factor # we reset the optimizer because the internal states (e.g. momentum) # might already be exploded, so we want to start from fresh reset_optimizer() module.set_params(*last_params) else: last_params = module.get_params() last_acc = curr_acc n_epoch += 1 # save checkpoints mx.model.save_checkpoint(get_checkpoint_path(args), n_epoch, module.symbol, *last_params) if n_epoch == num_epoch: break
def do_training(args, module, data_train, data_val, begin_epoch=0): from distutils.dir_util import mkpath from log_util import LogUtil log = LogUtil().getlogger() mkpath(os.path.dirname(get_checkpoint_path(args))) seq_len = args.config.get('arch', 'max_t_count') batch_size = args.config.getint('common', 'batch_size') save_checkpoint_every_n_epoch = args.config.getint('common', 'save_checkpoint_every_n_epoch') contexts = parse_contexts(args) num_gpu = len(contexts) eval_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, seq_length=seq_len) optimizer = args.config.get('train', 'optimizer') momentum = args.config.getfloat('train', 'momentum') learning_rate = args.config.getfloat('train', 'learning_rate') lr_scheduler = SimpleLRScheduler(learning_rate, momentum=momentum, optimizer=optimizer) n_epoch = begin_epoch num_epoch = args.config.getint('train', 'num_epoch') clip_gradient = args.config.getfloat('train', 'clip_gradient') weight_decay = args.config.getfloat('train', 'weight_decay') save_optimizer_states = args.config.getboolean('train', 'save_optimizer_states') show_every = args.config.getint('train', 'show_every') if clip_gradient == 0: clip_gradient = None module.bind(data_shapes=data_train.provide_data, label_shapes=data_train.provide_label, for_training=True) if begin_epoch == 0: module.init_params(initializer=get_initializer(args)) def reset_optimizer(): module.init_optimizer(kvstore='device', optimizer=args.config.get('train', 'optimizer'), optimizer_params={'clip_gradient': clip_gradient, 'wd': weight_decay}, force_init=True) reset_optimizer() while True: if n_epoch >= num_epoch: break eval_metric.reset() log.info('---------train---------') for nbatch, data_batch in enumerate(data_train): if data_batch.effective_sample_count is not None: lr_scheduler.effective_sample_count = data_batch.effective_sample_count module.forward_backward(data_batch) module.update() if (nbatch+1) % show_every == 0: module.update_metric(eval_metric, data_batch.label) # commented for Libri_sample data set to see only train cer log.info('---------validation---------') for nbatch, data_batch in enumerate(data_val): module.update_metric(eval_metric, data_batch.label) #module.score(eval_data=data_val, num_batch=None, eval_metric=eval_metric, reset=True) data_train.reset() # save checkpoints if n_epoch % save_checkpoint_every_n_epoch == 0: log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch) module.save_checkpoint(prefix=get_checkpoint_path(args), epoch=n_epoch, save_optimizer_states=save_optimizer_states) n_epoch += 1 log.info('FINISH')
kv=kv) # if mode is 'predict', it predict label from the input by the input model elif mode == 'predict': # predict through data if is_bucketing: max_t_count = args.config.getint('arch', 'max_t_count') load_optimizer_states = args.config.getboolean( 'load', 'load_optimizer_states') model_file = args.config.get('common', 'model_file') model_name = os.path.splitext(model_file)[0] model_num_epoch = int(model_name[-4:]) model_path = 'checkpoints/' + str(model_name[:-5]) prefix = args.config.get('common', 'prefix') if os.path.isabs(prefix): model_path = config_util.get_checkpoint_path(args).rsplit( "/", 1)[0] + "/" + str(model_name[:-5]) model = STTBucketingModule( sym_gen=model_loaded, default_bucket_key=data_train.default_bucket_key, context=contexts) model.bind(data_shapes=data_train.provide_data, label_shapes=data_train.provide_label, for_training=True) _, arg_params, aux_params = mx.model.load_checkpoint( model_path, model_num_epoch) model.set_params(arg_params, aux_params, allow_missing=True) model_loaded = model else: model_loaded.bind(for_training=False,
def do_training(training_method, args, module, data_train, data_val): from distutils.dir_util import mkpath mkpath(os.path.dirname(get_checkpoint_path(args))) batch_size = data_train.batch_size batch_end_callbacks = [ mx.callback.Speedometer(batch_size, args.config.getint('train', 'show_every')) ] eval_allow_extra = True if training_method == METHOD_TBPTT else False eval_metric = [ mx.metric.np(CrossEntropy, allow_extra_outputs=eval_allow_extra), mx.metric.np(Acc_exclude_padding, allow_extra_outputs=eval_allow_extra) ] eval_metric = mx.metric.create(eval_metric) optimizer = args.config.get('train', 'optimizer') momentum = args.config.getfloat('train', 'momentum') learning_rate = args.config.getfloat('train', 'learning_rate') lr_scheduler = SimpleLRScheduler(learning_rate, momentum=momentum, optimizer=optimizer) if training_method == METHOD_TBPTT: lr_scheduler.seq_len = data_train.truncate_len n_epoch = 0 num_epoch = args.config.getint('train', 'num_epoch') learning_rate = args.config.getfloat('train', 'learning_rate') decay_factor = args.config.getfloat('train', 'decay_factor') decay_bound = args.config.getfloat('train', 'decay_lower_bound') clip_gradient = args.config.getfloat('train', 'clip_gradient') weight_decay = args.config.getfloat('train', 'weight_decay') if clip_gradient == 0: clip_gradient = None last_acc = -float("Inf") last_params = None module.bind(data_shapes=data_train.provide_data, label_shapes=data_train.provide_label, for_training=True) module.init_params(initializer=get_initializer(args)) def reset_optimizer(): if optimizer == "sgd" or optimizer == "speechSGD": module.init_optimizer(kvstore='local', optimizer=args.config.get( 'train', 'optimizer'), optimizer_params={ 'lr_scheduler': lr_scheduler, 'momentum': momentum, 'rescale_grad': 1.0, 'clip_gradient': clip_gradient, 'wd': weight_decay }, force_init=True) else: module.init_optimizer(kvstore='local', optimizer=args.config.get( 'train', 'optimizer'), optimizer_params={ 'lr_scheduler': lr_scheduler, 'rescale_grad': 1.0, 'clip_gradient': clip_gradient, 'wd': weight_decay }, force_init=True) reset_optimizer() while True: tic = time.time() eval_metric.reset() for nbatch, data_batch in enumerate(data_train): if training_method == METHOD_TBPTT: lr_scheduler.effective_sample_count = data_train.batch_size * truncate_len lr_scheduler.momentum = np.power( np.power(momentum, 1.0 / (data_train.batch_size * truncate_len)), data_batch.effective_sample_count) else: if data_batch.effective_sample_count is not None: lr_scheduler.effective_sample_count = data_batch.effective_sample_count module.forward_backward(data_batch) module.update() module.update_metric(eval_metric, data_batch.label) batch_end_params = mx.model.BatchEndParam(epoch=n_epoch, nbatch=nbatch, eval_metric=eval_metric, locals=None) for callback in batch_end_callbacks: callback(batch_end_params) if training_method == METHOD_TBPTT: # copy over states outputs = module.get_outputs() # outputs[0] is softmax, 1:end are states for i in range(1, len(outputs)): outputs[i].copyto(data_train.init_state_arrays[i - 1]) for name, val in eval_metric.get_name_value(): logging.info('Epoch[%d] Train-%s=%f', n_epoch, name, val) toc = time.time() logging.info('Epoch[%d] Time cost=%.3f', n_epoch, toc - tic) data_train.reset() # test on eval data score_with_state_forwarding(module, data_val, eval_metric) # test whether we should decay learning rate curr_acc = None for name, val in eval_metric.get_name_value(): logging.info("Epoch[%d] Dev-%s=%f", n_epoch, name, val) if name == 'CrossEntropy': curr_acc = val assert curr_acc is not None, 'cannot find Acc_exclude_padding in eval metric' if n_epoch > 0 and lr_scheduler.dynamic_lr > decay_bound and curr_acc > last_acc: logging.info( 'Epoch[%d] !!! Dev set performance drops, reverting this epoch', n_epoch) logging.info('Epoch[%d] !!! LR decay: %g => %g', n_epoch, lr_scheduler.dynamic_lr, lr_scheduler.dynamic_lr / float(decay_factor)) lr_scheduler.dynamic_lr /= decay_factor # we reset the optimizer because the internal states (e.g. momentum) # might already be exploded, so we want to start from fresh reset_optimizer() module.set_params(*last_params) else: last_params = module.get_params() last_acc = curr_acc n_epoch += 1 # save checkpoints mx.model.save_checkpoint(get_checkpoint_path(args), n_epoch, module.symbol, *last_params) if n_epoch == num_epoch: break
def do_training(args, module, data_train, data_val, begin_epoch=0): from distutils.dir_util import mkpath from log_util import LogUtil log = LogUtil().getlogger() mkpath(os.path.dirname(get_checkpoint_path(args))) #seq_len = args.config.get('arch', 'max_t_count') batch_size = args.config.getint('common', 'batch_size') save_checkpoint_every_n_epoch = args.config.getint( 'common', 'save_checkpoint_every_n_epoch') save_checkpoint_every_n_batch = args.config.getint( 'common', 'save_checkpoint_every_n_batch') enable_logging_train_metric = args.config.getboolean( 'train', 'enable_logging_train_metric') enable_logging_validation_metric = args.config.getboolean( 'train', 'enable_logging_validation_metric') contexts = parse_contexts(args) num_gpu = len(contexts) eval_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, is_logging=enable_logging_validation_metric, is_epoch_end=True) # mxboard setting loss_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, is_logging=enable_logging_train_metric, is_epoch_end=False) optimizer = args.config.get('optimizer', 'optimizer') learning_rate = args.config.getfloat('train', 'learning_rate') learning_rate_annealing = args.config.getfloat('train', 'learning_rate_annealing') mode = args.config.get('common', 'mode') num_epoch = args.config.getint('train', 'num_epoch') clip_gradient = args.config.getfloat('optimizer', 'clip_gradient') weight_decay = args.config.getfloat('optimizer', 'weight_decay') save_optimizer_states = args.config.getboolean('train', 'save_optimizer_states') show_every = args.config.getint('train', 'show_every') optimizer_params_dictionary = json.loads( args.config.get('optimizer', 'optimizer_params_dictionary')) kvstore_option = args.config.get('common', 'kvstore_option') n_epoch = begin_epoch is_bucketing = args.config.getboolean('arch', 'is_bucketing') if clip_gradient == 0: clip_gradient = None if is_bucketing and mode == 'load': model_file = args.config.get('common', 'model_file') model_name = os.path.splitext(model_file)[0] model_num_epoch = int(model_name[-4:]) model_path = 'checkpoints/' + str(model_name[:-5]) symbol, data_names, label_names = module(1600) model = STTBucketingModule( sym_gen=module, default_bucket_key=data_train.default_bucket_key, context=contexts) data_train.reset() model.bind(data_shapes=data_train.provide_data, label_shapes=data_train.provide_label, for_training=True) _, arg_params, aux_params = mx.model.load_checkpoint( model_path, model_num_epoch) model.set_params(arg_params, aux_params) module = model else: module.bind(data_shapes=data_train.provide_data, label_shapes=data_train.provide_label, for_training=True) if begin_epoch == 0 and mode == 'train': module.init_params(initializer=get_initializer(args)) lr_scheduler = SimpleLRScheduler(learning_rate=learning_rate) def reset_optimizer(force_init=False): optimizer_params = { 'lr_scheduler': lr_scheduler, 'clip_gradient': clip_gradient, 'wd': weight_decay } optimizer_params.update(optimizer_params_dictionary) module.init_optimizer(kvstore=kvstore_option, optimizer=optimizer, optimizer_params=optimizer_params, force_init=force_init) if mode == "train": reset_optimizer(force_init=True) else: reset_optimizer(force_init=False) data_train.reset() data_train.is_first_epoch = True #mxboard setting mxlog_dir = args.config.get('common', 'mxboard_log_dir') summary_writer = SummaryWriter(mxlog_dir) while True: if n_epoch >= num_epoch: break loss_metric.reset() log.info('---------train---------') for nbatch, data_batch in enumerate(data_train): module.forward_backward(data_batch) module.update() # mxboard setting if (nbatch + 1) % show_every == 0: module.update_metric(loss_metric, data_batch.label) #summary_writer.add_scalar('loss batch', loss_metric.get_batch_loss(), nbatch) if (nbatch + 1) % save_checkpoint_every_n_batch == 0: log.info('Epoch[%d] Batch[%d] SAVE CHECKPOINT', n_epoch, nbatch) module.save_checkpoint( prefix=get_checkpoint_path(args) + "n_epoch" + str(n_epoch) + "n_batch", epoch=(int( (nbatch + 1) / save_checkpoint_every_n_batch) - 1), save_optimizer_states=save_optimizer_states) # commented for Libri_sample data set to see only train cer log.info('---------validation---------') data_val.reset() eval_metric.reset() for nbatch, data_batch in enumerate(data_val): # when is_train = False it leads to high cer when batch_norm module.forward(data_batch, is_train=True) module.update_metric(eval_metric, data_batch.label) # mxboard setting val_cer, val_n_label, val_l_dist, _ = eval_metric.get_name_value() log.info("Epoch[%d] val cer=%f (%d / %d)", n_epoch, val_cer, int(val_n_label - val_l_dist), val_n_label) curr_acc = val_cer summary_writer.add_scalar('CER validation', val_cer, n_epoch) assert curr_acc is not None, 'cannot find Acc_exclude_padding in eval metric' data_train.reset() data_train.is_first_epoch = False # mxboard setting train_cer, train_n_label, train_l_dist, train_ctc_loss = loss_metric.get_name_value( ) summary_writer.add_scalar('loss epoch', train_ctc_loss, n_epoch) summary_writer.add_scalar('CER train', train_cer, n_epoch) # save checkpoints if n_epoch % save_checkpoint_every_n_epoch == 0: log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch) module.save_checkpoint(prefix=get_checkpoint_path(args), epoch=n_epoch, save_optimizer_states=save_optimizer_states) n_epoch += 1 lr_scheduler.learning_rate = learning_rate / learning_rate_annealing log.info('FINISH')
args.config.write(sys.stderr) decoding_method = args.config.get('train', 'method') contexts = parse_contexts(args) init_states, test_sets = prepare_data(args) state_names = [x[0] for x in init_states] batch_size = args.config.getint('train', 'batch_size') num_hidden = args.config.getint('arch', 'num_hidden') num_lstm_layer = args.config.getint('arch', 'num_lstm_layer') feat_dim = args.config.getint('data', 'xdim') label_dim = args.config.getint('data', 'ydim') out_file = args.config.get('data', 'out_file') num_epoch = args.config.getint('train', 'num_epoch') model_name = get_checkpoint_path(args) logging.basicConfig(level=logging.DEBUG, format='%(asctime)-15s %(message)s') # load the model label_mean = np.zeros((label_dim,1), dtype='float32') data_test = TruncatedSentenceIter(test_sets, batch_size, init_states, 20, feat_dim=feat_dim, do_shuffling=False, pad_zeros=True, has_label=True) for i, batch in enumerate(data_test.labels): hist, edges = np.histogram(batch.flat, bins=range(0,label_dim+1)) label_mean += hist.reshape(label_dim,1) kaldiWriter = KaldiWriteOut(None, out_file) kaldiWriter.open_or_fd() kaldiWriter.write("label_mean", label_mean)
def do_training(args, module, data_train, data_val, begin_epoch=0): from distutils.dir_util import mkpath from log_util import LogUtil log = LogUtil.getInstance().getlogger() mkpath(os.path.dirname(get_checkpoint_path(args))) #seq_len = args.config.get('arch', 'max_t_count') batch_size = args.config.getint('common', 'batch_size') save_checkpoint_every_n_epoch = args.config.getint('common', 'save_checkpoint_every_n_epoch') save_checkpoint_every_n_batch = args.config.getint('common', 'save_checkpoint_every_n_batch') enable_logging_train_metric = args.config.getboolean('train', 'enable_logging_train_metric') enable_logging_validation_metric = args.config.getboolean('train', 'enable_logging_validation_metric') contexts = parse_contexts(args) num_gpu = len(contexts) eval_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, is_logging=enable_logging_validation_metric,is_epoch_end=True) # mxboard setting loss_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, is_logging=enable_logging_train_metric,is_epoch_end=False) optimizer = args.config.get('optimizer', 'optimizer') learning_rate = args.config.getfloat('train', 'learning_rate') learning_rate_annealing = args.config.getfloat('train', 'learning_rate_annealing') mode = args.config.get('common', 'mode') num_epoch = args.config.getint('train', 'num_epoch') clip_gradient = args.config.getfloat('optimizer', 'clip_gradient') weight_decay = args.config.getfloat('optimizer', 'weight_decay') save_optimizer_states = args.config.getboolean('train', 'save_optimizer_states') show_every = args.config.getint('train', 'show_every') optimizer_params_dictionary = json.loads(args.config.get('optimizer', 'optimizer_params_dictionary')) kvstore_option = args.config.get('common', 'kvstore_option') n_epoch=begin_epoch is_bucketing = args.config.getboolean('arch', 'is_bucketing') if clip_gradient == 0: clip_gradient = None if is_bucketing and mode == 'load': model_file = args.config.get('common', 'model_file') model_name = os.path.splitext(model_file)[0] model_num_epoch = int(model_name[-4:]) model_path = 'checkpoints/' + str(model_name[:-5]) symbol, data_names, label_names = module(1600) model = STTBucketingModule( sym_gen=module, default_bucket_key=data_train.default_bucket_key, context=contexts) data_train.reset() model.bind(data_shapes=data_train.provide_data, label_shapes=data_train.provide_label, for_training=True) _, arg_params, aux_params = mx.model.load_checkpoint(model_path, model_num_epoch) model.set_params(arg_params, aux_params) module = model else: module.bind(data_shapes=data_train.provide_data, label_shapes=data_train.provide_label, for_training=True) if begin_epoch == 0 and mode == 'train': module.init_params(initializer=get_initializer(args)) lr_scheduler = SimpleLRScheduler(learning_rate=learning_rate) def reset_optimizer(force_init=False): optimizer_params = {'lr_scheduler': lr_scheduler, 'clip_gradient': clip_gradient, 'wd': weight_decay} optimizer_params.update(optimizer_params_dictionary) module.init_optimizer(kvstore=kvstore_option, optimizer=optimizer, optimizer_params=optimizer_params, force_init=force_init) if mode == "train": reset_optimizer(force_init=True) else: reset_optimizer(force_init=False) data_train.reset() data_train.is_first_epoch = True #mxboard setting mxlog_dir = args.config.get('common', 'mxboard_log_dir') summary_writer = SummaryWriter(mxlog_dir) while True: if n_epoch >= num_epoch: break loss_metric.reset() log.info('---------train---------') for nbatch, data_batch in enumerate(data_train): module.forward_backward(data_batch) module.update() # mxboard setting if (nbatch + 1) % show_every == 0: module.update_metric(loss_metric, data_batch.label) #summary_writer.add_scalar('loss batch', loss_metric.get_batch_loss(), nbatch) if (nbatch+1) % save_checkpoint_every_n_batch == 0: log.info('Epoch[%d] Batch[%d] SAVE CHECKPOINT', n_epoch, nbatch) module.save_checkpoint(prefix=get_checkpoint_path(args)+"n_epoch"+str(n_epoch)+"n_batch", epoch=(int((nbatch+1)/save_checkpoint_every_n_batch)-1), save_optimizer_states=save_optimizer_states) # commented for Libri_sample data set to see only train cer log.info('---------validation---------') data_val.reset() eval_metric.reset() for nbatch, data_batch in enumerate(data_val): # when is_train = False it leads to high cer when batch_norm module.forward(data_batch, is_train=True) module.update_metric(eval_metric, data_batch.label) # mxboard setting val_cer, val_n_label, val_l_dist, _ = eval_metric.get_name_value() log.info("Epoch[%d] val cer=%f (%d / %d)", n_epoch, val_cer, int(val_n_label - val_l_dist), val_n_label) curr_acc = val_cer summary_writer.add_scalar('CER validation', val_cer, n_epoch) assert curr_acc is not None, 'cannot find Acc_exclude_padding in eval metric' data_train.reset() data_train.is_first_epoch = False # mxboard setting train_cer, train_n_label, train_l_dist, train_ctc_loss = loss_metric.get_name_value() summary_writer.add_scalar('loss epoch', train_ctc_loss, n_epoch) summary_writer.add_scalar('CER train', train_cer, n_epoch) # save checkpoints if n_epoch % save_checkpoint_every_n_epoch == 0: log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch) module.save_checkpoint(prefix=get_checkpoint_path(args), epoch=n_epoch, save_optimizer_states=save_optimizer_states) n_epoch += 1 lr_scheduler.learning_rate=learning_rate/learning_rate_annealing log.info('FINISH')
def do_training(args, module, data_train, data_val, begin_epoch=0, kv=None): from distutils.dir_util import mkpath host_name = socket.gethostname() log = LogUtil().getlogger() mkpath(os.path.dirname(config_util.get_checkpoint_path(args))) # seq_len = args.config.get('arch', 'max_t_count') batch_size = args.config.getint('common', 'batch_size') val_batch_size = args.config.getint('common', 'val_batch_size') save_checkpoint_every_n_epoch = args.config.getint( 'common', 'save_checkpoint_every_n_epoch') save_checkpoint_every_n_batch = args.config.getint( 'common', 'save_checkpoint_every_n_batch') enable_logging_train_metric = args.config.getboolean( 'train', 'enable_logging_train_metric') enable_logging_validation_metric = args.config.getboolean( 'train', 'enable_logging_validation_metric') contexts = config_util.parse_contexts(args) num_gpu = len(contexts) eval_metric = STTMetric(batch_size=val_batch_size, num_gpu=num_gpu, is_logging=enable_logging_validation_metric, is_epoch_end=True) # tensorboard setting loss_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, is_logging=enable_logging_train_metric, is_epoch_end=False) optimizer = args.config.get('optimizer', 'optimizer') learning_rate = args.config.getfloat('train', 'learning_rate') learning_rate_start = args.config.getfloat('train', 'learning_rate_start') learning_rate_annealing = args.config.getfloat('train', 'learning_rate_annealing') lr_factor = args.config.getfloat('train', 'lr_factor') mode = args.config.get('common', 'mode') num_epoch = args.config.getint('train', 'num_epoch') clip_gradient = args.config.getfloat('optimizer', 'clip_gradient') weight_decay = args.config.getfloat('optimizer', 'weight_decay') save_optimizer_states = args.config.getboolean('train', 'save_optimizer_states') show_every = args.config.getint('train', 'show_every') optimizer_params_dictionary = json.loads( args.config.get('optimizer', 'optimizer_params_dictionary')) kvstore_option = args.config.get('common', 'kvstore_option') n_epoch = begin_epoch is_bucketing = args.config.getboolean('arch', 'is_bucketing') # kv = mx.kv.create(kvstore_option) # data = mx.io.ImageRecordIter(num_parts=kv.num_workers, part_index=kv.rank) # # a.set_optimizer(optimizer) # updater = mx.optimizer.get_updater(optimizer) # a._set_updater(updater=updater) if clip_gradient == 0: clip_gradient = None if is_bucketing and mode == 'load': model_file = args.config.get('common', 'model_file') model_name = os.path.splitext(model_file)[0] model_num_epoch = int(model_name[-4:]) model_path = 'checkpoints/' + str(model_name[:-5]) prefix = args.config.get('common', 'prefix') if os.path.isabs(prefix): model_path = config_util.get_checkpoint_path(args).rsplit( "/", 1)[0] + "/" + str(model_name[:-5]) # symbol, data_names, label_names = module(1600) model = mx.mod.BucketingModule( sym_gen=module, default_bucket_key=data_train.default_bucket_key, context=contexts) data_train.reset() model.bind(data_shapes=data_train.provide_data, label_shapes=data_train.provide_label, for_training=True) _, arg_params, aux_params = mx.model.load_checkpoint( model_path, model_num_epoch) # arg_params2 = {} # for item in arg_params.keys(): # if not item.startswith("forward") and not item.startswith("backward") and not item.startswith("rear"): # arg_params2[item] = arg_params[item] # model.set_params(arg_params2, aux_params, allow_missing=True, allow_extra=True) model.set_params(arg_params, aux_params) module = model else: module.bind(data_shapes=data_train.provide_data, label_shapes=data_train.provide_label, for_training=True) if begin_epoch == 0 and mode == 'train': module.init_params(initializer=get_initializer(args)) # model_file = args.config.get('common', 'model_file') # model_name = os.path.splitext(model_file)[0] # model_num_epoch = int(model_name[-4:]) # model_path = 'checkpoints/' + str(model_name[:-5]) # _, arg_params, aux_params = mx.model.load_checkpoint(model_path, model_num_epoch) # arg_params2 = {} # for item in arg_params.keys(): # if not item.startswith("forward") and not item.startswith("backward") and not item.startswith("rear"): # arg_params2[item] = arg_params[item] # module.set_params(arg_params, aux_params, allow_missing=True, allow_extra=True) lr_scheduler = SimpleLRScheduler(learning_rate=learning_rate) # lr, lr_scheduler = _get_lr_scheduler(args, kv) def reset_optimizer(force_init=False): optimizer_params = { 'lr_scheduler': lr_scheduler, 'clip_gradient': clip_gradient, 'wd': weight_decay } optimizer_params.update(optimizer_params_dictionary) module.init_optimizer(kvstore=kv, optimizer=optimizer, optimizer_params=optimizer_params, force_init=force_init) if mode == "train": reset_optimizer(force_init=True) else: reset_optimizer(force_init=False) data_train.reset() data_train.is_first_epoch = True # tensorboard setting tblog_dir = args.config.get('common', 'tensorboard_log_dir') summary_writer = SummaryWriter(tblog_dir) learning_rate_pre = 0 while True: if n_epoch >= num_epoch: break loss_metric.reset() log.info(host_name + '---------train---------') step_epochs = [ int(l) for l in args.config.get('train', 'lr_step_epochs').split(',') ] # warm up to step_epochs[0] if step_epochs[0] > 0 if n_epoch < step_epochs[0]: learning_rate_cur = learning_rate_start + n_epoch * ( learning_rate - learning_rate_start) / step_epochs[0] else: # scaling lr every epoch if len(step_epochs) == 1: learning_rate_cur = learning_rate for s in range(n_epoch): learning_rate_cur /= learning_rate_annealing # scaling lr by step_epochs[1:] else: learning_rate_cur = learning_rate for s in step_epochs[1:]: if n_epoch > s: learning_rate_cur *= lr_factor if learning_rate_pre and args.config.getboolean( 'train', 'momentum_correction'): lr_scheduler.learning_rate = learning_rate_cur * learning_rate_cur / learning_rate_pre else: lr_scheduler.learning_rate = learning_rate_cur learning_rate_pre = learning_rate_cur log.info("n_epoch %d's lr is %.7f" % (n_epoch, lr_scheduler.learning_rate)) summary_writer.add_scalar('lr', lr_scheduler.learning_rate, n_epoch) for nbatch, data_batch in enumerate(data_train): module.forward_backward(data_batch) module.update() # tensorboard setting if (nbatch + 1) % show_every == 0: # loss_metric.set_audio_paths(data_batch.index) module.update_metric(loss_metric, data_batch.label) # print("loss=========== %.2f" % loss_metric.get_batch_loss()) # summary_writer.add_scalar('loss batch', loss_metric.get_batch_loss(), nbatch) if (nbatch + 1) % save_checkpoint_every_n_batch == 0: log.info('Epoch[%d] Batch[%d] SAVE CHECKPOINT', n_epoch, nbatch) save_checkpoint( module, prefix=config_util.get_checkpoint_path(args) + "n_epoch" + str(n_epoch) + "n_batch", epoch=(int( (nbatch + 1) / save_checkpoint_every_n_batch) - 1), save_optimizer_states=save_optimizer_states) # commented for Libri_sample data set to see only train cer log.info(host_name + '---------validation---------') data_val.reset() eval_metric.reset() for nbatch, data_batch in enumerate(data_val): # when is_train = False it leads to high cer when batch_norm module.forward(data_batch, is_train=True) eval_metric.set_audio_paths(data_batch.index) module.update_metric(eval_metric, data_batch.label) # tensorboard setting val_cer, val_n_label, val_l_dist, val_ctc_loss = eval_metric.get_name_value( ) log.info("Epoch[%d] val cer=%f (%d / %d), ctc_loss=%f", n_epoch, val_cer, int(val_n_label - val_l_dist), val_n_label, val_ctc_loss) curr_acc = val_cer summary_writer.add_scalar('CER validation', val_cer, n_epoch) summary_writer.add_scalar('loss validation', val_ctc_loss, n_epoch) assert curr_acc is not None, 'cannot find Acc_exclude_padding in eval metric' np.random.seed(n_epoch) data_train.reset() data_train.is_first_epoch = False # tensorboard setting train_cer, train_n_label, train_l_dist, train_ctc_loss = loss_metric.get_name_value( ) summary_writer.add_scalar('loss epoch', train_ctc_loss, n_epoch) summary_writer.add_scalar('CER train', train_cer, n_epoch) # save checkpoints if n_epoch % save_checkpoint_every_n_epoch == 0: log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch) save_checkpoint(module, prefix=config_util.get_checkpoint_path(args), epoch=n_epoch, save_optimizer_states=save_optimizer_states) n_epoch += 1 log.info('FINISH')
def do_training(args, module, data_train, data_val, begin_epoch=0): from distutils.dir_util import mkpath from log_util import LogUtil log = LogUtil().getlogger() mkpath(os.path.dirname(get_checkpoint_path(args))) seq_len = args.config.get('arch', 'max_t_count') batch_size = args.config.getint('common', 'batch_size') save_checkpoint_every_n_epoch = args.config.getint('common', 'save_checkpoint_every_n_epoch') save_checkpoint_every_n_batch = args.config.getint('common', 'save_checkpoint_every_n_batch') enable_logging_train_metric = args.config.getboolean('train', 'enable_logging_train_metric') enable_logging_validation_metric = args.config.getboolean('train', 'enable_logging_validation_metric') contexts = parse_contexts(args) num_gpu = len(contexts) eval_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, seq_length=seq_len,is_logging=enable_logging_validation_metric,is_epoch_end=True) # tensorboard setting loss_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, seq_length=seq_len,is_logging=enable_logging_train_metric,is_epoch_end=False) optimizer = args.config.get('train', 'optimizer') momentum = args.config.getfloat('train', 'momentum') learning_rate = args.config.getfloat('train', 'learning_rate') learning_rate_annealing = args.config.getfloat('train', 'learning_rate_annealing') mode = args.config.get('common', 'mode') num_epoch = args.config.getint('train', 'num_epoch') clip_gradient = args.config.getfloat('train', 'clip_gradient') weight_decay = args.config.getfloat('train', 'weight_decay') save_optimizer_states = args.config.getboolean('train', 'save_optimizer_states') show_every = args.config.getint('train', 'show_every') n_epoch=begin_epoch if clip_gradient == 0: clip_gradient = None module.bind(data_shapes=data_train.provide_data, label_shapes=data_train.provide_label, for_training=True) if begin_epoch == 0 and mode == 'train': module.init_params(initializer=get_initializer(args)) lr_scheduler = SimpleLRScheduler(learning_rate=learning_rate) def reset_optimizer(force_init=False): if optimizer == "sgd": module.init_optimizer(kvstore='device', optimizer=optimizer, optimizer_params={'lr_scheduler': lr_scheduler, 'momentum': momentum, 'clip_gradient': clip_gradient, 'wd': weight_decay}, force_init=force_init) elif optimizer == "adam": module.init_optimizer(kvstore='device', optimizer=optimizer, optimizer_params={'lr_scheduler': lr_scheduler, #'momentum': momentum, 'clip_gradient': clip_gradient, 'wd': weight_decay}, force_init=force_init) else: raise Exception('Supported optimizers are sgd and adam. If you want to implement others define them in train.py') if mode == "train": reset_optimizer(force_init=True) else: reset_optimizer(force_init=False) #tensorboard setting tblog_dir = args.config.get('common', 'tensorboard_log_dir') summary_writer = SummaryWriter(tblog_dir) while True: if n_epoch >= num_epoch: break loss_metric.reset() log.info('---------train---------') for nbatch, data_batch in enumerate(data_train): module.forward_backward(data_batch) module.update() # tensorboard setting if (nbatch + 1) % show_every == 0: module.update_metric(loss_metric, data_batch.label) #summary_writer.add_scalar('loss batch', loss_metric.get_batch_loss(), nbatch) if (nbatch+1) % save_checkpoint_every_n_batch == 0: log.info('Epoch[%d] Batch[%d] SAVE CHECKPOINT', n_epoch, nbatch) module.save_checkpoint(prefix=get_checkpoint_path(args)+"n_epoch"+str(n_epoch)+"n_batch", epoch=(int((nbatch+1)/save_checkpoint_every_n_batch)-1), save_optimizer_states=save_optimizer_states) # commented for Libri_sample data set to see only train cer log.info('---------validation---------') data_val.reset() eval_metric.reset() for nbatch, data_batch in enumerate(data_val): # when is_train = False it leads to high cer when batch_norm module.forward(data_batch, is_train=True) module.update_metric(eval_metric, data_batch.label) # tensorboard setting val_cer, val_n_label, val_l_dist, _ = eval_metric.get_name_value() log.info("Epoch[%d] val cer=%f (%d / %d)", n_epoch, val_cer, int(val_n_label - val_l_dist), val_n_label) curr_acc = val_cer summary_writer.add_scalar('CER validation', val_cer, n_epoch) assert curr_acc is not None, 'cannot find Acc_exclude_padding in eval metric' data_train.reset() # tensorboard setting train_cer, train_n_label, train_l_dist, train_ctc_loss = loss_metric.get_name_value() summary_writer.add_scalar('loss epoch', train_ctc_loss, n_epoch) summary_writer.add_scalar('CER train', train_cer, n_epoch) # save checkpoints if n_epoch % save_checkpoint_every_n_epoch == 0: log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch) module.save_checkpoint(prefix=get_checkpoint_path(args), epoch=n_epoch, save_optimizer_states=save_optimizer_states) n_epoch += 1 lr_scheduler.learning_rate=learning_rate/learning_rate_annealing log.info('FINISH')