def evaluate(self, loader, verbose=False, write_summary=False, epoch=None):

        self.eval()

        valid_loss = 0

        all_class_probs = []
        all_labels = []

        with torch.no_grad():
            for batch_idx, sample in enumerate(loader):

                signal, labels = (
                    sample["signal"].to(self.device),
                    sample["labels"].to(self.device).float()
                )

                outputs = self(signal)

                class_logits = outputs["class_logits"].squeeze(-1)

                loss = (
                    focal_loss(
                        class_logits,
                        labels,
                    )
                ).item()

                multiplier = len(labels) / len(loader.dataset)

                valid_loss += loss * multiplier

                class_probs = torch.sigmoid(class_logits).data.cpu().numpy()
                labels = labels.data.cpu().numpy()

                all_class_probs.extend(class_probs)
                all_labels.extend(labels)

            all_class_probs = np.asarray(all_class_probs)
            all_labels = np.asarray(all_labels)

            metric = compute_inverse_eer(all_labels, all_class_probs)

            if write_summary:
                self.add_scalar_summaries(
                    valid_loss,
                    metric,
                    writer=self.valid_writer, global_step=self.global_step
                )

            if verbose:
                print("\nValidation loss: {:.4f}".format(valid_loss))
                print("Validation metric: {:.4f}".format(metric))

            return metric
Exemple #2
0
    def train_epoch(self,
                    train_loader,
                    epoch,
                    log_interval,
                    write_summary=True):

        self.train()

        print("\n" + " " * 10 +
              "****** Epoch {epoch} ******\n".format(epoch=epoch))

        history = deque(maxlen=30)

        self.optimizer.zero_grad()
        accumulated_loss = 0

        with tqdm(total=len(train_loader), ncols=80) as pb:

            for batch_idx, sample in enumerate(train_loader):

                self.global_step += 1

                make_step(self.scheduler, step=self.global_step)

                signal, labels = (
                    sample["signal"].to(self.device),
                    sample["labels"].to(self.device).float(),
                )

                outputs = self(signal)

                class_logits = outputs["class_logits"].squeeze(-1)

                loss = (focal_loss(
                    class_logits,
                    labels,
                )) / self.config.train.accumulation_steps

                loss.backward()
                accumulated_loss += loss

                if batch_idx % self.config.train.accumulation_steps == 0:
                    self.optimizer.step()
                    accumulated_loss = 0
                    self.optimizer.zero_grad()

                class_logits = take_first_column(class_logits)  # human is 1
                labels = take_first_column(labels)

                probs = torch.sigmoid(class_logits).data.cpu().numpy()
                labels = labels.data.cpu().numpy()

                metric = compute_inverse_eer(labels, probs)
                history.append(metric)

                pb.update()
                pb.set_description("Loss: {:.4f}, Metric: {:.4f}".format(
                    loss.item(), np.mean(history)))

                if batch_idx % log_interval == 0:
                    self.add_scalar_summaries(loss.item(), metric,
                                              self.train_writer,
                                              self.global_step)

                if batch_idx == 0:
                    self.add_image_summaries(signal, self.global_step,
                                             self.train_writer)