예제 #1
0
def numerical_check(_net, cfg):
    net.initialize()

    datas = etl(cfg)

    bp_loss_f = BP_LOSS_F
    loss_function = {}
    loss_function.update(bp_loss_f)
    from longling.ML.toolkit.monitor import MovingLoss
    from longling.ML.MxnetHelper.glue import module

    loss_monitor = MovingLoss(loss_function)

    # train check
    trainer = module.Module.get_trainer(_net,
                                        optimizer=cfg.optimizer,
                                        optimizer_params=cfg.optimizer_params,
                                        select=cfg.train_select)

    for epoch in range(0, 100):
        for _data in tqdm(datas):
            with autograd.record():
                bp_loss = fit_f(_net, _data, bp_loss_f, loss_function,
                                loss_monitor)
            assert bp_loss is not None
            bp_loss.backward()
            trainer.step(cfg.batch_size)
        print(eval_f(_net, datas))
        print("epoch-%d: %s" % (epoch, list(loss_monitor.items())))
예제 #2
0
    def toolbox_init(
        self,
        evaluation_formatter_parameters=None,
        validation_logger_mode="w",
        informer_silent=False,
    ):

        from longling.lib.clock import Clock
        from longling.lib.utilog import config_logging
        from longling.ML.toolkit.formatter import MultiClassEvalFormatter \
            as Formatter
        from longling.ML.toolkit.monitor import MovingLoss, \
            ConsoleProgressMonitor as ProgressMonitor

        self.toolbox = {
            "monitor": dict(),
            "timer": None,
            "formatter": dict(),
        }

        mod = self.mod
        cfg = self.mod.cfg

        # 4.1 todo 定义损失函数
        # bp_loss_f 定义了用来进行 back propagation 的损失函数,
        # 有且只能有一个,命名中不能为 *_\d+ 型

        assert self.loss_function is not None

        loss_monitor = MovingLoss(self.loss_function)

        # 4.1 todo 初始化一些训练过程中的交互信息
        timer = Clock()

        progress_monitor = ProgressMonitor(
            loss_index=[name for name in self.loss_function],
            end_epoch=cfg.end_epoch - 1,
            silent=informer_silent)

        validation_logger = config_logging(
            filename=os.path.join(cfg.model_dir, "result.log"),
            logger="%s-validation" % cfg.model_name,
            mode=validation_logger_mode,
            log_format="%(message)s",
        )

        # set evaluation formatter
        evaluation_formatter_parameters = {} \
            if evaluation_formatter_parameters is None \
            else evaluation_formatter_parameters

        evaluation_formatter = Formatter(
            logger=validation_logger,
            dump_file=mod.cfg.validation_result_file,
            **evaluation_formatter_parameters)

        self.toolbox["monitor"]["loss"] = loss_monitor
        self.toolbox["monitor"]["progress"] = progress_monitor
        self.toolbox["timer"] = timer
        self.toolbox["formatter"]["evaluation"] = evaluation_formatter