Exemplo n.º 1
0
    def train(self):
        if self.reset_early_stopping and self.early_stopping:
            self.early_stopping.reset()

        self.model = self.model.to(self.device)

        pbar = tqdm(range(self.num_epochs),
                    desc=f"epoch {self.epoch}: loss=???")
        for epochs in batch_items(
                range(self.epoch + 1, self.epoch + self.num_epochs + 1),
                self.evaluate_every_epoch):
            self.model.train()
            for epoch in epochs:
                self.epoch = epoch
                metrics, losses = self.train_epoch()
                pbar.set_description(
                    f"epoch {self.epoch}: loss={losses['loss']:.5f}")
                pbar.refresh()
                pbar.update(1)
            self.validate()
            if isinstance(self.lr_scheduler, ReduceOnPlateauScheduler):
                self.lr_scheduler.step(self.val_metric_tracker.last_value)
            if self.should_terminate:
                logger_console.info(
                    "Reached an early stop threshold, stopping early.")
                break
        pbar.close()
        self.final_save()
Exemplo n.º 2
0
    def train_epoch(self):
        pbar = tqdm(desc=f"iter={self.iteration}: loss=???", leave=False)
        train_loss = 0.0
        train_reg_loss = 0.0
        iterations_this_epoch = 0

        self.optimizer.zero_grad()

        for batches in batch_items(self.train_loader,
                                   self.num_accumulation_steps):
            self.iteration += 1
            iterations_this_epoch += 1
            batch = None
            result = None
            for batch in batches:
                result = self.train_step(batch)
                train_loss += result["loss"].item()
                train_reg_loss += result["reg_loss"].item()

            # Optimizer + scheduler
            self.optimizer.step()
            if not isinstance(self.lr_scheduler, ReduceOnPlateauScheduler):
                self.lr_scheduler.step()

            # Logging
            metrics, losses = self.get_metrics_and_losses(
                train_loss,
                train_reg_loss,
                iterations_this_epoch,
                is_train=True)
            pbar.set_description(
                f"iter {self.iteration}: loss={losses['loss']:.5f}")
            pbar.refresh()
            pbar.update(1)

            for logger in self.loggers:
                logger.log(
                    self.iteration,
                    metrics,
                    losses,
                    batch,
                    result["logits"],
                    self.lr_scheduler,
                    self.optimizer,
                    prefix="train",
                )

            # Checkpoint
            if self.train_checkpoint:
                self.train_checkpoint.maybe_save(self.state_dict())

        pbar.close()

        metrics, losses = self.get_metrics_and_losses(train_loss,
                                                      train_reg_loss,
                                                      iterations_this_epoch,
                                                      is_train=True,
                                                      reset=True)

        return metrics, losses
Exemplo n.º 3
0
    def predict_patches(self,
                        images_batch: torch.tensor,
                        shapes: torch.tensor = None) -> torch.tensor:
        if self.patch_size is None:
            raise ValueError("In order to predict ")
        if shapes is not None:
            if len(images_batch) != len(shapes):
                raise ValueError(
                    "Images shapes and images batch should have the same number of samples."
                )
        else:
            if len(images_batch) != 1:
                raise ValueError("If no shapes given, should be a batch of 1.")
            shapes = torch.tensor([images_batch.size()[-2:]])
        results = []
        for image, shape in zip(images_batch, shapes):
            original_h, original_w = shape.numpy()
            image = image[..., :original_h, :original_w]
            image = F.pad(
                image.unsqueeze(0),
                [self.patches_margin] * 4,
                self.padding_mode,
                self.padding_value,
            ).squeeze(0)

            h, w = image.size()[1:]

            x_step = compute_step(h, self.patch_size[0], self.margin,
                                  self.patches_overlap)
            y_step = compute_step(w, self.patch_size[1], self.margin,
                                  self.patches_overlap)

            x_pos = (np.round(
                np.arange(x_step + 1) / x_step *
                (h - self.patch_size[0])).astype(np.int32).tolist())
            y_pos = (np.round(
                np.arange(y_step + 1) / y_step *
                (w - self.patch_size[1])).astype(np.int32).tolist())

            counts = torch.zeros((h, w), dtype=torch.long).to(self.device)
            probas_sum = torch.zeros([self.num_classes, h, w]).to(self.device)

            for positions in batch_items(list(product(x_pos, y_pos)),
                                         self.patches_batch_size):
                crops = torch.stack([
                    pos2crop(image, pos, self.patch_size, self.margin)
                    for pos in positions
                ])

                probas = self.predict_batch(crops)
                for idx, (x, y) in enumerate(positions):
                    counts[x + self.patches_margin:x + self.patch_size[0] -
                           self.patches_margin, y + self.patches_margin:y +
                           self.patch_size[1] - self.patches_margin, ] += 1
                    probas_sum[:, x + self.patches_margin:x +
                               self.patch_size[0] - self.patches_margin,
                               y + self.patches_margin:y + self.patch_size[1] -
                               self.patches_margin, ] += probas[idx][
                                   ...,
                                   self.patches_margin:self.patch_size[0] -
                                   self.patches_margin,
                                   self.patches_margin:self.patch_size[1] -
                                   self.patches_margin, ]
            image_probas = probas_sum / counts
            image_probas = image_probas[
                ..., self.patches_margin:self.patches_margin + original_h,
                self.patches_margin:self.patches_margin + original_w, ]
            results.append(image_probas)
        return torch.stack(results)
Exemplo n.º 4
0
    def __iter__(self):
        data = self.get_data_with_worker_info()
        if self.shuffle:
            shuffled_data = data.sample(frac=1)
        else:
            shuffled_data = data

        for samples in batch_items(shuffled_data[["image", "label"]].values,
                                   self.prefetch_shuffle):
            samples = [
                load_sample({
                    "image": image,
                    "label": label
                }) for image, label in samples
            ]

            if self.pre_patches_compose:
                samples = [
                    self.pre_patches_compose(**sample) for sample in samples
                ]

            if self.assign_transform:
                for sample in samples:
                    label = self.assign_transform.first_phase(sample["label"])
                    sample.update({"label": label})
            samples = [
                sample_to_patche_samples(
                    sample,
                    self.patch_size,
                    self.patches_overlap,
                    offset_augment=self.offsets_augment,
                ) for sample in samples
            ]

            paths = []
            idx = 0
            for patch_idx, sample in enumerate(samples):
                rows, cols = sample[0].shape[:2]
                for row in range(rows):
                    for col in range(cols):
                        paths.append((patch_idx, row, col))
                        idx += 1

            if self.shuffle:
                indices = torch.randperm(len(paths))
            else:
                indices = range(len(paths))
            for idx in indices:
                patch_idx, row, col = paths[idx]
                image = samples[patch_idx][0][row, col, 0]
                label = samples[patch_idx][1][row, col]

                sample = {"image": image, "label": label}
                if self.post_patches_compose:
                    sample = self.post_patches_compose(**sample)

                if self.assign_transform:
                    label = self.assign_transform.second_phase(sample["label"])
                    sample.update({"label": label})

                yield sample_to_tensor(sample)
            del samples