def evaluate(self, loader, verbose=False, write_summary=False, epoch=None): self.eval() valid_loss = 0 all_class_probs = [] all_labels = [] with torch.no_grad(): for batch_idx, sample in enumerate(loader): signal, labels = ( sample["signal"].to(self.device), sample["labels"].to(self.device).float() ) outputs = self(signal) class_logits = outputs["class_logits"] loss = ( lsep_loss( class_logits, labels, ) ).item() multiplier = len(labels) / len(loader.dataset) valid_loss += loss * multiplier class_probs = torch.sigmoid(class_logits).data.cpu().numpy() labels = labels.data.cpu().numpy() all_class_probs.extend(class_probs) all_labels.extend(labels) all_class_probs = np.asarray(all_class_probs) all_labels = np.asarray(all_labels) metric = lwlrap(all_labels, all_class_probs) if write_summary: self.add_scalar_summaries( valid_loss, metric, writer=self.valid_writer, global_step=self.global_step ) if verbose: print("\nValidation loss: {:.4f}".format(valid_loss)) print("Validation metric: {:.4f}".format(metric)) return metric
def train_epoch(self, train_loader, epoch, log_interval, write_summary=True): self.train() print( "\n" + " " * 10 + "****** Epoch {epoch} ******\n" .format(epoch=epoch) ) training_losses = [] history = deque(maxlen=30) self.optimizer.zero_grad() accumulated_loss = 0 with tqdm(total=len(train_loader), ncols=80) as pb: for batch_idx, sample in enumerate(train_loader): self.global_step += 1 make_step(self.scheduler, step=self.global_step) signal, labels, is_noisy = ( sample["signal"].to(self.device), sample["labels"].to(self.device).float(), sample["is_noisy"].to(self.device).float() ) outputs = self(signal) class_logits = outputs["class_logits"] loss = ( lsep_loss( class_logits, labels, average=False ) ) / self.config.train.accumulation_steps training_losses.extend(loss.data.cpu().numpy()) loss = loss.mean() loss.backward() accumulated_loss += loss if batch_idx % self.config.train.accumulation_steps == 0: self.optimizer.step() accumulated_loss = 0 self.optimizer.zero_grad() probs = torch.sigmoid(class_logits).data.cpu().numpy() labels = labels.data.cpu().numpy() metric = lwlrap(labels, probs) history.append(metric) pb.update() pb.set_description( "Loss: {:.4f}, Metric: {:.4f}".format( loss.item(), np.mean(history))) if batch_idx % log_interval == 0: self.add_scalar_summaries( loss.item(), metric, self.train_writer, self.global_step) if batch_idx == 0: self.add_image_summaries( signal, self.global_step, self.train_writer) self.add_histogram_summaries( training_losses, self.train_writer, self.global_step)