Exemplo n.º 1
0
class FakeLoop:
    callbacks: List[Callback]
    loss: float = float('inf')
    metric: float = -float('inf')
    epoch: int = 0
    stage: str = 'train'

    def __post_init__(self):
        self.event_log = []
        self.callbacks = CallbackManager(*self.callbacks)

    def handle_stage(self, stage):
        pass

    def fire(self, event):
        self.callbacks.on_event(self, event)
        self.event_log.append(event)

    def __call__(self, epochs):
        for _ in range(epochs):
            self.callbacks.on_epoch_start(self)
            for stage in ['train', 'val']:
                self.stage = stage
                self.callbacks.on_stage_start(self)
                self.handle_stage(stage)
                self.callbacks.on_stage_end(self)
            self.callbacks.on_epoch_end(self)
            self.epoch += 1
Exemplo n.º 2
0
    def __init__(
        self,
        model: nn.Module,
        optimizer: LazyOptimizer,
        loss_fn: Callable,
        metrics: Optional[Union[Callable, MetricStack,
                                Sequence[Callable]]] = None,
        callbacks: Sequence[Callback] = (),
        history: Optional[History] = None,
        device: Union[torch.device, str] = 'cpu',
    ):
        self.model = model
        self.loss_fn = loss_fn
        self.to(device)
        self.optimizer = optimizer

        self.metrics = metrics
        self.n_batches = 0
        self.batches_seen = 0
        self.stage = self.stages[0]
        self.should_stop = False
        self.history = history if history is not None else History()
        self.callbacks = CallbackManager(self.history, *callbacks)
        self.callbacks.on_registration(self)
        self.epoch = self.history.last_epoch + 1 if len(self.history) else 0
Exemplo n.º 3
0
def test_on_event(mocker):
    bim, bam = Bim(), Bam()
    bim_spy = mocker.spy(bim, 'on_event')
    bam_spy = mocker.spy(bam, 'on_event')

    callbacks = CallbackManager(bim, bam)

    fake_loop = 'loop'
    fake_event = 'fake'

    callbacks.on_event(fake_loop, fake_event)

    bim_spy.assert_called_once_with(fake_loop, fake_event)
    bam_spy.assert_called_once_with(fake_loop, fake_event)
Exemplo n.º 4
0
def test_getitem():
    a, b, c, d = Bim(), Bam(), Boop(), Bam()
    callbacks = CallbackManager(a, b, c, d)
    assert callbacks[0] is a
    assert callbacks[-1] is d

    assert isinstance(callbacks[:2], CallbackManager)
    assert len(callbacks[:2]) == 2
    assert callbacks[:2][-1] is b
Exemplo n.º 5
0
def test_standard_callback_methods(mocker, method):
    bim, bam = Bim(), Bam()
    bim_spy = mocker.spy(bim, method)
    bam_spy = mocker.spy(bam, method)

    callbacks = CallbackManager(bim, bam)
    method_ = getattr(callbacks, method)

    method_(1)  # not actually a loop... but whatever

    bim_spy.assert_called_once_with(1)
    bam_spy.assert_called_once_with(1)
Exemplo n.º 6
0
def test_bad_init():
    expected_msg = 'expected Callback type but got <class \'str\'>.'
    with pytest.raises(TypeError, match=expected_msg):
        CallbackManager(Bim(), 'bad')
Exemplo n.º 7
0
def test_len():
    callbacks = CallbackManager(Bim(), Bam(), Bam())
    assert len(callbacks) == 3
Exemplo n.º 8
0
def test_valid_init():
    CallbackManager(Bim(), Bam())
Exemplo n.º 9
0
 def __post_init__(self):
     self.event_log = []
     self.callbacks = CallbackManager(*self.callbacks)
