def __init__(self, *, model: nn.Module, num_batches: int, cur_epoch=0, device="cpu") -> None: super().__init__() self._model = model self._device = device if isinstance( device, torch.device) else torch.device(device) self._num_batches = num_batches self._cur_epoch = cur_epoch self.meters = MeterInterface() self.indicator = tqdm(range(self._num_batches), disable=not self.on_master()) self.__epocher_initialized__ = False self.__bind_trainer_done__ = False self._trainer = None
def test_meter(self): meters = MeterInterface() meters.register_meter("loss", AverageValueMeter()) with meters.focus_on("reg"): meters.register_meter("loss", AverageValueMeter()) with meters: for i in range(10): meters["loss"].add(1.0) with meters.focus_on("reg"): meters["loss"].add(10) meter_generator = meters.statistics() for g, meters in meter_generator: print(g, meters)
def configure_meters(self, meters: MeterInterface) -> MeterInterface: meters.register_meter("meter1", AverageValueMeter()) return meters
def configure_meters(self, meters: MeterInterface) -> MeterInterface: meters.register_meter("loss", AverageValueMeter()) with meters.focus_on("acc"): meters.register_meter("acc", AverageValueMeter()) return meters
def test_storage(self): storage = Storage() meters = MeterInterface() meters.register_meter("loss", AverageValueMeter()) with meters.focus_on("reg"): meters.register_meter("loss", AverageValueMeter()) meters.register_meter("loss2", AverageValueMeter()) print(storage.summary()) with storage: for epoch in range(10): with meters: for i in range(100): meters["loss"].add(epoch) with meters.focus_on("reg"): meters["loss"].add(epoch + 10) meters["loss2"].add(epoch + 5) statistics = meters.statistics() for g, group_dict in statistics: storage.put_group(g, epoch_result=group_dict, epoch=epoch) print(storage.summary()) meter_generator = meters.statistics() for g, meters in meter_generator: print(g, meters)
def test_del_meter(self): meters = MeterInterface() meters.register_meter("loss", AverageValueMeter()) with meters.focus_on("reg"): meters.register_meter("loss", AverageValueMeter()) meters.delete_meter("loss") with meters.focus_on("reg"): meters.delete_meter("loss") print(meters.groups())
class _Epocher(_DDPMixin, metaclass=ABCMeta): def __init__(self, *, model: nn.Module, num_batches: int, cur_epoch=0, device="cpu") -> None: super().__init__() self._model = model self._device = device if isinstance( device, torch.device) else torch.device(device) self._num_batches = num_batches self._cur_epoch = cur_epoch self.meters = MeterInterface() self.indicator = tqdm(range(self._num_batches), disable=not self.on_master()) self.__epocher_initialized__ = False self.__bind_trainer_done__ = False self._trainer = None @property def device(self): return self._device def _init(self, **kwargs): pass def init(self, **kwargs): self._init(**kwargs) self.configure_meters(self.meters) self.__epocher_initialized__ = True @contextmanager def _register_indicator(self): assert isinstance( self._num_batches, int ), f"self._num_batches must be provided as an integer, given {self._num_batches}." self.indicator.set_desc_from_epocher(self) yield self.indicator.close() self.indicator.print_result() @contextmanager def _register_meters(self): meters = self.meters yield meters meters.join() @abstractmethod def configure_meters(self, meters: MeterInterface) -> MeterInterface: meters.register_meter("lr", AverageValueListMeter()) return meters @abstractmethod def _run(self, **kwargs): pass def run(self, **kwargs): if not self.__epocher_initialized__: raise RuntimeError() self.to(self.device) # put all things into the same device with self._register_meters(), \ self._register_indicator(): return self._run(**kwargs) def get_metric(self): if not self.__epocher_initialized__: raise RuntimeError( f"{self.__class__.__name__} should be initialized by calling `init()` before." ) return dict(self.meters.statistics()) def get_score(self): raise NotImplementedError() def to(self, device: Union[torch.device, str] = torch.device("cpu")): if isinstance(device, str): device = torch.device(device) assert isinstance(device, torch.device) for n, m in self.__dict__.items(): if isinstance(m, nn.Module): m.to(device) self._device = device def set_trainer(self, trainer): self._trainer = weakref.proxy(trainer) self.__bind_trainer_done__ = True @property def trainer(self): if not self.__bind_trainer_done__: raise RuntimeError( f"{self.__class__.__name__} should call `set_trainer` first") return self._trainer