Example #1
0
    def train_step(self, batch, state: State) -> dict:
        state.input_batch = batch[0]
        state.engine.raise_event(CustomEvents.STEP_START)
        state.batch = None

        self.train()
        self.optimizer.zero_grad()
        input, target = deep_to(batch, device=self.device, non_blocking=True)
        prediction = self.nn_module(input)
        loss = self.loss(prediction, target)
        loss.backward()
        self.optimizer.step()

        prediction = deep_detach(prediction)
        target = deep_detach(target)
        prediction = self.prediction_transform(prediction)

        state.prediction = prediction
        state.engine.raise_event(CustomEvents.STEP_COMPLETE)
        state.prediction = None

        return {
            'prediction': prediction,
            'target': target,
            'loss': loss.item()
        }
Example #2
0
    def test_val_step(self, linear_argus_model_instance, poly_batch):
        model = linear_argus_model_instance
        batch_size = poly_batch[0].shape[0]
        output = model.val_step(poly_batch, State())

        assert isinstance(output, dict)
        prediction, target = output['prediction'], output['target']
        assert isinstance(prediction, torch.Tensor)
        assert list(prediction.shape) == [batch_size, 1]
        assert isinstance(target, torch.Tensor)
        assert list(target.shape) == [batch_size, 1]
Example #3
0
    def test_train_step(self, linear_argus_model_instance, poly_batch):
        model = linear_argus_model_instance
        batch_size = poly_batch[0].shape[0]
        output = model.train_step(poly_batch,
                                  State(linear_argus_model_instance.test_step))

        assert isinstance(output, dict)
        prediction = output['prediction']
        target = output['target']
        loss = output['loss']
        assert isinstance(prediction, torch.Tensor)
        assert list(prediction.shape) == [batch_size, 1]
        assert isinstance(target, torch.Tensor)
        assert list(target.shape) == [batch_size, 1]
        assert isinstance(loss, float)
Example #4
0
 def epoch_complete(self, state: State):
     assert self.monitor in state.metrics,\
         f"Monitor '{self.monitor}' metric not found in state"
     current_value = state.metrics[self.monitor]
     if self.better_comp(current_value, self.best_value):
         self.best_value = current_value
         self.wait = 0
     else:
         self.wait += 1
         if self.wait >= self.patience:
             state.stopped = True
             state.logger.info(
                 f"Epoch {state.epoch}: Early stopping triggered, "
                 f"'{self.monitor}' didn't improve score {self.wait} epochs"
             )
Example #5
0
 def iteration_complete(self, state: State):
     state.saved_step_output = state.step_output
Example #6
0
def test_state_update(linear_argus_model_instance):
    state = State(linear_argus_model_instance.test_step, qwerty=42)
    assert state.qwerty == 42
    state.update(asdf=12)
    assert state.asdf == 12
Example #7
0
def test_state_update():
    state = State(qwerty=42)
    assert state.qwerty == 42
    state.update(asdf=12)
    assert state.asdf == 12