Exemplo n.º 1
0
 def validation(self, epoch=None, validation_loader=None, epoch_run=None):
     score_acc = ScoreAccumulator()
     self.evaluate(data_loaders=validation_loader, logger=self.val_logger, gen_images=False, score_acc=score_acc)
     # epoch_run(epoch=epoch, data_loader=validation_loader, logger=self.val_logger, score_acc=score_acc)
     p, r, f1, a = score_acc.get_prfa()
     print('>>> PRF1: ', [p, r, f1, a])
     self._save_if_better(score=f1)
Exemplo n.º 2
0
    def train(self, data_loader=None, validation_loader=None, epoch_run=None):
        print('Training...')
        for epoch in range(1, self.epochs + 1):
            self._adjust_learning_rate(epoch=epoch)
            self.checkpoint['total_epochs'] = epoch

            # Run one epoch
            epoch_run(epoch=epoch, data_loader=data_loader)

            self._on_epoch_end(data_loader=data_loader,
                               log_file=self.train_logger.name)

            # Validation_frequency is the number of epoch until validation
            if epoch % self.validation_frequency == 0:
                print('Running validation..')
                self.model.eval()
                val_score = ScoreAccumulator()
                self._eval(data_loaders=validation_loader,
                           gen_images=False,
                           score_acc=val_score,
                           logger=self.val_logger)
                self._on_validation_end(data_loader=validation_loader,
                                        log_file=self.val_logger.name)
                if self.early_stop(patience=self.patience):
                    return

        if not self.train_logger and not self.train_logger.closed:
            self.train_logger.close()
        if not self.val_logger and not self.val_logger.closed:
            self.val_logger.close()
Exemplo n.º 3
0
    def evaluate(self,
                 data_loaders=None,
                 logger=None,
                 gen_images=False,
                 score_acc=None):
        assert isinstance(score_acc, ScoreAccumulator)
        for loader in data_loaders:
            img_obj = loader.dataset.image_objects[0]
            x, y = img_obj.working_arr.shape[0], img_obj.working_arr.shape[1]
            predicted_img = torch.FloatTensor(x, y).fill_(0).to(self.device)

            for i, data in enumerate(loader, 1):
                inputs, labels = data['inputs'].to(
                    self.device).float(), data['labels'].to(
                        self.device).float()
                clip_ix = data['clip_ix'].to(self.device).int()

                outputs = F.softmax(self.model(inputs), 1)
                _, predicted = torch.max(outputs, 1)
                predicted_map = outputs[:, 1, :, :]

                for j in range(predicted_map.shape[0]):
                    p, q, r, s = clip_ix[j]
                    predicted_img[p:q, r:s] = predicted[j]

                print('Batch: ', i, end='\r')

            img_score = ScoreAccumulator()
            if gen_images:  #### Test mode
                predicted_img = predicted_img.cpu().numpy() * 255
                predicted_img[img_obj.extra['fill_in'] == 1] = 255

                img_score.reset().add_array(predicted_img,
                                            img_obj.ground_truth)
                ### Only save scores for test images############################

                self.conf['acc'].accumulate(img_score)  # Global score
                prf1a = img_score.get_prfa()
                print(img_obj.file_name, ' PRF1A', prf1a)
                self.flush(
                    logger,
                    ','.join(str(x) for x in [img_obj.file_name] + prf1a))
                #################################################################

                IMG.fromarray(np.array(predicted_img, dtype=np.uint8)).save(
                    os.path.join(self.log_dir,
                                 img_obj.file_name.split('.')[0] + '.png'))
            else:  #### Validation mode
                img_score.reset().add_tensor(
                    predicted_img,
                    torch.FloatTensor(img_obj.extra['gt_mid']).to(self.device))
                score_acc.accumulate(img_score)
                prf1a = img_score.get_prfa()
                print(img_obj.file_name, ' PRF1A', prf1a)
                self.flush(
                    logger,
                    ','.join(str(x) for x in [img_obj.file_name] + prf1a))
Exemplo n.º 4
0
 def test(self, data_loaders=None):
     print('Running test')
     score = ScoreAccumulator()
     self.model.eval()
     with torch.no_grad():
         self.evaluate(data_loaders=data_loaders, gen_images=True, score_acc=score, logger=self.test_logger)
     self._on_test_end(log_file=self.test_logger.name)
     if not self.test_logger and not self.test_logger.closed:
         self.test_logger.close()
