def evaluate(self, loader, verbose=False, write_summary=False, epoch=None): self.eval() valid_loss = 0 all_class_probs = [] all_labels = [] with torch.no_grad(): for batch_idx, sample in enumerate(loader): signal, labels = ( sample["signal"].to(self.device), sample["labels"].to(self.device).float() ) outputs = self(signal) class_logits = outputs["class_logits"].squeeze(-1) loss = ( focal_loss( class_logits, labels, ) ).item() multiplier = len(labels) / len(loader.dataset) valid_loss += loss * multiplier class_probs = torch.sigmoid(class_logits).data.cpu().numpy() labels = labels.data.cpu().numpy() all_class_probs.extend(class_probs) all_labels.extend(labels) all_class_probs = np.asarray(all_class_probs) all_labels = np.asarray(all_labels) metric = compute_inverse_eer(all_labels, all_class_probs) if write_summary: self.add_scalar_summaries( valid_loss, metric, writer=self.valid_writer, global_step=self.global_step ) if verbose: print("\nValidation loss: {:.4f}".format(valid_loss)) print("Validation metric: {:.4f}".format(metric)) return metric
def train_epoch(self, train_loader, epoch, log_interval, write_summary=True): self.train() print("\n" + " " * 10 + "****** Epoch {epoch} ******\n".format(epoch=epoch)) history = deque(maxlen=30) self.optimizer.zero_grad() accumulated_loss = 0 with tqdm(total=len(train_loader), ncols=80) as pb: for batch_idx, sample in enumerate(train_loader): self.global_step += 1 make_step(self.scheduler, step=self.global_step) signal, labels = ( sample["signal"].to(self.device), sample["labels"].to(self.device).float(), ) outputs = self(signal) class_logits = outputs["class_logits"].squeeze(-1) loss = (focal_loss( class_logits, labels, )) / self.config.train.accumulation_steps loss.backward() accumulated_loss += loss if batch_idx % self.config.train.accumulation_steps == 0: self.optimizer.step() accumulated_loss = 0 self.optimizer.zero_grad() class_logits = take_first_column(class_logits) # human is 1 labels = take_first_column(labels) probs = torch.sigmoid(class_logits).data.cpu().numpy() labels = labels.data.cpu().numpy() metric = compute_inverse_eer(labels, probs) history.append(metric) pb.update() pb.set_description("Loss: {:.4f}, Metric: {:.4f}".format( loss.item(), np.mean(history))) if batch_idx % log_interval == 0: self.add_scalar_summaries(loss.item(), metric, self.train_writer, self.global_step) if batch_idx == 0: self.add_image_summaries(signal, self.global_step, self.train_writer)