Exemplo n.º 1
0
    def test(self, loader, epoch_num):

        Utility.cleanup()
        log.info(f"Finished cleanup for epoch {epoch_num}")

        self.model.eval()

        pbar = tqdm(loader, ncols=1000)
        total_loss = 0
        summary_loss = 0

        metrices = []
        num_batches = len(loader)
        log.info(f"Tester starting the testing for epoch: {epoch_num}")

        with torch.no_grad():
            for idx, data in enumerate(pbar):

                x = torch.cat((data['bg'], data['fg_bg']), dim=1).to(device=self.device)
                data['fg_bg_mask'] = data['fg_bg_mask'].to(self.device)
                data['fg_bg_depth'] = data['fg_bg_depth'].to(self.device)

                log.info(f"Starting the testing for batch:{idx}")
                (loss, mask, depth) = self.__test_one_batch__(x, data['fg_bg_mask'], data['fg_bg_depth'])
                log.info(f"End of the testing for batch:{idx}")

                total_loss += loss
                summary_loss += loss

                if self.persister is not None:
                    self.persister(data, mask, epoch_num, "mask")
                    self.persister(data, depth, epoch_num, "depth")
                    log.info(f"Persisted the prediction for batch:{idx}")

                if self.metric_fn is not None:
                    metric = self.metric_fn(data, mask)
                    metrices.append(metric)
                    log.info(f"Computed the metric for batch:{idx}")

                if ((idx + 1) % 500 == 0 or idx == num_batches - 1):
                    self.writer.write_pred_summary(data, mask, depth)
                    l = summary_loss / 500
                    if idx == num_batches - 1:
                        l = summary_loss / ((idx + 1) % 500)
                    self.writer.write_scalar_summary('test loss', l, epoch_num * num_batches + idx)
                    summary_loss = 0

                pbar.set_description(desc=f'Loss={loss}\t id={idx}\t')
                log.info(f"For test batch {idx} loss is {loss}")
                del loss, mask, depth, data
                log.info(f"Completed the training for batch:{idx}")

        metric = None
        if self.metric_fn is not None:
            metric = self.metric_fn.aggregate(metrices)
        return PredictionResult(total_loss / len(loader.dataset), metric)
Exemplo n.º 2
0
    def train_one_epoch(self, loader, epoch_num):

        Utility.cleanup()
        log.info(f"Finished cleanup for epoch {epoch_num}")

        self.model.train()
        pbar = tqdm(loader, ncols=1000)

        total_loss = 0
        summary_loss = 0
        metrices = []

        num_batches = len(loader)
        log.info(f"Trainer starting the training for epoch: {epoch_num}")
        for idx, data in enumerate(pbar):

            log.info(f"Obtained the data for batch:{idx}")

            x = torch.cat((data['bg'], data['fg_bg']), dim=1).to(self.device)
            data['fg_bg_mask'] = data['fg_bg_mask'].to(self.device)
            data['fg_bg_depth'] = data['fg_bg_depth'].to(self.device)

            log.info(f"Starting the training for batch:{idx}")
            (loss, mask, depth) = self.__train_one_batch__(x, data['fg_bg_mask'], data['fg_bg_depth'])
            log.info(f"End of the training for batch:{idx}")

            total_loss += loss
            summary_loss += loss

            self.scheduler.step()
            log.info(f"Scheduler step for the batch:{idx}")

            if self.persister is not None:
                self.persister(data, mask, epoch_num, "mask")
                self.persister(data, depth, epoch_num, "depth")
                log.info(f"Persisted the prediction for batch:{idx}")

            if self.metric_fn is not None:
                metric = self.metric_fn(data, mask)
                metrices.append(metric)
                log.info(f"Computed the metric for batch:{idx}")

            lr = self.optimizer.param_groups[0]['lr']
            pbar.set_description(desc=f'id={idx}\t Loss={loss}\t LR={lr}\t')
            log.info(f"For train batch {idx} loss is {loss} and lr is {lr}")

            if ((idx + 1) % 500 == 0 or idx == num_batches - 1):
                self.writer.write_pred_summary(data, mask.detach(), depth.detach())
                l = summary_loss / 500
                if idx == num_batches - 1:
                    l = summary_loss / ((idx + 1) % 500)
                self.writer.write_scalar_summary('train loss', l, epoch_num * num_batches + idx)
                self.writer.write_scalar_summary('lr', lr, epoch_num * num_batches + idx)
                summary_loss = 0

            del loss, mask, depth, data
            log.info(f"Completed the training for batch:{idx}")

        metric = None
        if self.metric_fn is not None:
            metric = self.metric_fn.aggregate(metrices)
        return PredictionResult(total_loss / len(loader.dataset), metric)