Exemplo n.º 10
0
class Loop:
    """The simplest kind of loop for basic supervised learning.

    Note:
        If you have more custom things you'd like to to that cant be handled
        in callbacks it's recommended to subclass this and overide the  ``handle_batch`` method.
    """

    stages = ('train', 'val')

    def __init__(
        self,
        model: nn.Module,
        optimizer: LazyOptimizer,
        loss_fn: Callable,
        metrics: Optional[Union[Callable, MetricStack,
                                Sequence[Callable]]] = None,
        callbacks: Sequence[Callback] = (),
        history: Optional[History] = None,
        device: Union[torch.device, str] = 'cpu',
    ):
        self.model = model
        self.loss_fn = loss_fn
        self.to(device)
        self.optimizer = optimizer

        self.metrics = metrics
        self.n_batches = 0
        self.batches_seen = 0
        self.stage = self.stages[0]
        self.should_stop = False
        self.history = history if history is not None else History()
        self.callbacks = CallbackManager(self.history, *callbacks)
        self.callbacks.on_registration(self)
        self.epoch = self.history.last_epoch + 1 if len(self.history) else 0

    @property
    def optimizer(self):
        return self._optimizer

    @optimizer.setter
    def optimizer(self, optimizer):
        if isinstance(optimizer, LazyOptimizer) and not optimizer.initialized:
            optimizer.add_model(self.model)
            # add parameters from loss function if any exist...
            # this will be a Running object so we need to access inner function...
            optimizer.add_model(self.loss_fn.fn)
        self._optimizer = optimizer

    def to(self, device: Union[torch.device, str]):
        self.device = device
        self.model.to(self.device)
        self._loss_fn.to(self.device)

    @property
    def metrics(self):
        return self._metrics

    @metrics.setter
    def metrics(self,
                metrics: Optional[Union[Callable, MetricStack,
                                        Sequence[Callable]]] = None):
        if metrics:
            if not isinstance(metrics, MetricStack):
                if isinstance(metrics, dict):
                    metrics = MetricStack(**metrics)
                elif isinstance(metrics, Sequence):
                    metrics = MetricStack(*metrics)
                else:
                    metrics = MetricStack(metrics)

            self._metrics = Running(metrics)
            self._has_metrics = True
        else:
            self._metrics = Running(MetricStack())
            self._has_metrics = False

    @property
    def loss_fn(self):
        return self._loss_fn

    @loss_fn.setter
    def loss_fn(self, loss_fn=Callable):
        self._is_multihead_loss = isinstance(loss_fn, MultiHeadLoss)
        self._loss_agg_key = loss_fn.aggregate_key if self._is_multihead_loss else None
        self._loss_fn = Running(loss_fn)

    @property
    def loss(self):
        return self.loss_fn.average

    @property
    def metric(self):
        if self._has_metrics:
            return self.metrics.average
        return None

    def grad_context(self):
        if self.stage == 'train':
            return nullcontext()
        return torch.no_grad()

    def _requires_backward(self) -> bool:
        return self.stage == 'train'

    def optimizer_step(self):
        self.optimizer.step()

    def _optimizer_step(self):
        self.callbacks.on_step_start(self)
        self.optimizer_step()
        self.callbacks.on_step_end(self)

    def compute_loss(self, yhat, ytru, **kwargs):
        return self.loss_fn(yhat, ytru, **kwargs)

    def _compute_loss(self, yhat, ytru, **kwargs):
        self.callbacks.on_loss_start(self)
        loss = self.compute_loss(yhat, ytru, **kwargs)
        self.callbacks.on_loss_end(self)
        if self._is_multihead_loss:
            return loss[self._loss_agg_key]
        return loss

    def compute_metric(self, yhat, ytru, **kwargs):
        return self.metrics(yhat, ytru, **kwargs)

    def _compute_metric(self, yhat, ytru, **kwargs):
        with torch.no_grad():
            self.callbacks.on_metric_start(self)
            metric = self.compute_metric(yhat, ytru, **kwargs)
            self.callbacks.on_metric_end(self)
        return metric

    def _backward(self, loss, **kwargs):
        self.callbacks.on_backward_start(self)
        self.backward(loss)
        self.callbacks.on_backward_end(self)

    def backward(self, loss, **kwargs):
        loss.backward()

    def _forward(self, x, **kwargs):
        with self.grad_context():
            return self.forward(x, **kwargs)

    def forward(self, x, **kwargs):
        return self.model(x, **kwargs)

    def handle_batch(self, batch):
        # do forward pass and get loss
        self.optimizer.zero_grad()
        # unpack the batch... you can override this if your batch differs...
        x, y = batch
        x = x.to(self.device)
        y = y.to(self.device)

        y_hat = self._forward(x)
        loss = self._compute_loss(y_hat, y)
        if self._requires_backward():
            self._backward(loss)
            self._optimizer_step()
        if self._has_metrics:
            self._compute_metric(y_hat, y)

    def handle_batches(self, batches):
        self.n_batches = len(batches)
        self.batches_seen = 0
        for batch in batches:
            self.callbacks.on_batch_start(self)
            self.handle_batch(batch)
            self.batches_seen += 1
            self.callbacks.on_batch_end(self)

    def handle_stage(self, stage, batches):
        self.stage = stage
        if self.stage == 'train':
            self.model.train()
        else:
            self.model.eval()
        if self._has_metrics:
            self.metrics.reset()
        self.loss_fn.reset()
        self.callbacks.on_stage_start(self)
        self.handle_batches(batches)
        self.callbacks.on_stage_end(self)

    def fire(self, event):
        self.callbacks.on_event(self, event)

    def __call__(self, train, val, epochs: int = 1):
        self.should_stop = False
        for _ in range(epochs):
            self.callbacks.on_epoch_start(self)
            for stage, batches in zip(self.stages, (train, val)):
                self.handle_stage(stage, batches)
            self.callbacks.on_epoch_end(self)
            self.epoch += 1
            if self.should_stop:
                break

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}(model={self.model},'
                f' optimizer={self.optimizer}'
                f' loss_fn={self.loss_fn!r}'
                f' metrics={self.metrics!r}'
                f' callbacks={self.callbacks})')