def train_step(
        self,
        train_loader: DataLoader,
        optimizer: Optimizer,
        scheduler: LRScheduler,
        device: str,
        loss_fn: Callable[[torch.Tensor, torch.Tensor], List[Tuple[str, torch.Tensor]]],
        acc_fn: Callable[
            [List[Tuple[str, torch.Tensor]], torch.Tensor],
            List[Tuple[str, torch.Tensor]],
        ],
        use_tqdm: bool,
    ) -> List[Tuple[str, torch.Tensor]]:
        """
        Performs a single train step.
            
        Note: 
            
            the losses and accuracies returned by the ``loss_fn`` and ``acc_fn`` are divided by the \
            number of batches in the dataset while recording them for an epoch (averaging). So make \
            sure any reduction in your functions are ``mean``.
            
        Args:
            train_loader (DataLoader): The ``DataLoader`` for the training data.
            optimizer (Optimizer): The optimizer to use.
            scheduler (LRScheduler): The LR scheduler to use.
            device (str): A valid pytorch device string.
            loss_fn (Callable[[torch.Tensor, torch.Tensor], List[Tuple[str, torch.Tensor]]]): The loss function to use. \
                the loss function should take in the predicted output of the model and target output from the dataset as \
                the arguments and return a list of tuples, in which the first element of each tuple is a label for the \
                loss and the second element is the loss value.
            acc_fn (Callable[[List[Tuple[str, torch.Tensor]], torch.Tensor], List[Tuple[str, torch.Tensor]]]): The accuracy \
                function to use. The function should take in two arguments, first, a list of tuples, where the first element \ 
                of each tuple is the label for the loss and the second element is the loss value, and the second argument \
                the target output from the dataset. The function should return a list of tuples, first element of the tuple \
                should be the label of the accuracy and the second element should be the accuracy value.
            use_tqdm (bool): If True, uses tqdm instead of a keras style progress bar (``pkbar``).

        Returns:
            List[Tuple[str, torch.Tensor]]: A list containing tuples in which the first element of the tuple is the label \
                describing the value and the second element is the value itself.
        """

        # setting model in train mode
        self.model.train()

        # creating progress bar
        if use_tqdm:
            pbar = tqdm(train_loader)
            iterator = pbar
        else:
            pbar = Kbar(len(train_loader), stateful_metrics=["loss", "accuracy"])
            iterator = train_loader

        # defining variables
        correct = 0
        processed = 0
        train_losses: np.ndarray = None
        train_accs: np.ndarray = None
        for batch_idx, (data, target) in enumerate(iterator):
            # casting to device
            data, target = data.to(device), target.to(device)

            # zeroing out accumulated gradients
            optimizer.zero_grad()

            # forward prop
            y_pred = self.model(data)

            # calculating loss (look at function documentation for details on what is returned by
            # the loss_fn)
            losses_data: List[Tuple[str, torch.Tensor]] = loss_fn(y_pred, target)
            if train_losses is None:
                train_losses = np.fromiter(
                    [x[-1] for x in losses_data], dtype=np.float32
                )
            else:
                train_losses = train_losses + np.fromiter(
                    [x[-1] for x in losses_data], dtype=np.float32
                )

            # backpropagation
            for _, loss in losses_data:
                loss.backward()
            optimizer.step()

            # calculating the accuracies (look at function documentation for details on what is returned by
            # the acc_fn)
            acc_data: List[Tuple[str, torch.Tensor]] = acc_fn(losses_data, target)
            if train_accs is None:
                train_accs = np.fromiter([x[-1] for x in acc_data], dtype=np.float32)
            else:
                train_accs = train_accs + np.fromiter(
                    [x[-1] for x in acc_data], dtype=np.float32
                )

            # updating progress bar with instantaneous losses and accuracies
            if use_tqdm:
                losses_desc = " - ".join(
                    [f"{name}: {value:0.4f}" for name, value in losses_data]
                )
                accs_desc = " - ".join(
                    [f"{name}: {value:0.4f}" for name, value in acc_data]
                )
                pbar.set_description(
                    desc=f"Batch_id: {batch_idx + 1} - {losses_desc} - {accs_desc}"
                )
            else:
                pbar.update(batch_idx, values=[*losses_data, *acc_data])

            if isinstance(scheduler, OneCycleLR):
                scheduler.step()

        if not use_tqdm:
            # required for pkbar
            pbar.add(1, values=[*losses_data, *acc_data])

        return [
            *list(
                zip(
                    # getting the labels of each loss value
                    [x[0] for x in losses_data],
                    # dividing the value of each of the losses by the number of batches in the dataset
                    [loss / len(train_loader) for loss in train_losses],
                )
            ),
            *list(
                zip(
                    # getting the labels of each accuracy value
                    [x[0] for x in acc_data],
                    # dividing the value of each accuracy by the number of batches in the dataset
                    [acc / len(train_loader) for acc in train_accs],
                )
            ),
        ]
Example #2
0
 def post_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):
     # synchronize all horovod optimizers.
     for optimizer in self.lightning_module.trainer.optimizers:
         optimizer.synchronize()
Example #3
0
 def post_backward(self, closure_loss: torch.Tensor,
                   should_accumulate: bool, optimizer: Optimizer,
                   opt_idx: int):
     optimizer.synchronize()