elif mode == "predict": data_train, args = load_data(args) is_batchnorm = args.config.getboolean('arch', 'is_batchnorm') is_bucketing = args.config.getboolean('arch', 'is_bucketing') # log current config config_logger = ConfigLogger(log) config_logger(args.config) # load model model_loaded, model_num_epoch = load_model(args, contexts, data_train) # if mode is 'train', it trains the model if mode == 'train': if is_bucketing: module = STTBucketingModule( sym_gen=model_loaded, default_bucket_key=data_train.default_bucket_key, context=contexts) else: data_names = [x[0] for x in data_train.provide_data] label_names = [x[0] for x in data_train.provide_label] module = mx.mod.Module(model_loaded, context=contexts, data_names=data_names, label_names=label_names) do_training(args=args, module=module, data_train=data_train, data_val=data_val) # if mode is 'load', it loads model from the checkpoint and continues the training. elif mode == 'load': do_training(args=args,
def do_training(args, module, data_train, data_val, begin_epoch=0): from distutils.dir_util import mkpath from log_util import LogUtil log = LogUtil().getlogger() mkpath(os.path.dirname(get_checkpoint_path(args))) #seq_len = args.config.get('arch', 'max_t_count') batch_size = args.config.getint('common', 'batch_size') save_checkpoint_every_n_epoch = args.config.getint( 'common', 'save_checkpoint_every_n_epoch') save_checkpoint_every_n_batch = args.config.getint( 'common', 'save_checkpoint_every_n_batch') enable_logging_train_metric = args.config.getboolean( 'train', 'enable_logging_train_metric') enable_logging_validation_metric = args.config.getboolean( 'train', 'enable_logging_validation_metric') contexts = parse_contexts(args) num_gpu = len(contexts) eval_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, is_logging=enable_logging_validation_metric, is_epoch_end=True) # mxboard setting loss_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, is_logging=enable_logging_train_metric, is_epoch_end=False) optimizer = args.config.get('optimizer', 'optimizer') learning_rate = args.config.getfloat('train', 'learning_rate') learning_rate_annealing = args.config.getfloat('train', 'learning_rate_annealing') mode = args.config.get('common', 'mode') num_epoch = args.config.getint('train', 'num_epoch') clip_gradient = args.config.getfloat('optimizer', 'clip_gradient') weight_decay = args.config.getfloat('optimizer', 'weight_decay') save_optimizer_states = args.config.getboolean('train', 'save_optimizer_states') show_every = args.config.getint('train', 'show_every') optimizer_params_dictionary = json.loads( args.config.get('optimizer', 'optimizer_params_dictionary')) kvstore_option = args.config.get('common', 'kvstore_option') n_epoch = begin_epoch is_bucketing = args.config.getboolean('arch', 'is_bucketing') if clip_gradient == 0: clip_gradient = None if is_bucketing and mode == 'load': model_file = args.config.get('common', 'model_file') model_name = os.path.splitext(model_file)[0] model_num_epoch = int(model_name[-4:]) model_path = 'checkpoints/' + str(model_name[:-5]) symbol, data_names, label_names = module(1600) model = STTBucketingModule( sym_gen=module, default_bucket_key=data_train.default_bucket_key, context=contexts) data_train.reset() model.bind(data_shapes=data_train.provide_data, label_shapes=data_train.provide_label, for_training=True) _, arg_params, aux_params = mx.model.load_checkpoint( model_path, model_num_epoch) model.set_params(arg_params, aux_params) module = model else: module.bind(data_shapes=data_train.provide_data, label_shapes=data_train.provide_label, for_training=True) if begin_epoch == 0 and mode == 'train': module.init_params(initializer=get_initializer(args)) lr_scheduler = SimpleLRScheduler(learning_rate=learning_rate) def reset_optimizer(force_init=False): optimizer_params = { 'lr_scheduler': lr_scheduler, 'clip_gradient': clip_gradient, 'wd': weight_decay } optimizer_params.update(optimizer_params_dictionary) module.init_optimizer(kvstore=kvstore_option, optimizer=optimizer, optimizer_params=optimizer_params, force_init=force_init) if mode == "train": reset_optimizer(force_init=True) else: reset_optimizer(force_init=False) data_train.reset() data_train.is_first_epoch = True #mxboard setting mxlog_dir = args.config.get('common', 'mxboard_log_dir') summary_writer = SummaryWriter(mxlog_dir) while True: if n_epoch >= num_epoch: break loss_metric.reset() log.info('---------train---------') for nbatch, data_batch in enumerate(data_train): module.forward_backward(data_batch) module.update() # mxboard setting if (nbatch + 1) % show_every == 0: module.update_metric(loss_metric, data_batch.label) #summary_writer.add_scalar('loss batch', loss_metric.get_batch_loss(), nbatch) if (nbatch + 1) % save_checkpoint_every_n_batch == 0: log.info('Epoch[%d] Batch[%d] SAVE CHECKPOINT', n_epoch, nbatch) module.save_checkpoint( prefix=get_checkpoint_path(args) + "n_epoch" + str(n_epoch) + "n_batch", epoch=(int( (nbatch + 1) / save_checkpoint_every_n_batch) - 1), save_optimizer_states=save_optimizer_states) # commented for Libri_sample data set to see only train cer log.info('---------validation---------') data_val.reset() eval_metric.reset() for nbatch, data_batch in enumerate(data_val): # when is_train = False it leads to high cer when batch_norm module.forward(data_batch, is_train=True) module.update_metric(eval_metric, data_batch.label) # mxboard setting val_cer, val_n_label, val_l_dist, _ = eval_metric.get_name_value() log.info("Epoch[%d] val cer=%f (%d / %d)", n_epoch, val_cer, int(val_n_label - val_l_dist), val_n_label) curr_acc = val_cer summary_writer.add_scalar('CER validation', val_cer, n_epoch) assert curr_acc is not None, 'cannot find Acc_exclude_padding in eval metric' data_train.reset() data_train.is_first_epoch = False # mxboard setting train_cer, train_n_label, train_l_dist, train_ctc_loss = loss_metric.get_name_value( ) summary_writer.add_scalar('loss epoch', train_ctc_loss, n_epoch) summary_writer.add_scalar('CER train', train_cer, n_epoch) # save checkpoints if n_epoch % save_checkpoint_every_n_epoch == 0: log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch) module.save_checkpoint(prefix=get_checkpoint_path(args), epoch=n_epoch, save_optimizer_states=save_optimizer_states) n_epoch += 1 lr_scheduler.learning_rate = learning_rate / learning_rate_annealing log.info('FINISH')
def do_training(args, module, data_train, data_val, begin_epoch=0): from distutils.dir_util import mkpath from log_util import LogUtil log = LogUtil.getInstance().getlogger() mkpath(os.path.dirname(get_checkpoint_path(args))) #seq_len = args.config.get('arch', 'max_t_count') batch_size = args.config.getint('common', 'batch_size') save_checkpoint_every_n_epoch = args.config.getint('common', 'save_checkpoint_every_n_epoch') save_checkpoint_every_n_batch = args.config.getint('common', 'save_checkpoint_every_n_batch') enable_logging_train_metric = args.config.getboolean('train', 'enable_logging_train_metric') enable_logging_validation_metric = args.config.getboolean('train', 'enable_logging_validation_metric') contexts = parse_contexts(args) num_gpu = len(contexts) eval_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, is_logging=enable_logging_validation_metric,is_epoch_end=True) # mxboard setting loss_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu, is_logging=enable_logging_train_metric,is_epoch_end=False) optimizer = args.config.get('optimizer', 'optimizer') learning_rate = args.config.getfloat('train', 'learning_rate') learning_rate_annealing = args.config.getfloat('train', 'learning_rate_annealing') mode = args.config.get('common', 'mode') num_epoch = args.config.getint('train', 'num_epoch') clip_gradient = args.config.getfloat('optimizer', 'clip_gradient') weight_decay = args.config.getfloat('optimizer', 'weight_decay') save_optimizer_states = args.config.getboolean('train', 'save_optimizer_states') show_every = args.config.getint('train', 'show_every') optimizer_params_dictionary = json.loads(args.config.get('optimizer', 'optimizer_params_dictionary')) kvstore_option = args.config.get('common', 'kvstore_option') n_epoch=begin_epoch is_bucketing = args.config.getboolean('arch', 'is_bucketing') if clip_gradient == 0: clip_gradient = None if is_bucketing and mode == 'load': model_file = args.config.get('common', 'model_file') model_name = os.path.splitext(model_file)[0] model_num_epoch = int(model_name[-4:]) model_path = 'checkpoints/' + str(model_name[:-5]) symbol, data_names, label_names = module(1600) model = STTBucketingModule( sym_gen=module, default_bucket_key=data_train.default_bucket_key, context=contexts) data_train.reset() model.bind(data_shapes=data_train.provide_data, label_shapes=data_train.provide_label, for_training=True) _, arg_params, aux_params = mx.model.load_checkpoint(model_path, model_num_epoch) model.set_params(arg_params, aux_params) module = model else: module.bind(data_shapes=data_train.provide_data, label_shapes=data_train.provide_label, for_training=True) if begin_epoch == 0 and mode == 'train': module.init_params(initializer=get_initializer(args)) lr_scheduler = SimpleLRScheduler(learning_rate=learning_rate) def reset_optimizer(force_init=False): optimizer_params = {'lr_scheduler': lr_scheduler, 'clip_gradient': clip_gradient, 'wd': weight_decay} optimizer_params.update(optimizer_params_dictionary) module.init_optimizer(kvstore=kvstore_option, optimizer=optimizer, optimizer_params=optimizer_params, force_init=force_init) if mode == "train": reset_optimizer(force_init=True) else: reset_optimizer(force_init=False) data_train.reset() data_train.is_first_epoch = True #mxboard setting mxlog_dir = args.config.get('common', 'mxboard_log_dir') summary_writer = SummaryWriter(mxlog_dir) while True: if n_epoch >= num_epoch: break loss_metric.reset() log.info('---------train---------') for nbatch, data_batch in enumerate(data_train): module.forward_backward(data_batch) module.update() # mxboard setting if (nbatch + 1) % show_every == 0: module.update_metric(loss_metric, data_batch.label) #summary_writer.add_scalar('loss batch', loss_metric.get_batch_loss(), nbatch) if (nbatch+1) % save_checkpoint_every_n_batch == 0: log.info('Epoch[%d] Batch[%d] SAVE CHECKPOINT', n_epoch, nbatch) module.save_checkpoint(prefix=get_checkpoint_path(args)+"n_epoch"+str(n_epoch)+"n_batch", epoch=(int((nbatch+1)/save_checkpoint_every_n_batch)-1), save_optimizer_states=save_optimizer_states) # commented for Libri_sample data set to see only train cer log.info('---------validation---------') data_val.reset() eval_metric.reset() for nbatch, data_batch in enumerate(data_val): # when is_train = False it leads to high cer when batch_norm module.forward(data_batch, is_train=True) module.update_metric(eval_metric, data_batch.label) # mxboard setting val_cer, val_n_label, val_l_dist, _ = eval_metric.get_name_value() log.info("Epoch[%d] val cer=%f (%d / %d)", n_epoch, val_cer, int(val_n_label - val_l_dist), val_n_label) curr_acc = val_cer summary_writer.add_scalar('CER validation', val_cer, n_epoch) assert curr_acc is not None, 'cannot find Acc_exclude_padding in eval metric' data_train.reset() data_train.is_first_epoch = False # mxboard setting train_cer, train_n_label, train_l_dist, train_ctc_loss = loss_metric.get_name_value() summary_writer.add_scalar('loss epoch', train_ctc_loss, n_epoch) summary_writer.add_scalar('CER train', train_cer, n_epoch) # save checkpoints if n_epoch % save_checkpoint_every_n_epoch == 0: log.info('Epoch[%d] SAVE CHECKPOINT', n_epoch) module.save_checkpoint(prefix=get_checkpoint_path(args), epoch=n_epoch, save_optimizer_states=save_optimizer_states) n_epoch += 1 lr_scheduler.learning_rate=learning_rate/learning_rate_annealing log.info('FINISH')
do_training(args=args, module=model_loaded, data_train=data_train, data_val=data_val, begin_epoch=model_num_epoch + 1) # if mode is 'predict', it predict label from the input by the input model elif mode == 'predict': # predict through data if is_bucketing: max_t_count = args.config.getint('arch', 'max_t_count') load_optimizer_states = args.config.getboolean('load', 'load_optimizer_states') model_file = args.config.get('common', 'model_file') model_name = os.path.splitext(model_file)[0] model_num_epoch = int(model_name[-4:]) model_path = 'checkpoints/' + str(model_name[:-5]) model = STTBucketingModule( sym_gen=model_loaded, default_bucket_key=data_train.default_bucket_key, context=contexts ) model.bind(data_shapes=data_train.provide_data, label_shapes=data_train.provide_label, for_training=True) _, arg_params, aux_params = mx.model.load_checkpoint(model_path, model_num_epoch) model.set_params(arg_params, aux_params) model_loaded = model else: model_loaded.bind(for_training=False, data_shapes=data_train.provide_data, label_shapes=data_train.provide_label) max_t_count = args.config.getint('arch', 'max_t_count') eval_metric = STTMetric(batch_size=batch_size, num_gpu=num_gpu) if is_batchnorm:
def __init__(self): if len(sys.argv) <= 1: raise Exception('cfg file path must be provided. ' + 'ex)python main.py --configfile examplecfg.cfg') self.args = parse_args(sys.argv[1]) # set parameters from cfg file # give random seed self.random_seed = self.args.config.getint('common', 'random_seed') self.mx_random_seed = self.args.config.getint('common', 'mx_random_seed') # random seed for shuffling data list if self.random_seed != -1: np.random.seed(self.random_seed) # set mx.random.seed to give seed for parameter initialization if self.mx_random_seed != -1: mx.random.seed(self.mx_random_seed) else: mx.random.seed(hash(datetime.now())) # set log file name self.log_filename = self.args.config.get('common', 'log_filename') self.log = LogUtil(filename=self.log_filename).getlogger() # set parameters from data section(common) self.mode = self.args.config.get('common', 'mode') # get meta file where character to number conversions are defined self.contexts = parse_contexts(self.args) self.num_gpu = len(self.contexts) self.batch_size = self.args.config.getint('common', 'batch_size') # check the number of gpus is positive divisor of the batch size for data parallel self.is_batchnorm = self.args.config.getboolean('arch', 'is_batchnorm') self.is_bucketing = self.args.config.getboolean('arch', 'is_bucketing') # log current config self.config_logger = ConfigLogger(self.log) self.config_logger(self.args.config) default_bucket_key = 1600 self.args.config.set('arch', 'max_t_count', str(default_bucket_key)) self.args.config.set('arch', 'max_label_length', str(100)) self.labelUtil = LabelUtil() is_bi_graphemes = self.args.config.getboolean('common', 'is_bi_graphemes') load_labelutil(self.labelUtil, is_bi_graphemes, language="zh") self.args.config.set('arch', 'n_classes', str(self.labelUtil.get_count())) self.max_t_count = self.args.config.getint('arch', 'max_t_count') # self.load_optimizer_states = self.args.config.getboolean('load', 'load_optimizer_states') # load model self.model_loaded, self.model_num_epoch, self.model_path = load_model( self.args) self.model = STTBucketingModule(sym_gen=self.model_loaded, default_bucket_key=default_bucket_key, context=self.contexts) from importlib import import_module prepare_data_template = import_module( self.args.config.get('arch', 'arch_file')) init_states = prepare_data_template.prepare_data(self.args) width = self.args.config.getint('data', 'width') height = self.args.config.getint('data', 'height') self.model.bind(data_shapes=[ ('data', (self.batch_size, default_bucket_key, width * height)) ] + init_states, label_shapes=[ ('label', (self.batch_size, self.args.config.getint('arch', 'max_label_length'))) ], for_training=True) _, self.arg_params, self.aux_params = mx.model.load_checkpoint( self.model_path, self.model_num_epoch) self.model.set_params(self.arg_params, self.aux_params, allow_extra=True, allow_missing=True) try: from swig_wrapper import Scorer vocab_list = [ chars.encode("utf-8") for chars in self.labelUtil.byList ] self.log.info("vacab_list len is %d" % len(vocab_list)) _ext_scorer = Scorer(0.26, 0.1, self.args.config.get('common', 'kenlm'), vocab_list) lm_char_based = _ext_scorer.is_character_based() lm_max_order = _ext_scorer.get_max_order() lm_dict_size = _ext_scorer.get_dict_size() self.log.info("language model: " "is_character_based = %d," % lm_char_based + " max_order = %d," % lm_max_order + " dict_size = %d" % lm_dict_size) self.scorer = _ext_scorer # self.eval_metric = EvalSTTMetric(batch_size=self.batch_size, num_gpu=self.num_gpu, is_logging=True, # scorer=_ext_scorer) except ImportError: import kenlm km = kenlm.Model(self.args.config.get('common', 'kenlm')) self.scorer = km.score