def __init__(self,
                 *,
                 model: nn.Module,
                 num_batches: int,
                 cur_epoch=0,
                 device="cpu") -> None:
        super().__init__()
        self._model = model
        self._device = device if isinstance(
            device, torch.device) else torch.device(device)
        self._num_batches = num_batches
        self._cur_epoch = cur_epoch

        self.meters = MeterInterface()
        self.indicator = tqdm(range(self._num_batches),
                              disable=not self.on_master())
        self.__epocher_initialized__ = False
        self.__bind_trainer_done__ = False
        self._trainer = None
    def test_meter(self):
        meters = MeterInterface()
        meters.register_meter("loss", AverageValueMeter())

        with meters.focus_on("reg"):
            meters.register_meter("loss", AverageValueMeter())

        with meters:
            for i in range(10):
                meters["loss"].add(1.0)
                with meters.focus_on("reg"):
                    meters["loss"].add(10)

        meter_generator = meters.statistics()
        for g, meters in meter_generator:
            print(g, meters)
 def configure_meters(self, meters: MeterInterface) -> MeterInterface:
     meters.register_meter("meter1", AverageValueMeter())
     return meters
Example #4
0
 def configure_meters(self, meters: MeterInterface) -> MeterInterface:
     meters.register_meter("loss", AverageValueMeter())
     with meters.focus_on("acc"):
         meters.register_meter("acc", AverageValueMeter())
     return meters
    def test_storage(self):
        storage = Storage()
        meters = MeterInterface()
        meters.register_meter("loss", AverageValueMeter())

        with meters.focus_on("reg"):
            meters.register_meter("loss", AverageValueMeter())
            meters.register_meter("loss2", AverageValueMeter())
        print(storage.summary())
        with storage:
            for epoch in range(10):
                with meters:
                    for i in range(100):
                        meters["loss"].add(epoch)
                        with meters.focus_on("reg"):
                            meters["loss"].add(epoch + 10)
                            meters["loss2"].add(epoch + 5)

                statistics = meters.statistics()
                for g, group_dict in statistics:
                    storage.put_group(g, epoch_result=group_dict, epoch=epoch)
            print(storage.summary())

        meter_generator = meters.statistics()
        for g, meters in meter_generator:
            print(g, meters)
    def test_del_meter(self):
        meters = MeterInterface()
        meters.register_meter("loss", AverageValueMeter())

        with meters.focus_on("reg"):
            meters.register_meter("loss", AverageValueMeter())
        meters.delete_meter("loss")
        with meters.focus_on("reg"):
            meters.delete_meter("loss")

        print(meters.groups())
class _Epocher(_DDPMixin, metaclass=ABCMeta):
    def __init__(self,
                 *,
                 model: nn.Module,
                 num_batches: int,
                 cur_epoch=0,
                 device="cpu") -> None:
        super().__init__()
        self._model = model
        self._device = device if isinstance(
            device, torch.device) else torch.device(device)
        self._num_batches = num_batches
        self._cur_epoch = cur_epoch

        self.meters = MeterInterface()
        self.indicator = tqdm(range(self._num_batches),
                              disable=not self.on_master())
        self.__epocher_initialized__ = False
        self.__bind_trainer_done__ = False
        self._trainer = None

    @property
    def device(self):
        return self._device

    def _init(self, **kwargs):
        pass

    def init(self, **kwargs):
        self._init(**kwargs)
        self.configure_meters(self.meters)
        self.__epocher_initialized__ = True

    @contextmanager
    def _register_indicator(self):
        assert isinstance(
            self._num_batches, int
        ), f"self._num_batches must be provided as an integer, given {self._num_batches}."

        self.indicator.set_desc_from_epocher(self)
        yield
        self.indicator.close()
        self.indicator.print_result()

    @contextmanager
    def _register_meters(self):
        meters = self.meters
        yield meters
        meters.join()

    @abstractmethod
    def configure_meters(self, meters: MeterInterface) -> MeterInterface:
        meters.register_meter("lr", AverageValueListMeter())
        return meters

    @abstractmethod
    def _run(self, **kwargs):
        pass

    def run(self, **kwargs):
        if not self.__epocher_initialized__:
            raise RuntimeError()
        self.to(self.device)  # put all things into the same device
        with self._register_meters(), \
            self._register_indicator():
            return self._run(**kwargs)

    def get_metric(self):
        if not self.__epocher_initialized__:
            raise RuntimeError(
                f"{self.__class__.__name__} should be initialized by calling `init()` before."
            )
        return dict(self.meters.statistics())

    def get_score(self):
        raise NotImplementedError()

    def to(self, device: Union[torch.device, str] = torch.device("cpu")):
        if isinstance(device, str):
            device = torch.device(device)
        assert isinstance(device, torch.device)
        for n, m in self.__dict__.items():
            if isinstance(m, nn.Module):
                m.to(device)
        self._device = device

    def set_trainer(self, trainer):
        self._trainer = weakref.proxy(trainer)
        self.__bind_trainer_done__ = True

    @property
    def trainer(self):
        if not self.__bind_trainer_done__:
            raise RuntimeError(
                f"{self.__class__.__name__} should call `set_trainer` first")
        return self._trainer