Exemple #1
0
    def train_step(self, batch) -> dict:
        if not self.nn_module.training:
            self.nn_module.train()
        self.optimizer.zero_grad()
        input, target, noisy = self.prepare_batch(batch, self.device)
        prediction = self.nn_module(input)
        if self.aux_weights is not None:
            loss = 0
            for pred, weight in zip(prediction, self.aux_weights):
                loss += self.loss(pred, target, noisy) * weight
        else:
            loss = self.loss(prediction, target, noisy)
        if self.use_amp:
            with self.amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.optimizer.step()

        prediction = deep_detach(prediction)
        target = deep_detach(target)
        return {
            'prediction': self.prediction_transform(prediction[0]),
            'target': target,
            'loss': loss.item(),
            'noisy': noisy
        }
Exemple #2
0
    def train_step(self, batch, state: State) -> dict:
        state.input_batch = batch[0]
        state.engine.raise_event(CustomEvents.STEP_START)
        state.batch = None

        self.train()
        self.optimizer.zero_grad()
        input, target = deep_to(batch, device=self.device, non_blocking=True)
        prediction = self.nn_module(input)
        loss = self.loss(prediction, target)
        loss.backward()
        self.optimizer.step()

        prediction = deep_detach(prediction)
        target = deep_detach(target)
        prediction = self.prediction_transform(prediction)

        state.prediction = prediction
        state.engine.raise_event(CustomEvents.STEP_COMPLETE)
        state.prediction = None

        return {
            'prediction': prediction,
            'target': target,
            'loss': loss.item()
        }
Exemple #3
0
    def train_step(self, batch, state) -> dict:
        self.train()
        self.optimizer.zero_grad()

        for i, chunk_batch in enumerate(deep_chunk(batch, self.iter_size)):
            input, target = deep_to(chunk_batch,
                                    self.device,
                                    non_blocking=True)
            prediction = self.nn_module(input)
            loss = self.loss(prediction, target, training=True)
            if self.amp is not None:
                delay_unscale = i != (self.iter_size - 1)
                with self.amp.scale_loss(
                        loss, self.optimizer,
                        delay_unscale=delay_unscale) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

        self.optimizer.step()

        torch.cuda.synchronize()
        if self.model_ema is not None:
            with torch.no_grad():
                self.model_ema.update(self.nn_module)

        prediction = deep_detach(prediction)
        target = deep_detach(target)
        prediction = self.prediction_transform(prediction)
        return {
            'prediction': prediction,
            'target': target,
            'loss': loss.item()
        }
Exemple #4
0
    def train_step(self, batch, state) -> dict:
        self.train()
        self.optimizer.zero_grad()

        # Gradient accumulation
        for i, chunk_batch in enumerate(deep_chunk(batch, self.iter_size)):
            input, target = deep_to(chunk_batch, self.device, non_blocking=True)
            with torch.cuda.amp.autocast(enabled=self.amp):
                prediction = self.nn_module(input)
                loss = self.loss(prediction, target)
                loss = loss / self.iter_size

            if self.amp:
                self.scaler.scale(loss).backward()
            else:
                loss.backward()

        if self.amp:
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            self.optimizer.step()

        prediction = deep_detach(prediction)
        target = deep_detach(target)
        prediction = self.prediction_transform(prediction)
        return {
            'prediction': prediction,
            'target': target,
            'loss': loss.item()
        }
Exemple #5
0
    def train_step(self, batch, state) -> dict:
        self.train()
        self.optimizer.zero_grad()
        input, target = self.prepare_batch(batch, self.device)
        prediction = self.nn_module(input)
        loss = self.loss(prediction, target)
        loss.backward()
        self.optimizer.step()

        prediction = deep_detach(prediction)
        target = deep_detach(target)
        prediction = self.prediction_transform(prediction)
        return {
            'prediction': prediction,
            'target': target,
            'loss': loss.item()
        }
Exemple #6
0
def test_deep_detach(list_of_tensors, dict_of_tensors):
    def all_grad_is_none(sequence):
        return all([tensor.grad is None for tensor in sequence])

    assert all_grad_is_none(list_of_tensors)
    assert all_grad_is_none(dict_of_tensors.values())

    list_of_grad_tensors = [tensor * 2 for tensor in list_of_tensors]
    dict_of_grad_tensors = {key: tensor * 2 for key, tensor in dict_of_tensors.items()}
    loss = torch.tensor(0.)
    for tensor in [*list_of_grad_tensors, *dict_of_grad_tensors.values()]:
        loss += tensor.sum()
    loss.backward()

    assert all_grad_is_none(deep_detach(list_of_tensors))
    assert all_grad_is_none(deep_detach(dict_of_tensors).values())

    assert 'qwerty' == deep_detach('qwerty')
    assert None is deep_detach(None)
    assert deep_detach(True)
Exemple #7
0
    def train_step(self, batch, state: State) -> dict:
        """Perform a single train step.

        The method is used by :class:`argus.engine.Engine`.
        The train step includes input and target tensor transferring to the
        model device, forward pass, loss evaluation, backward pass, and the
        train batch prediction treating with a prediction_transform.

        Args:
            batch (tuple of 2 torch.Tensors: (input, target)): The input and
                target tensors to process.
            state (:class:`argus.engine.State`): The argus model state.

        Returns:
            dict: The train step results::

                {
                    'prediction': The train batch predictions,
                    'target': The train batch target data on the model device,
                    'loss': The loss function value
                }

        """
        self.train()
        self.optimizer.zero_grad()
        input, target = deep_to(batch, device=self.device, non_blocking=True)
        prediction = self.nn_module(input)
        loss = self.loss(prediction, target)
        loss.backward()
        self.optimizer.step()

        prediction = deep_detach(prediction)
        target = deep_detach(target)
        prediction = self.prediction_transform(prediction)
        return {
            'prediction': prediction,
            'target': target,
            'loss': loss.item()
        }