Exemplo n.º 5
0
def run(runs, transforms):
    for R in runs:
        for k, folder in R['Dirs'].items():
            os.makedirs(folder, exist_ok=True)
        R['acc'] = ScoreAccumulator()
        for split in os.listdir(R['Dirs']['splits_json']):
            splits = asp.load_split_json(
                os.path.join(R['Dirs']['splits_json'], split))

            R['checkpoint_file'] = split + '.tar'
            model = UNet(R['Params']['num_channels'],
                         R['Params']['num_classes'])
            optimizer = optim.Adam(model.parameters(),
                                   lr=R['Params']['learning_rate'])
            if R['Params']['distribute']:
                model = torch.nn.DataParallel(model)
                model.float()
                optimizer = optim.Adam(model.module.parameters(),
                                       lr=R['Params']['learning_rate'])

            try:
                drive_trainer = UNetBee(model=model,
                                        conf=R,
                                        optimizer=optimizer)
                if R.get('Params').get('mode') == 'train':
                    train_loader = PatchesGenerator.get_loader(
                        conf=R,
                        images=splits['train'],
                        transforms=transforms,
                        mode='train')
                    val_loader = PatchesGenerator.get_loader_per_img(
                        conf=R,
                        images=splits['validation'],
                        mode='validation',
                        transforms=transforms)
                    drive_trainer.train(data_loader=train_loader,
                                        validation_loader=val_loader,
                                        epoch_run=drive_trainer.epoch_ce_loss)

                drive_trainer.resume_from_checkpoint(
                    parallel_trained=R.get('Params').get('parallel_trained'))

                all_images = splits['test'] + splits['train'] + splits[
                    'validation']
                test_loader = PatchesGenerator.get_loader_per_img(
                    conf=R,
                    images=all_images,
                    mode='test',
                    transforms=transforms)
                drive_trainer.test(test_loader)
            except Exception as e:
                traceback.print_exc()

        print(R['acc'].get_prfa())
        f = open(R['Dirs']['logs'] + os.sep + 'score.txt', "w")
        f.write(', '.join(str(s) for s in R['acc'].get_prfa()))
        f.close()
Exemplo n.º 6
0
    def epoch_ce_loss(self, **kw):
        """
        One epoch implementation of binary cross-entropy loss
        :param kw:
        :return:
        """
        running_loss = 0.0
        score_acc = ScoreAccumulator() if self.model.training else kw.get('score_acc')
        assert isinstance(score_acc, ScoreAccumulator)

        for i, data in enumerate(kw['data_loader'], 1):
            inputs, labels = data['inputs'].to(self.device).float(), data['labels'].to(self.device).long()

            if self.model.training:
                self.optimizer.zero_grad()

            outputs = F.log_softmax(self.model(inputs), 1)
            _, predicted = torch.max(outputs, 1)

            loss = F.nll_loss(outputs, labels, weight=torch.FloatTensor(self.dparm(self.conf)).to(self.device))

            if self.model.training:
                loss.backward()
                self.optimizer.step()

            current_loss = loss.item()
            running_loss += current_loss

            if self.model.training:
                score_acc.reset()

            p, r, f1, a = score_acc.add_tensor(predicted, labels).get_prfa()

            if i % self.log_frequency == 0:
                print('Epochs[%d/%d] Batch[%d/%d] loss:%.5f pre:%.3f rec:%.3f f1:%.3f acc:%.3f' %
                      (
                          kw['epoch'], self.epochs, i, kw['data_loader'].__len__(),
                          running_loss / self.log_frequency, p, r, f1,
                          a))
                running_loss = 0.0
            self.flush(kw['logger'],
                       ','.join(str(x) for x in [0, kw['epoch'], i, p, r, f1, a, current_loss]))
Exemplo n.º 7
0
    def epoch_dice_loss(self, **kw):
        score_acc = ScoreAccumulator()
        running_loss = 0.0
        for i, data in enumerate(kw['data_loader'], 1):
            inputs, labels = data['inputs'].to(
                self.device).float(), data['labels'].to(self.device).long()
            # weights = data['weights'].to(self.device)

            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            _, predicted = torch.max(outputs, 1)

            # Balancing imbalanced class as per computed weights from the dataset
            # w = torch.FloatTensor(2).random_(1, 100).to(self.device)
            # wd = torch.FloatTensor(*labels.shape).uniform_(0.1, 2).to(self.device)

            loss = dice_loss(outputs[:, 1, :, :],
                             labels,
                             beta=rd.choice(np.arange(1, 2, 0.1).tolist()))
            loss.backward()
            self.optimizer.step()

            current_loss = loss.item()
            running_loss += current_loss
            p, r, f1, a = score_acc.reset().add_tensor(predicted,
                                                       labels).get_prfa()
            if i % self.log_frequency == 0:
                print(
                    'Epochs[%d/%d] Batch[%d/%d] loss:%.5f pre:%.3f rec:%.3f f1:%.3f acc:%.3f'
                    %
                    (kw['epoch'], self.epochs, i, kw['data_loader'].__len__(),
                     running_loss / self.log_frequency, p, r, f1, a))
                running_loss = 0.0

            self.flush(
                self.train_logger, ','.join(
                    str(x)
                    for x in [0, kw['epoch'], i, p, r, f1, a, current_loss]))
