def numerical_check(_net, cfg): net.initialize() datas = etl(cfg) bp_loss_f = BP_LOSS_F loss_function = {} loss_function.update(bp_loss_f) from longling.ML.toolkit.monitor import MovingLoss from longling.ML.MxnetHelper.glue import module loss_monitor = MovingLoss(loss_function) # train check trainer = module.Module.get_trainer(_net, optimizer=cfg.optimizer, optimizer_params=cfg.optimizer_params, select=cfg.train_select) for epoch in range(0, 100): for _data in tqdm(datas): with autograd.record(): bp_loss = fit_f(_net, _data, bp_loss_f, loss_function, loss_monitor) assert bp_loss is not None bp_loss.backward() trainer.step(cfg.batch_size) print(eval_f(_net, datas)) print("epoch-%d: %s" % (epoch, list(loss_monitor.items())))
def toolbox_init( self, evaluation_formatter_parameters=None, validation_logger_mode="w", informer_silent=False, ): from longling.lib.clock import Clock from longling.lib.utilog import config_logging from longling.ML.toolkit.formatter import MultiClassEvalFormatter \ as Formatter from longling.ML.toolkit.monitor import MovingLoss, \ ConsoleProgressMonitor as ProgressMonitor self.toolbox = { "monitor": dict(), "timer": None, "formatter": dict(), } mod = self.mod cfg = self.mod.cfg # 4.1 todo 定义损失函数 # bp_loss_f 定义了用来进行 back propagation 的损失函数, # 有且只能有一个,命名中不能为 *_\d+ 型 assert self.loss_function is not None loss_monitor = MovingLoss(self.loss_function) # 4.1 todo 初始化一些训练过程中的交互信息 timer = Clock() progress_monitor = ProgressMonitor( loss_index=[name for name in self.loss_function], end_epoch=cfg.end_epoch - 1, silent=informer_silent) validation_logger = config_logging( filename=os.path.join(cfg.model_dir, "result.log"), logger="%s-validation" % cfg.model_name, mode=validation_logger_mode, log_format="%(message)s", ) # set evaluation formatter evaluation_formatter_parameters = {} \ if evaluation_formatter_parameters is None \ else evaluation_formatter_parameters evaluation_formatter = Formatter( logger=validation_logger, dump_file=mod.cfg.validation_result_file, **evaluation_formatter_parameters) self.toolbox["monitor"]["loss"] = loss_monitor self.toolbox["monitor"]["progress"] = progress_monitor self.toolbox["timer"] = timer self.toolbox["formatter"]["evaluation"] = evaluation_formatter