示例#1
0
    def __init__(self,
                 model,
                 opt,
                 loss_func,
                 data,
                 metrics=None,
                 callbacks=None,
                 device=None):
        """
        Args:
            model: pytorch 模型
            opt: 优化器
            loss_func: 损失函数
            data: DataBunch
            metrics: 性能评价指标或评价指标列表
            callbacks: callbacks对象或类,或者其列表
            device: cpu or gpu device
        """

        self.model, self.opt, self.data, self.loss_func = model, opt, data, loss_func
        self.state = 'train'  # 'train', 'val', 'test'
        self.messages = {}  # 存放需要在不同callbacks之间共享的信息
        self.epoch = 0
        self.best_loss = float('inf')

        cbs = tbox.listify(callbacks)
        # 添加一些必要的回调
        cbs.append(cbks.Recorder(as_attr=True))
        cbs += [
            cbks.StatesCallback(),
            cbks.CudaCallback(device=device),
            cbks.ProgressBarCallback(),
            cbks.AvgStatsCallback(metrics=metrics)
        ]

        cb_list = []
        for cb in cbs:
            if isinstance(cb, cbks.Callback):
                cb_obj = cb
            else:
                cb_obj = cb()

            if cb_obj.as_attr:
                setattr(self, cb_obj.name, cb_obj)
            cb_obj.set_learner(self)
            cb_list.append(cb_obj)

        self.cbs = cb_list
示例#2
0
 def __init__(self, metrics, state):
     self.metrics, self.state = tbox.listify(metrics), state
     self.metric_names = ['loss'] + [m.__name__
                                     for m in self.metrics]  # 构造metrics 名称
示例#3
0
 def __init__(self, items):
     self.items = tbox.listify(items)