def eval_with_image_folder(self, eval_params): image_list = eval_params.image_list data_process_func = eval_params.data_process_func label_list = eval_params.lable_list log_interval = eval_params.log_interval metrics = eval_params.metrics show = eval_params.show show_func = eval_params.show_func save = eval_params.save save_func = eval_params.save_func callback = mx.callback.Speedometer(1, log_interval) num_samples = len(image_list) for idx, image_file in enumerate(image_list): if label_list is not None: label_file = label_list[idx] else: label_file = None datas, labels = data_process_func(image=image_file, label=label_file) preds = self.net(datas) self.update_metrics(preds, labels, metrics) callback_params = BatchEndParam(epoch=1, nbatch=1, eval_metric=metrics, locals=locals()) callback(callback_params) if show: show_func(image_file, datas, preds, labels) if save: save_func(image_file, datas, preds, labels)
def _eval_process(self, mod, eval_iter, metric, callback): outputs = [] metric.reset() eval_iter.reset() for nbatch, batch in enumerate(eval_iter): # forward mod.prepare(batch) mod.forward(batch, is_train=False) # remove padded samples num_pad = batch.pad batch_label = batch.label batch_out = mod.get_outputs() if num_pad > 0: batch_out = [ out[:-num_pad] if len(out) > num_pad else out for out in batch_out ] # shorter is loss batch_label = [lab[:-num_pad] for lab in batch_label] # metric metric.update(batch_label, batch_out) # callback callback(BatchEndParam(0, nbatch, metric, locals())) # save outputs if nbatch == 0: outputs = [bout.copy() for bout in batch_out] else: outputs = [ mx.nd.concat(out, bout, dim=0) for out, bout in zip(outputs, batch_out) ] criteria = metric.get_name_value() return criteria, outputs
def callback(metric_list, callback_list): for metric in metric_list: batch_end_params = BatchEndParam(epoch=epoch, nbatch=i, eval_metric=metric, locals=locals()) for callback in _as_list(batch_end_callback): callback(batch_end_params)
def eval_with_dataloader(self, eval_params): data_process_func = eval_params.data_process_func dataloader = eval_params.dataloader metrics = eval_params.metrics log_interval = eval_params.log_interval batch_size = eval_params.batch_size batch_end_callback_list = eval_params.batch_end_callback_list show = eval_params.show show_func = eval_params.show_func save = eval_params.save save_func = eval_params.save_func save_dir = eval_params.save_dir ctxs = eval_params.ctxs dtype = eval_params.dtype dataloader.reset() if batch_end_callback_list is None: batch_end_callback_list = [ mx.callback.Speedometer(batch_size, log_interval) ] # for nbatch, (datas, labels) in enumerate(dataloader): for nbatch, databatch in enumerate(dataloader): datas, labels = databatch.data[0], databatch.label[0] self.net.cast(dtype) gpu_datas = mx.gluon.utils.split_and_load(datas, ctxs) gpu_labels = mx.gluon.utils.split_and_load(labels, ctxs) for datai, labeli in zip(gpu_datas, gpu_labels): datai, labeli = data_process_func(datai, labeli) datai = datai.astype(dtype) predi = self.forward(datai) self.update_metrics(predi, labeli, metrics) if show: show_func(datas, predi, labeli) if save: save_func(datas, predi, labeli, os.path.join(save_dir, 'eval'), num_count=nbatch * datas.shape[0]) for batch_end_callback in batch_end_callback_list: batch_end_params = BatchEndParam(epoch=1, nbatch=nbatch, eval_metric=metrics, locals=locals()) batch_end_callback(batch_end_params)
def train_epoch(self, epoch): # for nbatch, (datas, labels) in enumerate(self.train_dataloader): for nbatch, databatch in enumerate(self.train_dataloader): datas, labels = databatch.data[0], databatch.label[0] self.train_batch(datas, labels) # batch-end-callback if self.batch_end_callback is not None: batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, eval_metric=self.metrics, locals=locals()) for callback in self.batch_end_callback: callback(batch_end_params)
def _train_process(self, mod, epoch, train_iter, valid_iter, metric, optimizer, batch_end_callback, epoch_end_callback, early_stop=30): # params extracting checkpoint, adapted_lr = epoch_end_callback arg_params = self.net_params[ 'args'] if self.net_params is not None else None aux_params = self.net_params[ 'auxs'] if self.net_params is not None else None # model initialization mod.bind(train_iter.provide_data, train_iter.provide_label, for_training=True) mod.init_params(initializer=mx.init.Xavier(), arg_params=arg_params, aux_params=aux_params, allow_missing=False) mod.init_optimizer(optimizer=optimizer) # training loops if checkpoint.rule == "greater": best = {'epoch': 0, 'name': checkpoint.criteria_name, 'value': 0} else: best = { 'epoch': 0, 'name': checkpoint.criteria_name, 'value': 1e10 } for nepoch in range(epoch): metric.reset() train_iter = iter(train_iter) for nbatch, batch in enumerate(train_iter): mod.forward(batch, is_train=True) mod.update_metric(metric, batch.label) mod.backward() mod.update() batch_end_callback( BatchEndParam(nepoch, nbatch, metric, locals())) # training result present result = metric.get_name_value() logging_line = 'Epoch[%d]\t' % nepoch logging_line += '\t'.join( ['Train-%s=%f' % (n, v) for n, v in result]) logging.info(logging_line) # sync aux params across devices arg_params, aux_params = mod.get_params() mod.set_params(arg_params, aux_params) # evaluation res, _ = self._eval_process(mod, valid_iter, metric, batch_end_callback) # epoch end callback best = checkpoint(best, res, nepoch, (mod.symbol, arg_params, aux_params)) adapted_lr(optimizer, nepoch - best['epoch']) # adapted learning rate # early stop logging.info('Validation Best performance: Epoch[%d] %s=%f' % (best['epoch'], best['name'], best['value'])) if nepoch - best['epoch'] > early_stop: break # reset train iter train_iter.reset()
def fit(self, train_data, eval_data=None, eval_metric='acc', epoch_end_callback=None, batch_end_callback=None, kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01), ), eval_end_callback=None, eval_batch_end_callback=None, initializer=Uniform(0.01), arg_params=None, aux_params=None, allow_missing=False, force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None, best_model_callbacks=None, eval_interval=None, validation_metric=None, monitor=None): assert num_epoch is not None, 'please specify number of epochs' self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label, for_training=True, force_rebind=force_rebind) if monitor is not None: self.install_monitor(monitor) self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params, allow_missing=allow_missing, force_init=force_init) self.init_optimizer(kvstore=kvstore, optimizer=optimizer, optimizer_params=optimizer_params) if not isinstance(eval_metric, metric.EvalMetric): eval_metric = metric.create(eval_metric) if validation_metric is None: validation_metric = copy.deepcopy(eval_metric) epoch_metric = copy.deepcopy(eval_metric) swa_arg_params = None swa_aux_params = None swa_cnt = 0 ################################################################################ # training loop ################################################################################ for epoch in range(begin_epoch, num_epoch): tic_epoch = time.time() eval_metric.reset() nbatch = 0 end_of_batch = False data_iter = iter(train_data) next_data_batch = next(data_iter) name_values = [] while not end_of_batch: data_batch = next_data_batch if monitor is not None: monitor.tic() self.forward_backward(data_batch) self.update() try: # pre fetch next batch next_data_batch = next(data_iter) self.prepare(next_data_batch) except StopIteration: end_of_batch = True self.update_metric(eval_metric, data_batch.label) if end_of_batch: name_values = eval_metric.get_name_value() if monitor is not None: monitor.toc_print() nbatch += 1 if batch_end_callback is not None: batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, eval_metric=eval_metric, locals=locals()) for callback in _as_list(batch_end_callback): callback(batch_end_params) eval_metric.reset() # ---------------------------------------- # evaluation on validation set to_go = eval_interval is not None and nbatch % eval_interval == 0 if to_go and eval_data: res = self.score( eval_data, validation_metric, score_end_callback=eval_end_callback, batch_end_callback=eval_batch_end_callback, epoch=epoch) for name, val in res: self.logger.info( 'Epoch[%d] Batch[%d] Validation-%s=%f', epoch, nbatch, name, val) if best_model_callbacks is not None: for callback in _as_list(best_model_callbacks): if callback.is_best(validation_metric): # sync aux params across devices arg_params, aux_params = self.get_params() sync_made = True callback.checkpoint_if_only_best( validation_metric, self.symbol, arg_params, aux_params) break # one epoch of training is finished for name, val in name_values: self.logger.info('Epoch[%d] Train-%s=%f', epoch + 1, name, val) toc_epoch = time.time() elapsed = (toc_epoch - tic_epoch) avg_speed = float(len(train_data)) / (toc_epoch - tic_epoch) self.logger.info('Epoch[%d] Time cost=%.3f', epoch + 1, elapsed) self.logger.info('Epoch[%d] Average speed=%.3f samples/sec', epoch + 1, avg_speed) # sync aux params across devices arg_params, aux_params = self.get_params() self.set_params(arg_params, aux_params) if epoch_end_callback is not None: for callback in _as_list(epoch_end_callback): callback(epoch, self.symbol, arg_params, aux_params) # evaluation on validation set if eval_data: res = self.score(eval_data, validation_metric, score_end_callback=eval_end_callback, batch_end_callback=eval_batch_end_callback, epoch=epoch + 1) for name, val in res: self.logger.info('Epoch[%d] Validation-%s=%f', epoch + 1, name, val) if best_model_callbacks is not None: for callback in _as_list(best_model_callbacks): callback.checkpoint_if_only_best( validation_metric, self.symbol, arg_params, aux_params) # end of epoch, reset the data-iter for another epoch train_data.reset()
def fit(self, train_data, plateau_lr, plateau_metric=None, fn_curr_model=None, plateau_backtrace=True, eval_data=None, eval_metric='acc', epoch_end_callback=None, batch_end_callback=None, kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01),), eval_end_callback=None, eval_batch_end_callback=None, initializer=mx.init.Uniform(0.01), arg_params=None, aux_params=None, allow_missing=False, force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None, validation_metric=None, validation_period=1, monitor=None): ''' overrides fit() in base_module. ''' assert num_epoch is not None, 'please specify number of epochs' self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label, for_training=True, force_rebind=force_rebind) if monitor is not None: self.install_monitor(monitor) self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params, allow_missing=allow_missing, force_init=force_init) self.init_optimizer(kvstore=kvstore, optimizer=optimizer, optimizer_params=optimizer_params) if validation_metric is None: validation_metric = eval_metric if not isinstance(eval_metric, mx.metric.EvalMetric): eval_metric = mx.metric.create(eval_metric) # we will use plateau lr scheduler. self.logger.info('Initial base lr = %.5f', self._optimizer.lr) plateau_lr.reset(self._optimizer.lr) self._optimizer.lr_scheduler = None if plateau_metric is None: plateau_metric = eval_metric elif not isinstance(plateau_metric, mx.metric.EvalMetric): plateau_metric = mx.metric.create(plateau_metric) ################################################################################ # training loop ################################################################################ for epoch in range(begin_epoch, num_epoch): tic = time.time() eval_metric.reset() if plateau_metric is not eval_metric: plateau_metric.reset() nbatch = 0 data_iter = iter(train_data) end_of_batch = False next_data_batch = next(data_iter) while not end_of_batch: data_batch = next_data_batch if monitor is not None: monitor.tic() self.forward_backward(data_batch) self.update() try: # pre fetch next batch next_data_batch = next(data_iter) self.prepare(next_data_batch) except StopIteration: end_of_batch = True self.update_metric(eval_metric, data_batch.label) if plateau_metric is not eval_metric: self.update_metric(plateau_metric, data_batch.label) if monitor is not None: monitor.toc_print() if batch_end_callback is not None: batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, eval_metric=eval_metric, locals=locals()) for callback in _as_list(batch_end_callback): callback(batch_end_params) nbatch += 1 # one epoch of training is finished for name, val in eval_metric.get_name_value(): self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val) toc = time.time() self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc-tic)) # sync aux params across devices arg_params, aux_params = self.get_params() self.set_params(arg_params, aux_params) if epoch_end_callback is not None: for callback in _as_list(epoch_end_callback): callback(epoch, self.symbol, arg_params, aux_params) #---------------------------------------- # evaluation on validation set if eval_data and epoch % validation_period == 0: res = self.score(eval_data, validation_metric, score_end_callback=eval_end_callback, batch_end_callback=eval_batch_end_callback, epoch=epoch) #TODO: pull this into default for name, val in res: self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val) # update lr using plateau algorithm new_lr, is_good = plateau_lr.update_lr(plateau_metric) if is_good and fn_curr_model is not None: super(PlateauModule, self).save_params(fn_curr_model) if new_lr == 0.0: self.logger.info('Minimum LR reached. Terminate training.') train_data.reset() break if new_lr < self._optimizer.lr: self.logger.info('Update lr, from %.6f to %.6f', self._optimizer.lr, new_lr) self._optimizer.lr = new_lr if fn_curr_model is not None and plateau_backtrace: self.logger.info('Reset network parameters to the previous best result.') super(PlateauModule, self).load_params(fn_curr_model) else: self.logger.info('Current lr = %.6f', self._optimizer.lr) # end of 1 epoch, reset the data-iter for another epoch train_data.reset()
def fit( self, train_data, eval_data=None, eval_metric="acc", epoch_end_callback=None, batch_end_callback=None, kvstore="local", optimizer="sgd", optimizer_params=(("learning_rate", 0.01),), eval_end_callback=None, eval_batch_end_callback=None, initializer=Uniform(0.01), arg_params=None, aux_params=None, allow_missing=False, force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None, validation_metric=None, monitor=None, prefix=None, ): """Train the module parameters. Parameters ---------- train_data : DataIter eval_data : DataIter If not `None`, will be used as validation set and evaluate the performance after each epoch. eval_metric : str or EvalMetric Default `'acc'`. The performance measure used to display during training. epoch_end_callback : function or list of function Each callback will be called with the current `epoch`, `symbol`, `arg_params` and `aux_params`. batch_end_callback : function or list of function Each callback will be called with a `BatchEndParam`. kvstore : str or KVStore Default `'local'`. optimizer : str or Optimizer Default `'sgd'` optimizer_params : dict Default `(('learning_rate', 0.01),)`. The parameters for the optimizer constructor. The default value is not a `dict`, just to avoid pylint warning on dangerous default values. eval_end_callback : function or list of function These will be called at the end of each full evaluation, with the metrics over the entire evaluation set. eval_batch_end_callback : function or list of function These will be called at the end of each minibatch during evaluation initializer : Initializer Will be called to initialize the module parameters if not already initialized. arg_params : dict Default `None`, if not `None`, should be existing parameters from a trained model or loaded from a checkpoint (previously saved model). In this case, the value here will be used to initialize the module parameters, unless they are already initialized by the user via a call to `init_params` or `fit`. `arg_params` has higher priority to `initializer`. aux_params : dict Default `None`. Similar to `arg_params`, except for auxiliary states. allow_missing : bool Default `False`. Indicate whether we allow missing parameters when `arg_params` and `aux_params` are not `None`. If this is `True`, then the missing parameters will be initialized via the `initializer`. force_rebind : bool Default `False`. Whether to force rebinding the executors if already binded. force_init : bool Default `False`. Indicate whether we should force initialization even if the parameters are already initialized. begin_epoch : int Default `0`. Indicate the starting epoch. Usually, if we are resuming from a checkpoint saved at a previous training phase at epoch N, then we should specify this value as N+1. num_epoch : int Number of epochs to run training. Examples -------- An example of using fit for training:: >>> #Assume training dataIter and validation dataIter are ready >>> mod.fit(train_data=train_dataiter, eval_data=val_dataiter, optimizer_params={'learning_rate':0.01, 'momentum': 0.9}, num_epoch=10) """ assert num_epoch is not None, "please specify number of epochs" self.bind( data_shapes=train_data.provide_data, label_shapes=train_data.provide_label, for_training=True, force_rebind=force_rebind, ) if monitor is not None: self.install_monitor(monitor) self.init_params( initializer=initializer, arg_params=arg_params, aux_params=aux_params, allow_missing=allow_missing, force_init=force_init, ) self.init_optimizer( kvstore=kvstore, optimizer=optimizer, optimizer_params=optimizer_params ) if validation_metric is None: validation_metric = eval_metric if not isinstance(eval_metric, metric.EvalMetric): eval_metric = metric.create(eval_metric) ################################################################################ # training loop ################################################################################ # epoch 0 if epoch_end_callback is not None: arg_params, aux_params = self.get_params() self.set_params(arg_params, aux_params) for callback in _as_list(epoch_end_callback): callback(-1, self.symbol, arg_params, aux_params) from lib.pair_matching.batch_updater_py_multi import batchUpdaterPyMulti config = self.config if config.TRAIN.TENSORBOARD_LOG: from mxboard import SummaryWriter tf_log_dir = os.path.join( os.path.dirname(prefix), "logs/{}".format(time.strftime("%Y-%m-%d-%H-%M")), ) summ_writer = SummaryWriter(logdir=tf_log_dir) interBatchUpdater = batchUpdaterPyMulti(config, 480, 640) last_lr = 0 cur_step = 0 for epoch in range(begin_epoch, num_epoch): tic = time.time() eval_metric.reset() for nbatch, data_batch in enumerate(train_data): if monitor is not None: monitor.tic() # disp weights L2 norm cur_lr = self._curr_module._optimizer._get_lr(0) if nbatch % (4000 / train_data.batch_size) == 0: all_params = self._curr_module.get_params()[0] all_param_names = all_params.keys() all_param_names = sorted(all_param_names) print_and_log(prefix, self.logger) weight_str = "" for view_name in all_param_names: weight_str += "{}: {} ".format( view_name, nd.norm(all_params[view_name]).asnumpy() ) print_and_log(weight_str, self.logger) print_and_log( "batch {}: lr: {}".format(nbatch, cur_lr), self.logger ) if config.TRAIN.TENSORBOARD_LOG: summ_writer.add_scalar( tag="learning_rate", value=cur_lr, global_step=cur_step ) if cur_lr != last_lr: print_and_log( "batch {}: lr: {}".format(nbatch, cur_lr), self.logger ) last_lr = cur_lr if config.TRAIN.TENSORBOARD_LOG: summ_writer.add_scalar( tag="learning_rate", value=cur_lr, global_step=cur_step ) train_iter_size = config.network.TRAIN_ITER_SIZE for iter_idx in range(train_iter_size): self.forward_backward(data_batch) preds = self._curr_module.get_outputs(False) self.update() if iter_idx != train_iter_size - 1: data_batch = interBatchUpdater.forward( data_batch, preds, config ) cur_step += 1 self.update_metric(eval_metric, data_batch.label) if monitor is not None: monitor.toc_print() if batch_end_callback is not None: batch_end_params = BatchEndParam( epoch=epoch, nbatch=nbatch, eval_metric=eval_metric, locals=locals(), ) for callback in _as_list(batch_end_callback): callback(batch_end_params) if config.TRAIN.TENSORBOARD_LOG: for name, val in eval_metric.get_name_value(): summ_writer.add_scalar( tag="BatchTrain-{}".format(name), value=val, global_step=cur_step, ) # one epoch of training is finished for name, val in eval_metric.get_name_value(): self.logger.info("Epoch[%d] Train-%s=%f", epoch, name, val) if config.TRAIN.TENSORBOARD_LOG: summ_writer.add_scalar( tag="EpochTrain-{}".format(name), value=val, global_step=epoch ) toc = time.time() self.logger.info("Epoch[%d] Time cost=%.3f", epoch, (toc - tic)) # sync aux params across devices arg_params, aux_params = self.get_params() self.set_params(arg_params, aux_params) if epoch_end_callback is not None: for callback in _as_list(epoch_end_callback): callback(epoch, self.symbol, arg_params, aux_params) # ---------------------------------------- # evaluation on validation set if eval_data: res = self.score( eval_data, validation_metric, score_end_callback=eval_end_callback, batch_end_callback=eval_batch_end_callback, epoch=epoch, ) # TODO: pull this into default for name, val in res: self.logger.info("Epoch[%d] Validation-%s=%f", epoch, name, val) # end of 1 epoch, reset the data-iter for another epoch train_data.reset()
def fit(self, train_data, eval_data=None, eval_metric='acc', epoch_end_callback=None, batch_end_callback=None, kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01),), eval_end_callback=None, eval_batch_end_callback=None, initializer=Uniform(0.01), arg_params=None, aux_params=None, allow_missing=False, force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None, validation_metric=None, monitor=None, summary_writer = None): assert num_epoch is not None, 'please specify number of epochs' self.num_batch = 0 self.writer = summary_writer self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label, for_training=True, force_rebind=force_rebind) if monitor is not None: self.install_monitor(monitor) self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params, allow_missing=allow_missing, force_init=force_init) self.init_optimizer(kvstore=kvstore, optimizer=optimizer, optimizer_params=optimizer_params) acc_metric = IgnoreAccuracy(output_names=['softmax_output'], label_names=['softmax_label']) # acc_metric = metric.Accuracy(output_names=['softmax_output'], label_names=['softmax_label']) lmnn_metric = metric.Loss(output_names=['lmnn_output'], label_names=['softmax_label']) if validation_metric is None: validation_metric = lmnn_metric ################################################################################ # training loop ################################################################################ for epoch in range(begin_epoch, num_epoch): tic = time.time() acc_metric.reset() lmnn_metric.reset() # eval_metric.reset() for nbatch, data_batch in enumerate(train_data): if monitor is not None: monitor.tic() self.forward_backward(data_batch) self.update() self.update_metric(acc_metric, data_batch.label) self.update_metric(lmnn_metric, data_batch.label) if monitor is not None: monitor.toc_print() if batch_end_callback is not None: batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, eval_metric=eval_metric, locals=locals()) for callback in _as_list(batch_end_callback): callback(batch_end_params) # one epoch of training is finished for name, val in acc_metric.get_name_value(): self.logger.info('Epoch[%d] Accuracy Train-%s=%f', epoch, name, val) for name, val in lmnn_metric.get_name_value(): self.logger.info('Epoch[%d] Lmnn Train-%s=%f', epoch, name, val) if self.num_batch % 10 == 0: # print acc_metric.sum_metric, acc_metric.num_inst self.writer.add_scalar('{}/cls_acc'.format('Train'), acc_metric.sum_metric / acc_metric.num_inst, self.num_batch) self.writer.add_scalar('{}/lmnn_loss'.format('Train'), lmnn_metric.sum_metric / lmnn_metric.num_inst, self.num_batch) self.num_batch += 1 toc = time.time() self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc-tic)) # sync aux params across devices arg_params, aux_params = self.get_params() self.set_params(arg_params, aux_params) if epoch_end_callback is not None: for callback in _as_list(epoch_end_callback): callback(epoch, self.symbol, arg_params, aux_params) #---------------------------------------- # evaluation on validation set if eval_data: res = self.score(eval_data, validation_metric, score_end_callback=eval_end_callback, batch_end_callback=eval_batch_end_callback, epoch=epoch) #TODO: pull this into default for name, val in res: self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val) # end of 1 epoch, reset the data-iter for another epoch train_data.reset()
def fit(self, train_data, eval_data=None, eval_metric='acc', validate_metric=None, work_load_list=None, epoch_end_callback=None, batch_end_callback=None, fixed_param_prefix=None, initializer=None, arg_params=None, aux_params=None, allow_missing=False, optimizer=None, optimizer_params=None, begin_epoch=0, num_epoch=None, kvstore='device', teacher_modules=None): if type(teacher_modules) is not list: teacher_modules = [teacher_modules] self.module.bind(data_shapes=self.data_shapes, label_shapes=self.label_shapes, for_training=True) self.module.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params, allow_missing=allow_missing) self.module.init_optimizer(kvstore=kvstore, optimizer=optimizer, optimizer_params=optimizer_params) if validate_metric is None: validate_metric = eval_metric if not isinstance(eval_metric, metric.EvalMetric): eval_metric = metric.create(eval_metric) # training loop for epoch in range(begin_epoch, num_epoch): tic = time.time() eval_metric.reset() nbatch = 0 data_iter = iter(train_data) end_of_batch = False next_data_batch = next(data_iter) while not end_of_batch: data_batch = next_data_batch if teacher_modules[0] is not None: for teacher_module in teacher_modules: teacher_module.forward(data_batch=data_batch, is_train=True) transfer_label = teacher_module.get_outputs() data_batch.label = data_batch.label + transfer_label self.module.forward(data_batch, is_train=True) self.module.backward() self.module.update() try: next_data_batch = next(data_iter) except StopIteration: end_of_batch = True self.module.update_metric(eval_metric, data_batch.label) if batch_end_callback is not None: batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, eval_metric=eval_metric, locals=locals()) for callback in _as_list(batch_end_callback): callback(batch_end_params) nbatch += 1 for name, val in eval_metric.get_name_value(): self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val) toc = time.time() self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic)) arg_params, aux_params = self.module.get_params() self.module.set_params(arg_params, aux_params) if epoch_end_callback is not None: for callback in _as_list(epoch_end_callback): callback(epoch, self.symbol, arg_params, aux_params) if eval_data: res = self.module.score(eval_data, validate_metric, score_end_callback=None, batch_end_callback=None, reset=True, epoch=epoch) for name, val in res: self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val) train_data.reset()
def fit(self, train_data, eval_data=None, eval_metric='acc', epoch_end_callback=None, batch_end_callback=None, kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01), ), eval_end_callback=None, eval_batch_end_callback=None, initializer=Uniform(0.01), arg_params=None, aux_params=None, allow_missing=False, force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None, validation_metric=None, monitor=None): assert num_epoch is not None, 'please specify number of epochs' self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label, for_training=True, force_rebind=force_rebind) if monitor is not None: self.install_monitor(monitor) self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params, allow_missing=allow_missing, force_init=force_init) self.init_optimizer(kvstore=kvstore, optimizer=optimizer, optimizer_params=optimizer_params) if validation_metric is None: validation_metric = eval_metric if not isinstance(eval_metric, metric.EvalMetric): eval_metric = metric.create(eval_metric) ################################################################################ # training loop ################################################################################ last_grad_debug = None for epoch in range(begin_epoch, num_epoch): tic = time.time() eval_metric.reset() nbatch = 0 data_iter = iter(train_data) end_of_batch = False next_data_batch = next(data_iter) while not end_of_batch: data_batch = next_data_batch if monitor is not None: monitor.tic() self.forward_backward(data_batch) # grad_array = [[grad.copyto(grad.context) if grad is not None else None for grad in grads] for grads in # self._curr_module._exec_group.grad_arrays] # # for exec_ in self._curr_module._exec_group.execs: # grad_dict = exec_.grad_dict # # grad_debug = dict() # for k, v in grad_dict.items(): # if v is not None: # v_np = v.asnumpy() # grad_debug[k] = (np.min(v_np), np.max(v_np)) # print 'rpn_conv_cls_weight:', grad_debug['rpn_conv_cls_weight'] # print 'rcnn_fc_cls_weight:', grad_debug['rcnn_fc_cls_weight'] self.update() try: # pre fetch next batch next_data_batch = next(data_iter) self.prepare(next_data_batch) except StopIteration: end_of_batch = True self.update_metric(eval_metric, data_batch.label) if monitor is not None: monitor.toc_print() if batch_end_callback is not None: batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, eval_metric=eval_metric, locals=locals()) for callback in _as_list(batch_end_callback): callback(batch_end_params) nbatch += 1 # one epoch of training is finished for name, val in eval_metric.get_name_value(): self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val) toc = time.time() self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic)) # sync aux params across devices arg_params, aux_params = self.get_params() self.set_params(arg_params, aux_params) if epoch_end_callback is not None: for callback in _as_list(epoch_end_callback): callback(epoch, self.symbol, arg_params, aux_params) #---------------------------------------- # evaluation on validation set if eval_data: res = self.score(eval_data, validation_metric, score_end_callback=eval_end_callback, batch_end_callback=eval_batch_end_callback, epoch=epoch) #TODO: pull this into default for name, val in res: self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val) # end of 1 epoch, reset the data-iter for another epoch train_data.reset()
# train logging.info("Training started ... ") for epoch in range(args.epochs): # train total_loss = 0.0 nbatch = 0 for batch in train_data: module.forward(batch) module.backward() module.update(max_norm=args.clip * bptt * batch_size) # update metric outputs = module.get_loss() total_loss += mx.nd.sum(outputs[0]).asscalar() speedometer_param = BatchEndParam(epoch=epoch, nbatch=nbatch, eval_metric=None, locals=locals()) speedometer(speedometer_param) if nbatch % args.log_interval == 0 and nbatch > 0: cur_loss = total_loss / bptt / batch_size / args.log_interval logging.info('Iter[%d] Batch [%d]\tLoss: %.7f,\tPerplexity:\t%.7f' % \ (epoch, nbatch, cur_loss, math.exp(cur_loss))) total_loss = 0.0 nbatch += 1 # validation valid_loss = evaluate(module, valid_data, epoch, 'Valid', bptt, batch_size) if valid_loss < best_loss: best_loss = valid_loss # test test_loss = evaluate(module, test_data, epoch, 'Test', bptt,
def fit( self, train_data, ogdb, eval_data=None, eval_metric='acc', epoch_end_callback=None, batch_end_callback=None, kvstore='local', optimizer='sgd', optimizer_params=( ('learning_rate', 0.01), ), #,('rescale_grad', 1.0/8.0),), #8 gpu attempt eval_end_callback=None, iter_size=1, eval_batch_end_callback=None, initializer=Uniform(0.01), arg_params=None, aux_params=None, allow_missing=False, force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None, validation_metric=None, monitor=None): """Ke's revision: add iter_size. Trains the module parameters. Checkout `Module Tutorial <http://mxnet.io/tutorials/basic/module.html>`_ to see a end-to-end use-case. Parameters ---------- train_data : DataIter Train DataIter. eval_data : DataIter If not ``None``, will be used as validation set and the performance after each epoch will be evaluated. eval_metric : str or EvalMetric Defaults to 'accuracy'. The performance measure used to display during training. Other possible predefined metrics are: 'ce' (CrossEntropy), 'f1', 'mae', 'mse', 'rmse', 'top_k_accuracy'. epoch_end_callback : function or list of functions Each callback will be called with the current `epoch`, `symbol`, `arg_params` and `aux_params`. batch_end_callback : function or list of function Each callback will be called with a `BatchEndParam`. kvstore : str or KVStore Defaults to 'local'. optimizer : str or Optimizer Defaults to 'sgd'. optimizer_params : dict Defaults to ``(('learning_rate', 0.01),)``. The parameters for the optimizer constructor. The default value is not a dict, just to avoid pylint warning on dangerous default values. eval_end_callback : function or list of function These will be called at the end of each full evaluation, with the metrics over the entire evaluation set. eval_batch_end_callback : function or list of function These will be called at the end of each mini-batch during evaluation. initializer : Initializer The initializer is called to initialize the module parameters when they are not already initialized. arg_params : dict Defaults to ``None``, if not ``None``, should be existing parameters from a trained model or loaded from a checkpoint (previously saved model). In this case, the value here will be used to initialize the module parameters, unless they are already initialized by the user via a call to `init_params` or `fit`. `arg_params` has a higher priority than `initializer`. aux_params : dict Defaults to ``None``. Similar to `arg_params`, except for auxiliary states. allow_missing : bool Defaults to ``False``. Indicates whether to allow missing parameters when `arg_params` and `aux_params` are not ``None``. If this is ``True``, then the missing parameters will be initialized via the `initializer`. force_rebind : bool Defaults to ``False``. Whether to force rebinding the executors if already bound. force_init : bool Defaults to ``False``. Indicates whether to force initialization even if the parameters are already initialized. begin_epoch : int Defaults to 0. Indicates the starting epoch. Usually, if resumed from a checkpoint saved at a previous training phase at epoch N, then this value should be N+1. num_epoch : int Number of epochs for training. Examples -------- >>> # An example of using fit for training. >>> # Assume training dataIter and validation dataIter are ready >>> # Assume loading a previously checkpointed model >>> sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3) >>> mod.fit(train_data=train_dataiter, eval_data=val_dataiter, optimizer='sgd', ... optimizer_params={'learning_rate':0.01, 'momentum': 0.9}, ... arg_params=arg_params, aux_params=aux_params, ... eval_metric='acc', num_epoch=10, begin_epoch=3) """ assert num_epoch is not None, 'please specify number of epochs' self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label, for_training=True, force_rebind=force_rebind, grad_req='add') if monitor is not None: self.install_monitor(monitor) self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params, allow_missing=allow_missing, force_init=force_init) self.init_optimizer(kvstore=kvstore, optimizer=optimizer, optimizer_params=optimizer_params) if validation_metric is None: validation_metric = eval_metric if not isinstance(eval_metric, metric.EvalMetric): eval_metric = metric.create(eval_metric) annealing_steps = 0 # number of current annealing steps in current epoch redo_training = 0 # Flag to redo training / resample val_list = [] # list of validation results per annealing step cur_val = 0 target_prec = 50 #Note: we want to identify the best cluster of images / training sets with a low percentage ################################################################################ # training loop ################################################################################ for epoch in range(begin_epoch, num_epoch): tic = time.time() eval_metric.reset() nbatch = 0 if redo_training: annealing_steps = annealing_steps + 1 self.logger.info('Redoing training to meet criteria = %d', annealing_steps) #sroidb = train_data.roidb #passthrough test atick = time.time() iterdiff = 1.0 # Check if we've stagnated if len(val_list) > 2: itermean = (val_list[-1] + val_list[-2] + val_list[-3]) / 3 iterdiff = abs(itermean - val_list[-1]) self.logger.info('Last 3 samples have diff of: %f', iterdiff) if iterdiff < 0.01: self.logger.info( 'Reached a stagnated annealing criteria, dumping current samples' ) # Do something drastic # Lets try to instantly use the original db sroidb = ogdb # Try to read in another random subset #sroidb = sample_roidb(ogdb, 25) # Sample with removal else: # Continue as usual # Select a new random subset newroidb = sample_roidb(ogdb, 15) # Without removal, this is 10% # Append old with new sroidb = append_roidb(train_data.roidb, newroidb) # Create new training data instance by passing most of previous arguments and new random db train_data2 = AnchorLoader( train_data.feat_sym, sroidb, train_data.batch_size, train_data.shuffle, train_data.ctx, train_data.work_load_list, train_data.feat_stride, train_data.anchor_scales, train_data.anchor_ratios, train_data.aspect_grouping, nThreads=default.prefetch_thread_num) # Overwrite old train_data with the new one train_data = train_data2 data_iter = iter(train_data) atock = time.time() self.logger.info('Annealing[%d] Time cost=%.3f', annealing_steps, (atock - atick)) else: data_iter = iter(train_data) annealing_steps = 0 val_list = [] #target_prec=cur_val+5 target_prec = target_prec + 5 end_of_batch = False next_data_batch = next(data_iter) while not end_of_batch: data_batch = next_data_batch if monitor is not None: monitor.tic() # self.forward_backward(data_batch) self.forward(data_batch, is_train=True, grad_req='add') self.backward() if nbatch % iter_size == 0: # update every iter_size batches self.update() for g in self._curr_module._exec_group.grad_arrays: for g1 in g: if g1 is not None: g1[:] = 0. try: # pre fetch next batch next_data_batch = next(data_iter) self.prepare(next_data_batch) except StopIteration: end_of_batch = True self.update_metric(eval_metric, data_batch.label) if monitor is not None: monitor.toc_print() if batch_end_callback is not None: batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, eval_metric=eval_metric, locals=locals()) for callback in _as_list(batch_end_callback): callback(batch_end_params) nbatch += 1 # one epoch of training is finished for name, val in eval_metric.get_name_value(): self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val) toc = time.time() self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic)) #print('Epoch[%d] Time cost=%.3f', epoch, (toc-tic)) # sync aux params across devices arg_params, aux_params = self.get_params() self.set_params(arg_params, aux_params) if epoch_end_callback is not None: for callback in _as_list(epoch_end_callback): cur_val = callback(epoch, self.symbol, arg_params, aux_params) self.logger.info('Returned Validation=%f', val) val_list.append(val) #---------------------------------------- # evaluation on validation set if eval_data: self.logger.info('Evaluating data') res = self.score(eval_data, validation_metric, score_end_callback=eval_end_callback, batch_end_callback=eval_batch_end_callback, epoch=epoch) #TODO: pull this into default for name, val in res: self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val) #---------- # Check epoch if it falls within the validation threshold if cur_val < target_prec: # Evaluate list of precision/validation results first #val_list print(eval_data) #else redo_training = 1 self.logger.info('Retraining data=%f', val) else: redo_training = 0 self.logger.info('Annealing steps=%f', annealing_steps) # end of 1 epoch, reset the data-iter for another epoch train_data.reset()
def fit(self, train_data, eval_data=None, eval_metric='acc', epoch_end_callback=None, batch_end_callback=None, kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01), ), eval_end_callback=None, eval_batch_end_callback=None, initializer=Uniform(0.01), arg_params=None, aux_params=None, allow_missing=False, force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None, validation_metric=None, monitor=None, sparse_row_id_fn=None, accuracy_target=1.0, eval_frequency=1, eval_offset=0, logger=None): assert num_epoch is not None, 'please specify number of epochs' if 'horovod' in kvstore: rank = hvd.rank() local_rank = hvd.local_rank() else: rank = 0 local_rank = 0 profiler_on = os.getenv('RESNET50_PROFILING', False) and (rank == 0) if profiler_on: self.logger.info("Profiling is enabled") stop_iter = int(os.getenv('RESNET50_STOP_ITERATION', '0')) if stop_iter > 0: self.logger.info( "Training will stop at iteration {} of the first epoch".format( stop_iter)) self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label, for_training=True, force_rebind=force_rebind) if monitor is not None: self.install_monitor(monitor) self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params, allow_missing=allow_missing, force_init=force_init) self.init_optimizer(kvstore=kvstore, optimizer=optimizer, optimizer_params=optimizer_params) if validation_metric is None: validation_metric = eval_metric if not isinstance(eval_metric, mx.metric.EvalMetric): eval_metric = mx.metric.create(eval_metric) block_epoch_start = begin_epoch block_epoch_count = eval_offset + 1 - (begin_epoch % eval_frequency) if block_epoch_count < 0: block_epoch_count += eval_frequency mll.block_start(block_epoch_start + 1, count=block_epoch_count) if profiler_on: mx.profiler.set_config(profile_symbolic=True, profile_imperative=True, profile_memory=False, profile_api=True, filename='resnet50_profile.json', aggregate_stats=True) mx.profiler.set_state('run') ################################################################################ # training loop ################################################################################ for epoch in range(begin_epoch, num_epoch): mll.epoch_start(epoch + 1) tic = time.time() eval_metric.reset() nbatch = 0 early_stop = False data_iter = iter(train_data) end_of_batch = False next_data_batch = next(data_iter) while not end_of_batch: data_batch = next_data_batch if monitor is not None: monitor.tic() self.forward_backward(data_batch) self.update() if isinstance(data_batch, list): self.update_metric(eval_metric, [db.label for db in data_batch], pre_sliced=True) else: self.update_metric(eval_metric, data_batch.label) try: # pre fetch next batch next_data_batch = next(data_iter) self.prepare(next_data_batch, sparse_row_id_fn=sparse_row_id_fn) except StopIteration: end_of_batch = True if monitor is not None: monitor.toc_print() if end_of_batch: eval_name_vals = eval_metric.get_global_name_value() if batch_end_callback is not None: batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, eval_metric=eval_metric, locals=locals()) for callback in _as_list(batch_end_callback): callback(batch_end_params) nbatch += 1 if stop_iter > 0 and nbatch >= stop_iter: early_stop = True self.logger.info( "Training stopped at {} iteration. Clear RESNET50_STOP_ITERATION if it's not itended." .format(stop_iter)) break if early_stop: break mll.epoch_stop(epoch + 1) # one epoch of training is finished if rank == 0: for name, val in eval_name_vals: self.logger.info('Rank[%d] Epoch[%d] Train-%s=%f', rank, epoch, name, val) toc = time.time() self.logger.info('Rank[%d] Epoch[%d] Time cost=%.3f', rank, epoch, (toc - tic)) # sync aux params across devices arg_params, aux_params = self.get_params() self.set_params(arg_params, aux_params) #---------------------------------------- # evaluation on validation set if eval_data is not None and ((epoch % eval_frequency == eval_offset) or (epoch + 1 == num_epoch)): mll.eval_start(epoch + 1, sync=True) res = self.score(eval_data, [validation_metric, CorrectCount(), TotalCount()], score_end_callback=eval_end_callback, batch_end_callback=eval_batch_end_callback, epoch=epoch) #TODO: pull this into default if rank == 0: for name, val in res: self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val) # temporarily add these two metrics for debugging, can be removed before submission res = dict(res) correct_count = res['correct-count'] total_count = res['total-count'] if 'horovod' in kvstore: correct_count = allreduce(correct_count) total_count = allreduce(total_count) acc = correct_count / total_count mll.eval_stop(epoch + 1) mll.eval_accuracy(epoch + 1, acc) mll.block_stop(block_epoch_start + 1) if acc > accuracy_target: mll.run_stop(status='success') return if epoch < num_epoch - 1: block_epoch_start = epoch + 1 block_epoch_count = num_epoch - epoch - 1 if block_epoch_count > eval_frequency: block_epoch_count = eval_frequency mll.block_start(block_epoch_start + 1, count=block_epoch_count) # end of 1 epoch, reset the data-iter for another epoch train_data.reset() if profiler_on: mx.profiler.set_state('stop') print(mx.profiler.dumps()) mll.run_stop(status='aborted')
def __Train(self): begin_epoch = 0 num_epoch = 1000 eval_metric = mx.metric.np(Accuracy, allow_extra_outputs=True) validation_metric = None batch_end_callback = mx.callback.Speedometer(self.Model_Batch_Size, frequent=10), epoch_end_callback = mx.callback.do_checkpoint( "Epoch_Saved/ocr_checkpoint", period=1) eval_end_callback = None eval_batch_end_callback = None if validation_metric is None: validation_metric = eval_metric if not isinstance(eval_metric, mx.metric.EvalMetric): eval_metric = mx.metric.create(eval_metric) ################################################################################ # Training Loop # ################################################################################ for epoch in range(begin_epoch, num_epoch): tic = time.time() eval_metric.reset() nbatch = 0 data_iter = iter(self.Data_Train_Iter) end_of_batch = False next_data_batch = next(data_iter) while not end_of_batch: data_batch = next_data_batch self.Model.forward_backward(data_batch) self.Model.update() try: # pre fetch next batch next_data_batch = next(data_iter) self.Model.prepare(next_data_batch) except StopIteration: end_of_batch = True self.Model.update_metric(eval_metric, data_batch.label) if batch_end_callback is not None: batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, eval_metric=eval_metric, locals=locals()) for callback in mx.module.base_module._as_list( batch_end_callback): callback(batch_end_params) nbatch += 1 # one epoch of training is finished for name, val in eval_metric.get_name_value(): self.Model.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val) toc = time.time() self.Model.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic)) # sync aux params across devices arg_params, aux_params = self.Model.get_params() self.Model.set_params(arg_params, aux_params) if epoch_end_callback is not None: for callback in mx.module.base_module._as_list( epoch_end_callback): callback(epoch, self.Model.symbol, arg_params, aux_params) # ---------------------------------------- # evaluation on validation set if self.Data_Val_Iter: res = self.Model.score( self.Data_Val_Iter, validation_metric, score_end_callback=eval_end_callback, batch_end_callback=eval_batch_end_callback, epoch=epoch) # TODO: pull this into default for name, val in res: self.Model.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val) # end of 1 epoch, reset the data-iter for another epoch self.Data_Train_Iter.reset()
logging.info(msg, param.epoch, count, speed, *sum(name_value, ())) else: logging.info( "Iter[%d] Batch [%d]\tSpeed: %.2f samples/sec", param.epoch, count, speed) self.tic = time.time() else: # print(f'[3] init: {self.init} {self.last_count} {count} {self.frequent}') self.init = True self.tic = time.time() if __name__ == '__main__': import mxnet as mx from mxnet.model import BatchEndParam acc = mx.metric.Accuracy() metric = mx.metric.MSE() params = BatchEndParam(epoch=0, nbatch=1, eval_metric=acc, locals=locals()) batch_size = 2 acc.update(labels=mx.nd.uniform(shape=(batch_size, 512)), preds=mx.nd.uniform(shape=(batch_size, 512))) cb = Speedometer(batch_size=batch_size, frequent=1) cb(params) cb(params) cb(params)
def fit(self, train_data, eval_data=None, eval_metric='acc', epoch_end_callback=None, batch_end_callback=None, kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01), ), eval_end_callback=None, eval_batch_end_callback=None, initializer=Uniform(0.01), arg_params=None, aux_params=None, allow_missing=False, force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None, validation_metric=None, monitor=None, sparse_row_id_fn=None, profile=False): """Trains the module parameters. Checkout `Module Tutorial <http://mxnet.io/tutorials/basic/module.html>`_ to see a end-to-end use-case. Parameters ---------- train_data : DataIter Train DataIter. eval_data : DataIter If not ``None``, will be used as validation set and the performance after each epoch will be evaluated. eval_metric : str or EvalMetric Defaults to 'accuracy'. The performance measure used to display during training. Other possible predefined metrics are: 'ce' (CrossEntropy), 'f1', 'mae', 'mse', 'rmse', 'top_k_accuracy'. epoch_end_callback : function or list of functions Each callback will be called with the current `epoch`, `symbol`, `arg_params` and `aux_params`. batch_end_callback : function or list of function Each callback will be called with a `BatchEndParam`. kvstore : str or KVStore Defaults to 'local'. optimizer : str or Optimizer Defaults to 'sgd'. optimizer_params : dict Defaults to ``(('learning_rate', 0.01),)``. The parameters for the optimizer constructor. The default value is not a dict, just to avoid pylint warning on dangerous default values. eval_end_callback : function or list of function These will be called at the end of each full evaluation, with the metrics over the entire evaluation set. eval_batch_end_callback : function or list of function These will be called at the end of each mini-batch during evaluation. initializer : Initializer The initializer is called to initialize the module parameters when they are not already initialized. arg_params : dict Defaults to ``None``, if not ``None``, should be existing parameters from a trained model or loaded from a checkpoint (previously saved model). In this case, the value here will be used to initialize the module parameters, unless they are already initialized by the user via a call to `init_params` or `fit`. `arg_params` has a higher priority than `initializer`. aux_params : dict Defaults to ``None``. Similar to `arg_params`, except for auxiliary states. allow_missing : bool Defaults to ``False``. Indicates whether to allow missing parameters when `arg_params` and `aux_params` are not ``None``. If this is ``True``, then the missing parameters will be initialized via the `initializer`. force_rebind : bool Defaults to ``False``. Whether to force rebinding the executors if already bound. force_init : bool Defaults to ``False``. Indicates whether to force initialization even if the parameters are already initialized. begin_epoch : int Defaults to 0. Indicates the starting epoch. Usually, if resumed from a checkpoint saved at a previous training phase at epoch N, then this value should be N+1. num_epoch : int Number of epochs for training. sparse_row_id_fn : A callback function The function takes `data_batch` as an input and returns a dict of str -> NDArray. The resulting dict is used for pulling row_sparse parameters from the kvstore, where the str key is the name of the param, and the value is the row id of the param to pull. """ assert num_epoch is not None, 'please specify number of epochs' self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label, for_training=True, force_rebind=force_rebind) if monitor is not None: self.install_monitor(monitor) self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params, allow_missing=allow_missing, force_init=force_init) self.init_optimizer(kvstore=kvstore, optimizer=optimizer, optimizer_params=optimizer_params) if validation_metric is None: validation_metric = eval_metric if not isinstance(eval_metric, metric.EvalMetric): eval_metric = metric.create(eval_metric) ################################################################################ # training loop ################################################################################ for epoch in range(begin_epoch, num_epoch): tic = time.time() eval_metric.reset() nbatch = 0 data_iter = iter(train_data) end_of_batch = False next_data_batch = next(data_iter) while not end_of_batch: data_batch = next_data_batch if monitor is not None: monitor.tic() self.forward_backward(data_batch) self.update() if isinstance(data_batch, list): self.update_metric(eval_metric, [db.label for db in data_batch], pre_sliced=True) else: self.update_metric(eval_metric, data_batch.label) try: # pre fetch next batch next_data_batch = next(data_iter) self.prepare(next_data_batch, sparse_row_id_fn=sparse_row_id_fn) except StopIteration: end_of_batch = True if monitor is not None: monitor.toc_print() if end_of_batch: eval_name_vals = eval_metric.get_name_value() if batch_end_callback is not None: batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, eval_metric=eval_metric, locals=locals()) for callback in _as_list(batch_end_callback): callback(batch_end_params) nbatch += 1 if profile is True and nbatch == 10: self.logger.info("Profiling ends") import mxnet as mx mx.profiler.dump() # one epoch of training is finished for name, val in eval_name_vals: self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val) toc = time.time() self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic)) # sync aux params across devices arg_params, aux_params = self.get_params() self.set_params(arg_params, aux_params) if epoch_end_callback is not None and self._kvstore.rank == 0: for callback in _as_list(epoch_end_callback): callback(epoch, self.symbol, arg_params, aux_params) # end of 1 epoch, reset the data-iter for another epoch train_data.reset()
def mlperf_fit(self, train_data, eval_data=None, eval_metric='acc', epoch_end_callback=None, batch_end_callback=None, kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01), ), eval_end_callback=None, eval_batch_end_callback=None, initializer=Uniform(0.01), arg_params=None, aux_params=None, allow_missing=False, force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None, validation_metric=None, monitor=None, sparse_row_id_fn=None, eval_offset=0, eval_period=1, accuracy_threshold=1.0): assert num_epoch is not None, 'please specify number of epochs' self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label, for_training=True, force_rebind=force_rebind) if monitor is not None: self.install_monitor(monitor) self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params, allow_missing=allow_missing, force_init=force_init) self.init_optimizer(kvstore=kvstore, optimizer=optimizer, optimizer_params=optimizer_params) if validation_metric is None: validation_metric = eval_metric ########################################################################### # Adding Correct and Total Count metrics ########################################################################### if not isinstance(validation_metric, list): validation_metric = [validation_metric] validation_metric = mx.metric.create(validation_metric) if not isinstance(validation_metric, mx.metric.CompositeEvalMetric): vm = mx.metric.CompositeEvalMetric() vm.append(validation_metric) validation_metric = vm for m in [CorrectCount(), TotalCount()]: validation_metric.metrics.append(m) ########################################################################### if not isinstance(eval_metric, mx.metric.EvalMetric): eval_metric = mx.metric.create(eval_metric) mx_resnet_print(key=mlperf_log.TRAIN_LOOP) ################################################################################ # training loop ################################################################################ for epoch in range(begin_epoch, num_epoch): mx_resnet_print(key=mlperf_log.TRAIN_EPOCH, val=epoch) tic = time.time() eval_metric.reset() nbatch = 0 data_iter = iter(train_data) end_of_batch = False next_data_batch = next(data_iter) while not end_of_batch: data_batch = next_data_batch if monitor is not None: monitor.tic() self.forward_backward(data_batch) self.update() if isinstance(data_batch, list): self.update_metric(eval_metric, [db.label for db in data_batch], pre_sliced=True) else: self.update_metric(eval_metric, data_batch.label) try: # pre fetch next batch next_data_batch = next(data_iter) self.prepare(next_data_batch, sparse_row_id_fn=sparse_row_id_fn) except StopIteration: end_of_batch = True if monitor is not None: monitor.toc_print() if batch_end_callback is not None: batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, eval_metric=eval_metric, locals=locals()) for callback in _as_list(batch_end_callback): callback(batch_end_params) nbatch += 1 # one epoch of training is finished toc = time.time() if kvstore: if kvstore.rank == 0: self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic)) else: self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic)) # sync aux params across devices arg_params, aux_params = self.get_params() self.set_params(arg_params, aux_params) if epoch_end_callback is not None: for callback in _as_list(epoch_end_callback): callback(epoch, self.symbol, arg_params, aux_params) #---------------------------------------- # evaluation on validation set if eval_data and epoch % eval_period == eval_offset: mx_resnet_print(key=mlperf_log.EVAL_START) res = self.score(eval_data, validation_metric, score_end_callback=eval_end_callback, batch_end_callback=eval_batch_end_callback, epoch=epoch) #TODO: pull this into default if kvstore: if kvstore.rank == 0: for name, val in res: self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val) else: for name, val in res: self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val) res = dict(res) acc = [res['correct-count'], res['total-count']] acc = all_reduce(acc) acc = acc[0] / acc[1] mx_resnet_print(key=mlperf_log.EVAL_ACCURACY, val={ "epoch": epoch, "value": acc }) mx_resnet_print(key=mlperf_log.EVAL_STOP) if acc > accuracy_threshold: return epoch # end of 1 epoch, reset the data-iter for another epoch train_data.reset() return num_epoch
def fit(self, train_data, eval_data=None, eval_metric='acc', epoch_end_callback=None, batch_end_callback=None, kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01), ), eval_end_callback=None, eval_batch_end_callback=None, initializer=Uniform(0.01), arg_params=None, aux_params=None, allow_missing=False, force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None, validation_metric=None, monitor=None): assert num_epoch is not None, 'please specify number of epochs' self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label, for_training=True, force_rebind=force_rebind) if monitor is not None: self.install_monitor(monitor) self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params, allow_missing=allow_missing, force_init=force_init) self.init_optimizer(kvstore=kvstore, optimizer=optimizer, optimizer_params=optimizer_params) if validation_metric is None: validation_metric = eval_metric if not isinstance(eval_metric, metric.EvalMetric): eval_metric = metric.create(eval_metric) ####chris_arg if int(os.getenv("TASK_LIMIT", 0)) != 0: #为0时不分task限制,为1时分task但是每轮更新,为2时分task并但固定 get_task_cmd = "sh /home/ubuntu/tc.sh -l 1" else: self.logger.info("no_task_bandwidth_limit") get_task_cmd = "sh /home/ubuntu/tc.sh -l 0" os.system(get_task_cmd) delay_time = float(os.getenv("DELAY_TIME", 0.8)) ps_upload_bandwidth_part1 = int(os.getenv("PS_UPLOAD_BANDWIDTH1", 2000)) worker_upload_bandwidth_part1 = int( os.getenv("WORKER_UPLOAD_BANDWIDTH1", 2000)) ps_upload_bandwidth_part2 = int(os.getenv("PS_UPLOAD_BANDWIDTH2", 2000)) worker_upload_bandwidth_part2 = int( os.getenv("WORKER_UPLOAD_BANDWIDTH2", 2000)) tc_command = "sudo tc class change dev {} parent 1: classid 1:3 htb rate {}mbit ceil {}mbit && sudo tc class change dev {} parent 1: classid 1:4 htb rate {}mbit ceil {}mbit" ################################################################################ # training loop ################################################################################ for epoch in range(begin_epoch, num_epoch): tic = time.time() eval_metric.reset() nbatch = 0 data_iter = iter(train_data) end_of_batch = False next_data_batch = next(data_iter) while not end_of_batch: data_batch = next_data_batch if monitor is not None: monitor.tic() self.forward(data_batch, is_train=True) if int(os.getenv("TASK_LIMIT", 0)) == 1: ##first part bandwidth allocation ndarray.waitall() # self.logger.info("change bandwidth part1:, "+str(time.time())) x = str(ps_upload_bandwidth_part1) y = str(worker_upload_bandwidth_part1) cmd_up = tc_command.format("ens3", x, x, "ens3", y, y) cmd_down = tc_command.format("ifb0", y, y, "ifb0", x, x) os.system(cmd_up) # os.system(cmd_down) # self.logger.info("after forward, "+str(time.time())) self.backward() # self.logger.info("before update: "+str(time.time())) self.update() #异步执行的 if int(os.getenv("TASK_LIMIT", 0)) == 1: x = str(ps_upload_bandwidth_part2) y = str(worker_upload_bandwidth_part2) cmd_up = tc_command.format("ens3", x, x, "ens3", y, y) cmd_down = tc_command.format("ifb0", y, y, "ifb0", x, x) time.sleep(delay_time) ##second part bandwidth allocation # self.logger.info("change bandwidth part2:, "+str(time.time())) os.system(cmd_up) # os.system(cmd_down) try: # pre fetch next batch next_data_batch = next(data_iter) self.prepare(next_data_batch) except StopIteration: end_of_batch = True self.update_metric(eval_metric, data_batch.label) if monitor is not None: monitor.toc_print() if batch_end_callback is not None: batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, eval_metric=eval_metric, locals=locals()) for callback in _as_list(batch_end_callback): callback(batch_end_params) nbatch += 1 # one epoch of training is finished for name, val in eval_metric.get_name_value(): self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val) toc = time.time() self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic)) # sync aux params across devices arg_params, aux_params = self.get_params() self.set_params(arg_params, aux_params) if epoch_end_callback is not None: for callback in _as_list(epoch_end_callback): callback(epoch, self.symbol, arg_params, aux_params) #---------------------------------------- # evaluation on validation set if eval_data: res = self.score(eval_data, validation_metric, score_end_callback=eval_end_callback, batch_end_callback=eval_batch_end_callback, epoch=epoch) #TODO: pull this into default for name, val in res: self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val) # end of 1 epoch, reset the data-iter for another epoch train_data.reset()
def mlperf_fit(self, args, data_loader, epoch_size, eval_metric='acc', epoch_end_callback=None, batch_end_callback=None, kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01), ), explorer='linear', explorer_params=None, eval_end_callback=None, eval_batch_end_callback=None, initializer=Uniform(0.01), arg_params=None, aux_params=None, allow_missing=False, force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None, validation_metric=None, monitor=None, sparse_row_id_fn=None, eval_offset=0, eval_period=1, accuracy_threshold=1.0): assert num_epoch is not None, 'please specify number of epochs' if monitor is not None: self.install_monitor(monitor) self.init_optimizer(kvstore=kvstore, optimizer=optimizer, optimizer_params=optimizer_params) explorer = Explorer.create_explorer(name=explorer, optimizer=self._optimizer, explorer_params=explorer_params) #This mxnet can not use several optimizers without sgd series explorer.set_best_coeff(0) explorer.set_best_wd_coeff(0) explorer.set_best_cg(0) exp_freq = explorer_params['explore_freq'] exp_start_epoch = explorer_params['explore_start_epoch'] if validation_metric is None: validation_metric = eval_metric ########################################################################### # Adding Correct and Total Count metrics ########################################################################### if not isinstance(validation_metric, list): validation_metric = [validation_metric] validation_metric = mx.metric.create(validation_metric) if not isinstance(validation_metric, mx.metric.CompositeEvalMetric): vm = mx.metric.CompositeEvalMetric() vm.append(validation_metric) validation_metric = vm for m in [CorrectCount(), TotalCount()]: validation_metric.metrics.append(m) ########################################################################### if not isinstance(eval_metric, mx.metric.EvalMetric): eval_metric = mx.metric.create(eval_metric) try: world_rank = int(os.environ['OMPI_COMM_WORLD_RANK']) world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) except: world_rank = 0 world_size = 1 use_cval_data = explorer_params['add_one_fwd_epoch'] < num_epoch \ or explorer_params['no_augument_epoch'] < num_epoch best_rank = 0 self.prepare_states() mx_resnet_print(key=mlperf_constants.INIT_STOP, sync=True) mx_resnet_print(key=mlperf_constants.RUN_START, sync=True) # data iterators (train_data, eval_data, cval_data) = data_loader(args, kvstore) if 'dist' in args.kv_store and not 'async' in args.kv_store: logging.info('Resizing training data to %d batches per machine', epoch_size) # resize train iter to ensure each machine has same number of batches per epoch # if not, dist_sync can hang at the end with one machine waiting for other machines if not args.use_dali: train = mx.io.ResizeIter(train_data, epoch_size) block_epoch_start = begin_epoch block_epoch_count = eval_offset + 1 - (begin_epoch % eval_period) if block_epoch_count < 0: block_epoch_count += eval_period mx_resnet_print(key=mlperf_constants.BLOCK_START, metadata={ 'first_epoch_num': block_epoch_start + 1, 'epoch_count': block_epoch_count }) ################################################################################ # training loop ################################################################################ for epoch in range(begin_epoch, num_epoch): mx_resnet_print(key=mlperf_constants.EPOCH_START, metadata={'epoch_num': epoch + 1}) tic = time.time() eval_metric.reset() nbatch = 0 use_normal_data_batch = epoch < explorer_params['no_augument_epoch'] if not use_normal_data_batch: if world_rank == 0: self.logger.info('use non-augumented batch') end_of_batch = False if use_normal_data_batch: data_iter = iter(train_data) next_data_batch = next(data_iter) else: cval_iter = iter(cval_data) next_cval_batch = next(cval_iter) smooth_decay = explorer_params['smooth_decay'] if not smooth_decay: explorer.apply_lr_decay_epoch(epoch) explorer.apply_wd_decay_epoch(epoch) explorer.set_mom(epoch) while not end_of_batch: if use_normal_data_batch: data_batch = next_data_batch else: cval_batch = next_cval_batch if monitor is not None: monitor.tic() if use_normal_data_batch: self.forward_backward(data_batch) else: self.forward_backward(cval_batch) if smooth_decay: explorer.apply_lr_decay_iter() explorer.apply_wd_decay_iter() explorer.apply_wd_warmup() explorer.apply_burn_in() use_explorer = (epoch == 0 and nbatch == 0) or (epoch >= exp_start_epoch and nbatch % exp_freq == 0) if use_explorer: explorer.set_tmp_coeff(world_rank) explorer.set_tmp_wd_coeff(world_rank) explorer.set_tmp_cg(world_rank) explorer.set_best_coeff(0) explorer.set_best_wd_coeff(0) explorer.set_best_cg(world_rank) self.update() if use_normal_data_batch: if isinstance(data_batch, list): self.update_metric(eval_metric, [db.label for db in data_batch], pre_sliced=True) else: self.update_metric(eval_metric, data_batch.label) else: if isinstance(cval_batch, list): self.update_metric(eval_metric, [db.label for db in cval_batch], pre_sliced=True) else: self.update_metric(eval_metric, cval_batch.label) if use_normal_data_batch: try: # pre fetch next batch next_data_batch = next(data_iter) except StopIteration: end_of_batch = True else: try: # pre fetch next cval batch next_cval_batch = next(cval_iter) except StopIteration: end_of_batch = True if use_normal_data_batch: if not end_of_batch: self.prepare(next_data_batch, sparse_row_id_fn=sparse_row_id_fn) else: if not end_of_batch: self.prepare(next_cval_batch, sparse_row_id_fn=sparse_row_id_fn) if monitor is not None: monitor.toc_print() if batch_end_callback is not None: batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, eval_metric=eval_metric, locals=locals()) for callback in _as_list(batch_end_callback): callback(batch_end_params) nbatch += 1 mx_resnet_print(key=mlperf_constants.EPOCH_STOP, metadata={"epoch_num": epoch + 1}) # one epoch of training is finished toc = time.time() if kvstore: if kvstore.rank == 0: self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic)) else: self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic)) # sync aux params across devices #arg_params, aux_params = self.get_params() #self.set_params(arg_params, aux_params) if epoch_end_callback is not None: for callback in _as_list(epoch_end_callback): callback(epoch, self.symbol, arg_params, aux_params) #---------------------------------------- # evaluation on validation set if eval_data and epoch >= eval_offset and ( epoch - eval_offset) % eval_period == 0: mx_resnet_print(key=mlperf_constants.EVAL_START, metadata={'epoch_num': epoch + 1}) res = self.score(eval_data, validation_metric, score_end_callback=eval_end_callback, batch_end_callback=eval_batch_end_callback, epoch=epoch) #TODO: pull this into default if kvstore: if kvstore.rank == 0: for name, val in res: self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val) else: for name, val in res: self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val) res = dict(res) acc = [res['correct-count'], res['total-count']] acc = all_reduce(acc) acc = acc[0] / acc[1] mx_resnet_print(key=mlperf_constants.EVAL_STOP, metadata={'epoch_num': epoch + 1}) mx_resnet_print(key=mlperf_constants.EVAL_ACCURACY, val=acc, metadata={'epoch_num': epoch + 1}) mx_resnet_print( key=mlperf_constants.BLOCK_STOP, metadata={'first_epoch_num': block_epoch_start + 1}) if acc > accuracy_threshold: mx_resnet_print(key=mlperf_constants.RUN_STOP, metadata={'status': 'success'}) return epoch if epoch < (num_epoch - 1): block_epoch_start = epoch + 1 block_epoch_count = num_epoch - epoch - 1 if block_epoch_count > eval_period: block_epoch_count = eval_period mx_resnet_print(key=mlperf_constants.BLOCK_START, metadata={ 'first_epoch_num': block_epoch_start + 1, 'epoch_count': block_epoch_count }) # end of 1 epoch, reset the data-iter for another epoch if use_normal_data_batch: train_data.reset() else: cval_data.reset() mx_resnet_print(key=mlperf_constants.RUN_STOP, metadata={'status': 'aborted'}) return num_epoch
def fit(self, train_data, eval_data=None, eval_metric='acc', epoch_end_callback=None, batch_end_callback=None, kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01),), eval_end_callback=None, eval_batch_end_callback=None, initializer=Uniform(0.01), arg_params=None, aux_params=None, allow_missing=False, force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None, validation_metric=None, monitor=None, prefix=None, batches_checkpoint=None, num_batches_save_ckpt=2000): """Train the module parameters. Parameters ---------- train_data : DataIter eval_data : DataIter If not `None`, will be used as validation set and evaluate the performance after each epoch. eval_metric : str or EvalMetric Default `'acc'`. The performance measure used to display during training. epoch_end_callback : function or list of function Each callback will be called with the current `epoch`, `symbol`, `arg_params` and `aux_params`. batch_end_callback : function or list of function Each callback will be called with a `BatchEndParam`. kvstore : str or KVStore Default `'local'`. optimizer : str or Optimizer Default `'sgd'` optimizer_params : dict Default `(('learning_rate', 0.01),)`. The parameters for the optimizer constructor. The default value is not a `dict`, just to avoid pylint warning on dangerous default values. eval_end_callback : function or list of function These will be called at the end of each full evaluation, with the metrics over the entire evaluation set. eval_batch_end_callback : function or list of function These will be called at the end of each minibatch during evaluation initializer : Initializer Will be called to initialize the module parameters if not already initialized. arg_params : dict Default `None`, if not `None`, should be existing parameters from a trained model or loaded from a checkpoint (previously saved model). In this case, the value here will be used to initialize the module parameters, unless they are already initialized by the user via a call to `init_params` or `fit`. `arg_params` has higher priority to `initializer`. aux_params : dict Default `None`. Similar to `arg_params`, except for auxiliary states. allow_missing : bool Default `False`. Indicate whether we allow missing parameters when `arg_params` and `aux_params` are not `None`. If this is `True`, then the missing parameters will be initialized via the `initializer`. force_rebind : bool Default `False`. Whether to force rebinding the executors if already binded. force_init : bool Default `False`. Indicate whether we should force initialization even if the parameters are already initialized. begin_epoch : int Default `0`. Indicate the starting epoch. Usually, if we are resuming from a checkpoint saved at a previous training phase at epoch N, then we should specify this value as N+1. num_epoch : int Number of epochs to run training. Examples -------- An example of using fit for training:: >>> #Assume training dataIter and validation dataIter are ready >>> mod.fit(train_data=train_dataiter, eval_data=val_dataiter, optimizer_params={'learning_rate':0.01, 'momentum': 0.9}, num_epoch=10) """ assert num_epoch is not None, 'please specify number of epochs' self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label, for_training=True, force_rebind=force_rebind) if monitor is not None: self.install_monitor(monitor) self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params, allow_missing=allow_missing, force_init=force_init) self.init_optimizer(kvstore=kvstore, optimizer=optimizer, optimizer_params=optimizer_params) if validation_metric is None: validation_metric = eval_metric if not isinstance(eval_metric, metric.EvalMetric): eval_metric = metric.create(eval_metric) ################################################################################ # training loop ################################################################################ for epoch in range(begin_epoch, num_epoch): tic = time.time() eval_metric.reset() for nbatch, data_batch in enumerate(train_data): if monitor is not None: monitor.tic() self.forward_backward(data_batch) self.update() self.update_metric(eval_metric, data_batch.label) if monitor is not None: monitor.toc_print() if batch_end_callback is not None: batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, eval_metric=eval_metric, locals=locals()) for callback in _as_list(batch_end_callback): callback(batch_end_params) if batches_checkpoint is not None and nbatch != 0 and nbatch % num_batches_save_ckpt == 0: for callback in _as_list(epoch_end_callback): callback(epoch, self.symbol, arg_params, aux_params) # one epoch of training is finished for name, val in eval_metric.get_name_value(): self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val) toc = time.time() self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc-tic)) # sync aux params across devices arg_params, aux_params = self.get_params() self.set_params(arg_params, aux_params) if epoch_end_callback is not None: for callback in _as_list(epoch_end_callback): callback(epoch, self.symbol, arg_params, aux_params) if prefix is not None: self._curr_module.save_checkpoint(prefix, epoch + 1, save_optimizer_states=True) #---------------------------------------- # evaluation on validation set if eval_data: res = self.score(eval_data, validation_metric, score_end_callback=eval_end_callback, batch_end_callback=eval_batch_end_callback, epoch=epoch) #TODO: pull this into default for name, val in res: self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val) # end of 1 epoch, reset the data-iter for another epoch train_data.reset()
def fit(self, train_data, eval_data=None, eval_metric='acc', epoch_end_callback=None, batch_end_callback=None, kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01), ), eval_end_callback=None, iter_size=1, eval_batch_end_callback=None, initializer=Uniform(0.01), arg_params=None, aux_params=None, allow_missing=False, force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None, validation_metric=None, monitor=None): """Ke's revision: add iter_size. Trains the module parameters. Checkout `Module Tutorial <http://mxnet.io/tutorials/basic/module.html>`_ to see a end-to-end use-case. Parameters ---------- train_data : DataIter Train DataIter. eval_data : DataIter If not ``None``, will be used as validation set and the performance after each epoch will be evaluated. eval_metric : str or EvalMetric Defaults to 'accuracy'. The performance measure used to display during training. Other possible predefined metrics are: 'ce' (CrossEntropy), 'f1', 'mae', 'mse', 'rmse', 'top_k_accuracy'. epoch_end_callback : function or list of functions Each callback will be called with the current `epoch`, `symbol`, `arg_params` and `aux_params`. batch_end_callback : function or list of function Each callback will be called with a `BatchEndParam`. kvstore : str or KVStore Defaults to 'local'. optimizer : str or Optimizer Defaults to 'sgd'. optimizer_params : dict Defaults to ``(('learning_rate', 0.01),)``. The parameters for the optimizer constructor. The default value is not a dict, just to avoid pylint warning on dangerous default values. eval_end_callback : function or list of function These will be called at the end of each full evaluation, with the metrics over the entire evaluation set. eval_batch_end_callback : function or list of function These will be called at the end of each mini-batch during evaluation. initializer : Initializer The initializer is called to initialize the module parameters when they are not already initialized. arg_params : dict Defaults to ``None``, if not ``None``, should be existing parameters from a trained model or loaded from a checkpoint (previously saved model). In this case, the value here will be used to initialize the module parameters, unless they are already initialized by the user via a call to `init_params` or `fit`. `arg_params` has a higher priority than `initializer`. aux_params : dict Defaults to ``None``. Similar to `arg_params`, except for auxiliary states. allow_missing : bool Defaults to ``False``. Indicates whether to allow missing parameters when `arg_params` and `aux_params` are not ``None``. If this is ``True``, then the missing parameters will be initialized via the `initializer`. force_rebind : bool Defaults to ``False``. Whether to force rebinding the executors if already bound. force_init : bool Defaults to ``False``. Indicates whether to force initialization even if the parameters are already initialized. begin_epoch : int Defaults to 0. Indicates the starting epoch. Usually, if resumed from a checkpoint saved at a previous training phase at epoch N, then this value should be N+1. num_epoch : int Number of epochs for training. Examples -------- >>> # An example of using fit for training. >>> # Assume training dataIter and validation dataIter are ready >>> # Assume loading a previously checkpointed model >>> sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3) >>> mod.fit(train_data=train_dataiter, eval_data=val_dataiter, optimizer='sgd', ... optimizer_params={'learning_rate':0.01, 'momentum': 0.9}, ... arg_params=arg_params, aux_params=aux_params, ... eval_metric='acc', num_epoch=10, begin_epoch=3) """ assert num_epoch is not None, 'please specify number of epochs' self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label, for_training=True, force_rebind=force_rebind, grad_req='add') if monitor is not None: self.install_monitor(monitor) self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params, allow_missing=allow_missing, force_init=force_init) self.init_optimizer(kvstore=kvstore, optimizer=optimizer, optimizer_params=optimizer_params) if validation_metric is None: validation_metric = eval_metric if not isinstance(eval_metric, metric.EvalMetric): eval_metric = metric.create(eval_metric) ################################################################################ # training loop ################################################################################ for epoch in range(begin_epoch, num_epoch): tic = time.time() eval_metric.reset() nbatch = 0 data_iter = iter(train_data) end_of_batch = False next_data_batch = next(data_iter) while not end_of_batch: data_batch = next_data_batch if monitor is not None: monitor.tic() # self.forward_backward(data_batch) self.forward(data_batch, is_train=True, grad_req='add') self.backward() if nbatch % iter_size == 0: # update every iter_size batches self.update() for g in self._curr_module._exec_group.grad_arrays: for g1 in g: if g1 is not None: g1[:] = 0. try: # pre fetch next batch next_data_batch = next(data_iter) self.prepare(next_data_batch) except StopIteration: end_of_batch = True self.update_metric(eval_metric, data_batch.label) if monitor is not None: monitor.toc_print() if batch_end_callback is not None: batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, eval_metric=eval_metric, locals=locals()) for callback in _as_list(batch_end_callback): callback(batch_end_params) nbatch += 1 # one epoch of training is finished for name, val in eval_metric.get_name_value(): self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val) toc = time.time() self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic)) # sync aux params across devices arg_params, aux_params = self.get_params() self.set_params(arg_params, aux_params) if epoch_end_callback is not None: for callback in _as_list(epoch_end_callback): callback(epoch, self.symbol, arg_params, aux_params) #---------------------------------------- # evaluation on validation set if eval_data: res = self.score(eval_data, validation_metric, score_end_callback=eval_end_callback, batch_end_callback=eval_batch_end_callback, epoch=epoch) #TODO: pull this into default for name, val in res: self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val) # end of 1 epoch, reset the data-iter for another epoch train_data.reset()