def fit( self, net, begin_epoch, end_epoch, train_data, trainer, bp_loss_f, loss_function, eval_data=None, ctx=None, toolbox=None, **kwargs ): """ API for train Parameters ---------- net: HybridBlock The network which has been initialized or loaded from the existed model begin_epoch: int The begin epoch of this train procession end_epoch: int The end epoch of this train procession train_data: Iterable The data used for this train procession, NOTICE: should have been divided to batches trainer: The trainer used to update the parameters of the net bp_loss_f: dict with only one value and one key The function to compute the loss for the procession of back propagation loss_function: dict of function Some other measurement in addition to bp_loss_f eval_data: Iterable The data used for the evaluation at the end of each epoch, NOTICE: should have been divided to batches Default to ``None`` ctx: Context or list of Context Defaults to ``mx.cpu()``. toolbox: dict or None Default to ``None`` kwargs Returns ------- """ # 此方法可以直接使用 if ctx is not None: net = set_device(net, ctx) return self.epoch_loop( net=net, begin_epoch=begin_epoch, end_epoch=end_epoch, train_data=train_data, trainer=trainer, bp_loss_f=bp_loss_f, loss_function=loss_function, test_data=eval_data, toolbox=toolbox, **kwargs )
def eval_f(_net, test_data, ctx=None): ground_truth = [] prediction = [] pred_labels = [] if ctx is not None: _net = set_device(_net, ctx) #print(len(test_data)) #print(len(test_data[0])) #print(len(test_data[0][0])) for (data, data_mask, label, pick_index, label_mask) in tqdm(test_data, "evaluating"): with torch.no_grad(): output, _ = _net(data, data_mask) output = output[:, :-1] output = pick(output, pick_index.to(output.device)) pred = tensor2list(output) label = tensor2list(label) for i, length in enumerate(label_mask.numpy().tolist()): length = int(length) ground_truth.extend(label[i][:length]) prediction.extend(pred[i][:length]) pred_labels.extend([0 if p < 0.5 else 1 for p in pred[i][:length]]) #print(pred[i][:length]) #print([0 if p < 0.5 else 1 for p in pred[i][:length]]) auc = roc_auc_score(ground_truth, prediction) precision, recall, f1, _ = precision_recall_fscore_support( ground_truth, pred_labels) evaluation_result = {} evaluation_result.update( {"precision_%d" % i: precision[i] for i in range(len(precision))}) evaluation_result.update( {"recall_%d" % i: recall[i] for i in range(len(recall))}) evaluation_result.update({"f1_%d" % i: f1[i] for i in range(len(f1))}) evaluation_result.update({"auc": auc}) return evaluation_result
def numerical_check(_net, _cfg: Configuration, train_data, test_data, dump_result=False, reporthook=None, final_reporthook=None): # pragma: no cover ctx = _cfg.ctx _net = set_device(_net, ctx) bp_loss_f = get_bp_loss(ctx, **_cfg.loss_params) loss_function = {} loss_function.update(bp_loss_f) from longling.ML.toolkit import EpochEvalFMT as Formatter from longling.ML.toolkit import MovingLoss from tqdm import tqdm loss_monitor = MovingLoss(loss_function) progress_monitor = tqdm if dump_result: from longling import config_logging validation_logger = config_logging( filename=path_append(_cfg.model_dir, "result.log"), logger="%s-validation" % _cfg.model_name, mode="w", log_format="%(message)s", ) evaluation_formatter = Formatter( logger=validation_logger, dump_file=_cfg.validation_result_file, ) else: evaluation_formatter = Formatter() # train check from longling.ML.PytorchHelper.toolkit.optimizer import get_trainer trainer = get_trainer( _net, optimizer=_cfg.optimizer, optimizer_params=_cfg.optimizer_params, select=_cfg.train_select ) for epoch in range(_cfg.begin_epoch, _cfg.end_epoch): for batch_data in progress_monitor(train_data, "Epoch: %s" % epoch): fit_f( net=_net, batch_data=batch_data, trainer=trainer, bp_loss_f=bp_loss_f, loss_function=loss_function, loss_monitor=loss_monitor, ) if epoch % 1 == 0: msg, data = evaluation_formatter( epoch=epoch, loss_name_value=dict(loss_monitor.items()), eval_name_value=eval_f(_net, test_data, ctx=ctx), extra_info=None, dump=dump_result, ) print(msg) if reporthook is not None: reporthook(data) # optional, whether reset the loss at the end of each epoch loss_monitor.reset() if final_reporthook is not None: final_reporthook()
def net_initialize(net, model_ctx="cpu", **kwargs): from TKT.shared import set_device """初始化网络参数""" return set_device(net, model_ctx, **kwargs)