Exemple #1
0
 def inference(self, model, loader):
     empty_cuda_cache()
     state_dict = torch.load(self.best_checkpoint_path)
     model.load_state_dict(state_dict)
     tqdm_loader = tqdm(loader)
     for idx, batch in enumerate(tqdm_loader):
         with torch.no_grad():
             batch_imgs = batch[0].to(device=self.device, non_blocking=True)
             batch_pred = model(batch_imgs).cpu()
             batch_pred = F.softmax(batch_pred, dim=1)[:, 1, ...]
             for pred_idx, pred in enumerate(batch_pred):
                 save_image(
                     pred,
                     self.checkpoints_history_folder / f'{pred_idx}.tif')
Exemple #2
0
    def run_train(self, model1, model2, train_dataloader, valid_dataloader):
        # pdb.set_trace()
        model1.to(self.device1)
        model2.to(self.device2)
        # model, self.optimizer = amp.initialize(model, self.optimizer, opt_level='O1')
        for self.epoch in range(self.n_epoches):
            if self.distrib_config['LOCAL_RANK'] == 0:
                self.logger.info(f'Epoch {self.epoch}: \t start training....')
                self.evaluator1.reset()
                self.evaluator2.reset()
            model1.train()
            model2.train()
            train_loss1_mean, train_loss2_mean = self.train_epoch(
                model1, model2, train_dataloader)
            if self.distrib_config['LOCAL_RANK'] == 0:
                self.logger.info(
                    f'Epoch {self.epoch}: \t Calculated train loss: {train_loss1_mean:.5f},'
                    f' {train_loss2_mean:.5f}')
                self.tb_logger.add_scalar('Train/Loss1', train_loss1_mean)
                self.tb_logger.add_scalar('Train/Loss2', train_loss2_mean)

            if self.distrib_config['LOCAL_RANK'] == 0:
                self.logger.info(
                    f'Epoch {self.epoch}: \t start validation....')
            model1.eval()
            model2.eval()
            self.valid_epoch(model1, model2, valid_dataloader)
            selected_score = self.process_summary()

            self.post_processing(selected_score, model1, model2)

            if self.epoch - self.best_epoch > self.early_stopping:
                if self.distrib_config['LOCAL_RANK'] == 0:
                    self.logger.info('EARLY STOPPING')
                break

        if self.distrib_config['LOCAL_RANK'] == 0:
            self.tb_logger.close()

        # self.inference(model, valid_dataloader)
        empty_cuda_cache()
        return self.best_epoch, self.best_score