def __init__(self, model, opt, loss_func, data, metrics=None, callbacks=None, device=None): """ Args: model: pytorch 模型 opt: 优化器 loss_func: 损失函数 data: DataBunch metrics: 性能评价指标或评价指标列表 callbacks: callbacks对象或类,或者其列表 device: cpu or gpu device """ self.model, self.opt, self.data, self.loss_func = model, opt, data, loss_func self.state = 'train' # 'train', 'val', 'test' self.messages = {} # 存放需要在不同callbacks之间共享的信息 self.epoch = 0 self.best_loss = float('inf') cbs = tbox.listify(callbacks) # 添加一些必要的回调 cbs.append(cbks.Recorder(as_attr=True)) cbs += [ cbks.StatesCallback(), cbks.CudaCallback(device=device), cbks.ProgressBarCallback(), cbks.AvgStatsCallback(metrics=metrics) ] cb_list = [] for cb in cbs: if isinstance(cb, cbks.Callback): cb_obj = cb else: cb_obj = cb() if cb_obj.as_attr: setattr(self, cb_obj.name, cb_obj) cb_obj.set_learner(self) cb_list.append(cb_obj) self.cbs = cb_list
def __init__(self, metrics, state): self.metrics, self.state = tbox.listify(metrics), state self.metric_names = ['loss'] + [m.__name__ for m in self.metrics] # 构造metrics 名称
def __init__(self, items): self.items = tbox.listify(items)