def train_epoch(self,
                    train_loader,
                    epoch,
                    log_interval,
                    write_summary=True):

        self.train()

        print("\n" + " " * 10 +
              "****** Epoch {epoch} ******\n".format(epoch=epoch))

        history = deque(maxlen=30)

        self.optimizer.zero_grad()
        accumulated_loss = 0

        with tqdm(total=len(train_loader), ncols=80) as pb:

            for batch_idx, sample in enumerate(train_loader):

                self.global_step += 1

                make_step(self.scheduler, step=self.global_step)

                signal, labels = (sample["signal"].to(self.device),
                                  sample["labels"].to(self.device))

                outputs = self(signal)

                losses = outputs["losses"]

                loss = (sum(losses)) / self.config.train.accumulation_steps

                loss.backward()
                accumulated_loss += loss

                if batch_idx % self.config.train.accumulation_steps == 0:
                    self.optimizer.step()
                    accumulated_loss = 0
                    self.optimizer.zero_grad()

                history.append(loss.item())

                pb.update()
                pb.set_description("Loss: {:.4f}".format(np.mean(history)))

                if batch_idx % log_interval == 0:
                    self.add_scalar_summaries([loss.item() for loss in losses],
                                              self.train_writer,
                                              self.global_step)

                if batch_idx == 0:
                    self.add_image_summaries(signal,
                                             outputs["c"].permute(0, 2, 1),
                                             outputs["z"].permute(0, 2, 1),
                                             self.global_step,
                                             self.train_writer)
Beispiel #2
0
    def fit_validate(self,
                     train_loader,
                     valid_loader,
                     epochs,
                     fold,
                     log_interval=25):

        self.experiment.register_directory("summaries")
        self.train_writer = SummaryWriter(logdir=os.path.join(
            self.experiment.summaries, "fold_{}".format(fold), "train"))
        self.valid_writer = SummaryWriter(logdir=os.path.join(
            self.experiment.summaries, "fold_{}".format(fold), "valid"))

        os.makedirs(os.path.join(self.experiment.checkpoints,
                                 "fold_{}".format(fold)),
                    exist_ok=True)

        self.global_step = 0
        self.make_optimizer(max_steps=len(train_loader) * epochs)

        scores = []
        best_score = 0

        for epoch in range(epochs):

            make_step(self.scheduler, epoch=epoch)

            if epoch == self.config.train.switch_off_augmentations_on:
                train_loader.dataset.transform.switch_off_augmentations()

            self.train_epoch(train_loader,
                             epoch,
                             log_interval,
                             write_summary=True)
            validation_score = self.validation(valid_loader, epoch)
            scores.append(validation_score)

            if epoch % self.config.train._save_every == 0:
                print("\nSaving model on epoch", epoch)
                torch.save(
                    self.state_dict(),
                    os.path.join(self.experiment.checkpoints,
                                 "fold_{}".format(fold),
                                 "model_on_epoch_{}.pth".format(epoch)))

            if validation_score > best_score:
                torch.save(
                    self.state_dict(),
                    os.path.join(self.experiment.checkpoints,
                                 "fold_{}".format(fold), "best_model.pth"))
                best_score = validation_score

        return scores
Beispiel #3
0
    def train_epoch(self,
                    train_loader,
                    epoch,
                    log_interval,
                    write_summary=True):

        self.train()

        print("\n" + " " * 10 +
              "****** Epoch {epoch} ******\n".format(epoch=epoch))

        history = deque(maxlen=30)

        self.optimizer.zero_grad()
        accumulated_loss = 0

        with tqdm(total=len(train_loader), ncols=80) as pb:

            for batch_idx, sample in enumerate(train_loader):

                self.global_step += 1

                make_step(self.scheduler, step=self.global_step)

                signal, labels = (
                    sample["signal"].to(self.device),
                    sample["labels"].to(self.device).float(),
                )

                outputs = self(signal)

                class_logits = outputs["class_logits"].squeeze(-1)

                loss = (focal_loss(
                    class_logits,
                    labels,
                )) / self.config.train.accumulation_steps

                loss.backward()
                accumulated_loss += loss

                if batch_idx % self.config.train.accumulation_steps == 0:
                    self.optimizer.step()
                    accumulated_loss = 0
                    self.optimizer.zero_grad()

                class_logits = take_first_column(class_logits)  # human is 1
                labels = take_first_column(labels)

                probs = torch.sigmoid(class_logits).data.cpu().numpy()
                labels = labels.data.cpu().numpy()

                metric = compute_inverse_eer(labels, probs)
                history.append(metric)

                pb.update()
                pb.set_description("Loss: {:.4f}, Metric: {:.4f}".format(
                    loss.item(), np.mean(history)))

                if batch_idx % log_interval == 0:
                    self.add_scalar_summaries(loss.item(), metric,
                                              self.train_writer,
                                              self.global_step)

                if batch_idx == 0:
                    self.add_image_summaries(signal, self.global_step,
                                             self.train_writer)
    def train_epoch(self, train_loader,
                    epoch, log_interval, write_summary=True):

        self.train()

        print(
            "\n" + " " * 10 + "****** Epoch {epoch} ******\n"
            .format(epoch=epoch)
        )

        training_losses = []

        history = deque(maxlen=30)

        self.optimizer.zero_grad()
        accumulated_loss = 0

        with tqdm(total=len(train_loader), ncols=80) as pb:

            for batch_idx, sample in enumerate(train_loader):

                self.global_step += 1

                make_step(self.scheduler, step=self.global_step)

                signal, labels, is_noisy = (
                    sample["signal"].to(self.device),
                    sample["labels"].to(self.device).float(),
                    sample["is_noisy"].to(self.device).float()
                )

                outputs = self(signal)

                class_logits = outputs["class_logits"]

                loss = (
                    lsep_loss(
                        class_logits,
                        labels,
                        average=False
                    )
                ) / self.config.train.accumulation_steps

                training_losses.extend(loss.data.cpu().numpy())
                loss = loss.mean()

                loss.backward()
                accumulated_loss += loss

                if batch_idx % self.config.train.accumulation_steps == 0:
                    self.optimizer.step()
                    accumulated_loss = 0
                    self.optimizer.zero_grad()

                probs = torch.sigmoid(class_logits).data.cpu().numpy()
                labels = labels.data.cpu().numpy()

                metric = lwlrap(labels, probs)
                history.append(metric)

                pb.update()
                pb.set_description(
                    "Loss: {:.4f}, Metric: {:.4f}".format(
                        loss.item(), np.mean(history)))

                if batch_idx % log_interval == 0:
                    self.add_scalar_summaries(
                        loss.item(), metric, self.train_writer, self.global_step)

                if batch_idx == 0:
                    self.add_image_summaries(
                        signal, self.global_step, self.train_writer)

        self.add_histogram_summaries(
            training_losses, self.train_writer, self.global_step)