def evaluate(self, loader, verbose=False, write_summary=False, epoch=None):

        self.eval()

        valid_loss = 0

        all_class_probs = []
        all_labels = []

        with torch.no_grad():
            for batch_idx, sample in enumerate(loader):

                signal, labels = (
                    sample["signal"].to(self.device),
                    sample["labels"].to(self.device).float()
                )

                outputs = self(signal)

                class_logits = outputs["class_logits"]

                loss = (
                    lsep_loss(
                        class_logits,
                        labels,
                    )
                ).item()

                multiplier = len(labels) / len(loader.dataset)

                valid_loss += loss * multiplier

                class_probs = torch.sigmoid(class_logits).data.cpu().numpy()
                labels = labels.data.cpu().numpy()

                all_class_probs.extend(class_probs)
                all_labels.extend(labels)

            all_class_probs = np.asarray(all_class_probs)
            all_labels = np.asarray(all_labels)

            metric = lwlrap(all_labels, all_class_probs)

            if write_summary:
                self.add_scalar_summaries(
                    valid_loss,
                    metric,
                    writer=self.valid_writer, global_step=self.global_step
                )

            if verbose:
                print("\nValidation loss: {:.4f}".format(valid_loss))
                print("Validation metric: {:.4f}".format(metric))

            return metric
    def train_epoch(self, train_loader,
                    epoch, log_interval, write_summary=True):

        self.train()

        print(
            "\n" + " " * 10 + "****** Epoch {epoch} ******\n"
            .format(epoch=epoch)
        )

        training_losses = []

        history = deque(maxlen=30)

        self.optimizer.zero_grad()
        accumulated_loss = 0

        with tqdm(total=len(train_loader), ncols=80) as pb:

            for batch_idx, sample in enumerate(train_loader):

                self.global_step += 1

                make_step(self.scheduler, step=self.global_step)

                signal, labels, is_noisy = (
                    sample["signal"].to(self.device),
                    sample["labels"].to(self.device).float(),
                    sample["is_noisy"].to(self.device).float()
                )

                outputs = self(signal)

                class_logits = outputs["class_logits"]

                loss = (
                    lsep_loss(
                        class_logits,
                        labels,
                        average=False
                    )
                ) / self.config.train.accumulation_steps

                training_losses.extend(loss.data.cpu().numpy())
                loss = loss.mean()

                loss.backward()
                accumulated_loss += loss

                if batch_idx % self.config.train.accumulation_steps == 0:
                    self.optimizer.step()
                    accumulated_loss = 0
                    self.optimizer.zero_grad()

                probs = torch.sigmoid(class_logits).data.cpu().numpy()
                labels = labels.data.cpu().numpy()

                metric = lwlrap(labels, probs)
                history.append(metric)

                pb.update()
                pb.set_description(
                    "Loss: {:.4f}, Metric: {:.4f}".format(
                        loss.item(), np.mean(history)))

                if batch_idx % log_interval == 0:
                    self.add_scalar_summaries(
                        loss.item(), metric, self.train_writer, self.global_step)

                if batch_idx == 0:
                    self.add_image_summaries(
                        signal, self.global_step, self.train_writer)

        self.add_histogram_summaries(
            training_losses, self.train_writer, self.global_step)