Exemplo n.º 8
0
    def epoch_dice_loss(self, **kw):

        score_acc = ScoreAccumulator() if self.model.training else kw.get('score_acc')
        assert isinstance(score_acc, ScoreAccumulator)

        running_loss = 0.0
        for i, data in enumerate(kw['data_loader'], 1):
            inputs, labels = data['inputs'].to(self.device).float(), data['labels'].to(self.device).long()

            if self.model.training:
                self.optimizer.zero_grad()

            outputs = F.softmax(self.model(inputs), 1)
            _, predicted = torch.max(outputs, 1)

            loss = dice_loss(outputs[:, 1, :, :], labels, beta=rd.choice(np.arange(1, 2, 0.1).tolist()))
            if self.model.training:
                loss.backward()
                self.optimizer.step()

            current_loss = loss.item()
            running_loss += current_loss

            if self.model.training:
                score_acc.reset()
            p, r, f1, a = score_acc.add_tensor(predicted, labels).get_prfa()

            if i % self.log_frequency == 0:
                print('Epochs[%d/%d] Batch[%d/%d] loss:%.5f pre:%.3f rec:%.3f f1:%.3f acc:%.3f' %
                      (
                          kw['epoch'], self.epochs, i, kw['data_loader'].__len__(), running_loss / self.log_frequency,
                          p, r, f1,
                          a))
                running_loss = 0.0

            self.flush(kw['logger'], ','.join(str(x) for x in [0, kw['epoch'], i, p, r, f1, a, current_loss]))
Exemplo n.º 9
0
    def _eval(self,
              data_loaders=None,
              logger=None,
              gen_images=False,
              score_acc=None):
        assert isinstance(score_acc, ScoreAccumulator)
        with torch.no_grad():
            for loader in data_loaders:
                img_obj = loader.dataset.image_objects[0]
                x, y = img_obj.working_arr.shape[0], img_obj.working_arr.shape[
                    1]
                predicted_img = torch.FloatTensor(x,
                                                  y).fill_(0).to(self.device)
                map_img = torch.FloatTensor(x, y).fill_(0).to(self.device)

                gt = torch.FloatTensor(img_obj.ground_truth).to(self.device)

                for i, data in enumerate(loader, 1):
                    inputs, labels = data['inputs'].to(
                        self.device).float(), data['labels'].to(
                            self.device).float()
                    clip_ix = data['clip_ix'].to(self.device).int()

                    outputs = self.model(inputs)
                    _, predicted = torch.max(outputs, 1)
                    predicted_map = outputs[:, 1, :, :]

                    for j in range(predicted_map.shape[0]):
                        p, q, r, s = clip_ix[j]
                        predicted_img[p:q, r:s] = predicted[j]
                        map_img[p:q, r:s] = predicted_map[j]
                    print('Batch: ', i, end='\r')

                img_score = ScoreAccumulator()
                map_img = torch.exp(map_img) * 255
                predicted_img = predicted_img * 255

                if gen_images:
                    map_img = map_img.cpu().numpy()
                    predicted_img = predicted_img.cpu().numpy()
                    img_score.add_array(predicted_img, img_obj.ground_truth)
                    self.conf['acc'].accumulate(img_score)  # Global score

                    IMG.fromarray(np.array(
                        predicted_img, dtype=np.uint8)).save(
                            os.path.join(
                                self.log_dir, 'pred_' +
                                img_obj.file_name.split('.')[0] + '.png'))
                    IMG.fromarray(np.array(map_img, dtype=np.uint8)).save(
                        os.path.join(self.log_dir,
                                     img_obj.file_name.split('.')[0] + '.png'))
                else:
                    img_score.add_tensor(predicted_img, gt)
                    score_acc.accumulate(img_score)

                prf1a = img_score.get_prfa()
                print(img_obj.file_name, ' PRF1A', prf1a)
                self.flush(
                    logger,
                    ','.join(str(x) for x in [img_obj.file_name] + prf1a))
        self._save_if_better(score=score_acc.get_prfa()[2])