def validation(self, epoch=None, validation_loader=None, epoch_run=None): score_acc = ScoreAccumulator() self.evaluate(data_loaders=validation_loader, logger=self.val_logger, gen_images=False, score_acc=score_acc) # epoch_run(epoch=epoch, data_loader=validation_loader, logger=self.val_logger, score_acc=score_acc) p, r, f1, a = score_acc.get_prfa() print('>>> PRF1: ', [p, r, f1, a]) self._save_if_better(score=f1)
def train(self, data_loader=None, validation_loader=None, epoch_run=None): print('Training...') for epoch in range(1, self.epochs + 1): self._adjust_learning_rate(epoch=epoch) self.checkpoint['total_epochs'] = epoch # Run one epoch epoch_run(epoch=epoch, data_loader=data_loader) self._on_epoch_end(data_loader=data_loader, log_file=self.train_logger.name) # Validation_frequency is the number of epoch until validation if epoch % self.validation_frequency == 0: print('Running validation..') self.model.eval() val_score = ScoreAccumulator() self._eval(data_loaders=validation_loader, gen_images=False, score_acc=val_score, logger=self.val_logger) self._on_validation_end(data_loader=validation_loader, log_file=self.val_logger.name) if self.early_stop(patience=self.patience): return if not self.train_logger and not self.train_logger.closed: self.train_logger.close() if not self.val_logger and not self.val_logger.closed: self.val_logger.close()
def evaluate(self, data_loaders=None, logger=None, gen_images=False, score_acc=None): assert isinstance(score_acc, ScoreAccumulator) for loader in data_loaders: img_obj = loader.dataset.image_objects[0] x, y = img_obj.working_arr.shape[0], img_obj.working_arr.shape[1] predicted_img = torch.FloatTensor(x, y).fill_(0).to(self.device) for i, data in enumerate(loader, 1): inputs, labels = data['inputs'].to( self.device).float(), data['labels'].to( self.device).float() clip_ix = data['clip_ix'].to(self.device).int() outputs = F.softmax(self.model(inputs), 1) _, predicted = torch.max(outputs, 1) predicted_map = outputs[:, 1, :, :] for j in range(predicted_map.shape[0]): p, q, r, s = clip_ix[j] predicted_img[p:q, r:s] = predicted[j] print('Batch: ', i, end='\r') img_score = ScoreAccumulator() if gen_images: #### Test mode predicted_img = predicted_img.cpu().numpy() * 255 predicted_img[img_obj.extra['fill_in'] == 1] = 255 img_score.reset().add_array(predicted_img, img_obj.ground_truth) ### Only save scores for test images############################ self.conf['acc'].accumulate(img_score) # Global score prf1a = img_score.get_prfa() print(img_obj.file_name, ' PRF1A', prf1a) self.flush( logger, ','.join(str(x) for x in [img_obj.file_name] + prf1a)) ################################################################# IMG.fromarray(np.array(predicted_img, dtype=np.uint8)).save( os.path.join(self.log_dir, img_obj.file_name.split('.')[0] + '.png')) else: #### Validation mode img_score.reset().add_tensor( predicted_img, torch.FloatTensor(img_obj.extra['gt_mid']).to(self.device)) score_acc.accumulate(img_score) prf1a = img_score.get_prfa() print(img_obj.file_name, ' PRF1A', prf1a) self.flush( logger, ','.join(str(x) for x in [img_obj.file_name] + prf1a))
def test(self, data_loaders=None): print('Running test') score = ScoreAccumulator() self.model.eval() with torch.no_grad(): self.evaluate(data_loaders=data_loaders, gen_images=True, score_acc=score, logger=self.test_logger) self._on_test_end(log_file=self.test_logger.name) if not self.test_logger and not self.test_logger.closed: self.test_logger.close()
def run(runs, transforms): for R in runs: for k, folder in R['Dirs'].items(): os.makedirs(folder, exist_ok=True) R['acc'] = ScoreAccumulator() for split in os.listdir(R['Dirs']['splits_json']): splits = asp.load_split_json( os.path.join(R['Dirs']['splits_json'], split)) R['checkpoint_file'] = split + '.tar' model = UNet(R['Params']['num_channels'], R['Params']['num_classes']) optimizer = optim.Adam(model.parameters(), lr=R['Params']['learning_rate']) if R['Params']['distribute']: model = torch.nn.DataParallel(model) model.float() optimizer = optim.Adam(model.module.parameters(), lr=R['Params']['learning_rate']) try: drive_trainer = UNetBee(model=model, conf=R, optimizer=optimizer) if R.get('Params').get('mode') == 'train': train_loader = PatchesGenerator.get_loader( conf=R, images=splits['train'], transforms=transforms, mode='train') val_loader = PatchesGenerator.get_loader_per_img( conf=R, images=splits['validation'], mode='validation', transforms=transforms) drive_trainer.train(data_loader=train_loader, validation_loader=val_loader, epoch_run=drive_trainer.epoch_ce_loss) drive_trainer.resume_from_checkpoint( parallel_trained=R.get('Params').get('parallel_trained')) all_images = splits['test'] + splits['train'] + splits[ 'validation'] test_loader = PatchesGenerator.get_loader_per_img( conf=R, images=all_images, mode='test', transforms=transforms) drive_trainer.test(test_loader) except Exception as e: traceback.print_exc() print(R['acc'].get_prfa()) f = open(R['Dirs']['logs'] + os.sep + 'score.txt', "w") f.write(', '.join(str(s) for s in R['acc'].get_prfa())) f.close()
def epoch_ce_loss(self, **kw): """ One epoch implementation of binary cross-entropy loss :param kw: :return: """ running_loss = 0.0 score_acc = ScoreAccumulator() if self.model.training else kw.get('score_acc') assert isinstance(score_acc, ScoreAccumulator) for i, data in enumerate(kw['data_loader'], 1): inputs, labels = data['inputs'].to(self.device).float(), data['labels'].to(self.device).long() if self.model.training: self.optimizer.zero_grad() outputs = F.log_softmax(self.model(inputs), 1) _, predicted = torch.max(outputs, 1) loss = F.nll_loss(outputs, labels, weight=torch.FloatTensor(self.dparm(self.conf)).to(self.device)) if self.model.training: loss.backward() self.optimizer.step() current_loss = loss.item() running_loss += current_loss if self.model.training: score_acc.reset() p, r, f1, a = score_acc.add_tensor(predicted, labels).get_prfa() if i % self.log_frequency == 0: print('Epochs[%d/%d] Batch[%d/%d] loss:%.5f pre:%.3f rec:%.3f f1:%.3f acc:%.3f' % ( kw['epoch'], self.epochs, i, kw['data_loader'].__len__(), running_loss / self.log_frequency, p, r, f1, a)) running_loss = 0.0 self.flush(kw['logger'], ','.join(str(x) for x in [0, kw['epoch'], i, p, r, f1, a, current_loss]))
def epoch_dice_loss(self, **kw): score_acc = ScoreAccumulator() running_loss = 0.0 for i, data in enumerate(kw['data_loader'], 1): inputs, labels = data['inputs'].to( self.device).float(), data['labels'].to(self.device).long() # weights = data['weights'].to(self.device) self.optimizer.zero_grad() outputs = self.model(inputs) _, predicted = torch.max(outputs, 1) # Balancing imbalanced class as per computed weights from the dataset # w = torch.FloatTensor(2).random_(1, 100).to(self.device) # wd = torch.FloatTensor(*labels.shape).uniform_(0.1, 2).to(self.device) loss = dice_loss(outputs[:, 1, :, :], labels, beta=rd.choice(np.arange(1, 2, 0.1).tolist())) loss.backward() self.optimizer.step() current_loss = loss.item() running_loss += current_loss p, r, f1, a = score_acc.reset().add_tensor(predicted, labels).get_prfa() if i % self.log_frequency == 0: print( 'Epochs[%d/%d] Batch[%d/%d] loss:%.5f pre:%.3f rec:%.3f f1:%.3f acc:%.3f' % (kw['epoch'], self.epochs, i, kw['data_loader'].__len__(), running_loss / self.log_frequency, p, r, f1, a)) running_loss = 0.0 self.flush( self.train_logger, ','.join( str(x) for x in [0, kw['epoch'], i, p, r, f1, a, current_loss]))
def epoch_dice_loss(self, **kw): score_acc = ScoreAccumulator() if self.model.training else kw.get('score_acc') assert isinstance(score_acc, ScoreAccumulator) running_loss = 0.0 for i, data in enumerate(kw['data_loader'], 1): inputs, labels = data['inputs'].to(self.device).float(), data['labels'].to(self.device).long() if self.model.training: self.optimizer.zero_grad() outputs = F.softmax(self.model(inputs), 1) _, predicted = torch.max(outputs, 1) loss = dice_loss(outputs[:, 1, :, :], labels, beta=rd.choice(np.arange(1, 2, 0.1).tolist())) if self.model.training: loss.backward() self.optimizer.step() current_loss = loss.item() running_loss += current_loss if self.model.training: score_acc.reset() p, r, f1, a = score_acc.add_tensor(predicted, labels).get_prfa() if i % self.log_frequency == 0: print('Epochs[%d/%d] Batch[%d/%d] loss:%.5f pre:%.3f rec:%.3f f1:%.3f acc:%.3f' % ( kw['epoch'], self.epochs, i, kw['data_loader'].__len__(), running_loss / self.log_frequency, p, r, f1, a)) running_loss = 0.0 self.flush(kw['logger'], ','.join(str(x) for x in [0, kw['epoch'], i, p, r, f1, a, current_loss]))
def _eval(self, data_loaders=None, logger=None, gen_images=False, score_acc=None): assert isinstance(score_acc, ScoreAccumulator) with torch.no_grad(): for loader in data_loaders: img_obj = loader.dataset.image_objects[0] x, y = img_obj.working_arr.shape[0], img_obj.working_arr.shape[ 1] predicted_img = torch.FloatTensor(x, y).fill_(0).to(self.device) map_img = torch.FloatTensor(x, y).fill_(0).to(self.device) gt = torch.FloatTensor(img_obj.ground_truth).to(self.device) for i, data in enumerate(loader, 1): inputs, labels = data['inputs'].to( self.device).float(), data['labels'].to( self.device).float() clip_ix = data['clip_ix'].to(self.device).int() outputs = self.model(inputs) _, predicted = torch.max(outputs, 1) predicted_map = outputs[:, 1, :, :] for j in range(predicted_map.shape[0]): p, q, r, s = clip_ix[j] predicted_img[p:q, r:s] = predicted[j] map_img[p:q, r:s] = predicted_map[j] print('Batch: ', i, end='\r') img_score = ScoreAccumulator() map_img = torch.exp(map_img) * 255 predicted_img = predicted_img * 255 if gen_images: map_img = map_img.cpu().numpy() predicted_img = predicted_img.cpu().numpy() img_score.add_array(predicted_img, img_obj.ground_truth) self.conf['acc'].accumulate(img_score) # Global score IMG.fromarray(np.array( predicted_img, dtype=np.uint8)).save( os.path.join( self.log_dir, 'pred_' + img_obj.file_name.split('.')[0] + '.png')) IMG.fromarray(np.array(map_img, dtype=np.uint8)).save( os.path.join(self.log_dir, img_obj.file_name.split('.')[0] + '.png')) else: img_score.add_tensor(predicted_img, gt) score_acc.accumulate(img_score) prf1a = img_score.get_prfa() print(img_obj.file_name, ' PRF1A', prf1a) self.flush( logger, ','.join(str(x) for x in [img_obj.file_name] + prf1a)) self._save_if_better(score=score_acc.get_prfa()[2])