def train_epoch(segmentation_module, loader, optimizers, history, epoch, cfg, writer, epoch_iters, channels, patch_size, disp_iter, lr_encoder, lr_decoder): batch_time = AverageMeter() data_time = AverageMeter() ave_total_loss = AverageMeter() ave_acc = AverageMeter() iterator = iter(loader) segmentation_module.train(not cfg['TRAIN']['fix_bn']) #i.e. True # main loop tic = time.time() for i in range(epoch_iters): # load a batch of data batch_data = next(iterator) data_time.update(time.time() - tic) segmentation_module.zero_grad() # ============================================================================= # # adjust learning rate # TODO turn off if you want stable lr. # cur_iter = i + (epoch - 1) * cfg['TRAIN']['epoch_iters'] # adjust_learning_rate(optimizers, cur_iter, cfg, lr_encoder, lr_decoder) # ============================================================================= # get the data in correct format batch_images = torch.zeros( len(batch_data), len(channels), patch_size*3, patch_size*3) if cfg['DATASET']['segm_downsampling_rate'] == 0: batch_segms = torch.zeros( len(batch_data), patch_size*3, patch_size*3).long() else: batch_segms = torch.zeros( len(batch_data), patch_size*3//cfg['DATASET']['segm_downsampling_rate'], patch_size*3//cfg['DATASET']['segm_downsampling_rate']).long() for j, bd in enumerate(batch_data): batch_images[j] = bd['img_data'] batch_segms[j] = bd['seg_label'] batch_data = {'img_data': batch_images.cuda(), 'seg_label':batch_segms.cuda()} # forward pass #for HRNET # TODO: first one for HR model with acc/loss only on inner patch, second for smallmodel (or model without downsampling) #loss, acc = segmentation_module(batch_data, patch_size = int(patch_size/4)) #loss, acc = segmentation_module(batch_data, patch_size = int(patch_size)) loss, acc = segmentation_module(batch_data) loss = loss.mean() acc = acc.mean() # Backward loss.backward() for optimizer in optimizers: optimizer.step() # measure elapsed time batch_time.update(time.time() - tic) tic = time.time() # update average loss and acc ave_total_loss.update(loss.data.item()) ave_acc.update(acc.data.item()*100) # calculate accuracy, and display if i % disp_iter == 0: print('Epoch: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, ' 'lr_encoder: {:.6f}, lr_decoder: {:.6f}, ' 'Accuracy: {:4.2f}, Loss: {:.6f}' .format(epoch, i, epoch_iters, batch_time.average(), data_time.average(), cfg['TRAIN']['running_lr_encoder'], cfg['TRAIN']['running_lr_decoder'], ave_acc.average(), ave_total_loss.average())) fractional_epoch = epoch - 1 + 1. * i / epoch_iters history['train']['epoch'].append(fractional_epoch) history['train']['loss'].append(loss.data.item()) history['train']['acc'].append(acc.data.item()) writer.add_scalar('Train/Loss', ave_total_loss.average(), epoch) writer.add_scalar('Train/Acc', ave_acc.average(), epoch)
def evaluate(segmentation_module, loader, cfg, gpu, activations, num_class, patch_size, patch_size_padded, class_names, channels, index_test, visualize, results_dir, arch_encoder): acc_meter = AverageMeter() intersection_meter = AverageMeter() union_meter = AverageMeter() acc_meter_patch = AverageMeter() intersection_meter_patch = AverageMeter() union_meter_patch = AverageMeter() time_meter = AverageMeter() # initiate confusion matrix conf_matrix = np.zeros((num_class, num_class)) conf_matrix_patch = np.zeros((num_class, num_class)) # turn on for initialise for umap area_activations_mean = np.zeros((len(index_test), 32 // 4 * 32 // 4)) area_activations_max = np.zeros((len(index_test), 32 // 4 * 32 // 4)) area_cl = np.zeros((len(index_test), ), dtype=np.int) area_loc = np.zeros((len(index_test), 3), dtype=np.int) j = 0 segmentation_module.eval() pbar = tqdm(total=len(loader)) for batch_data in loader: # process data batch_data = batch_data[0] seg_label = as_numpy(batch_data['seg_label'][0]) img_resized_list = batch_data['img_data'] torch.cuda.synchronize() tic = time.perf_counter() with torch.no_grad(): segSize = (seg_label.shape[0], seg_label.shape[1]) scores = torch.zeros(1, num_class, segSize[0], segSize[1]) scores = async_copy_to(scores, gpu) for img in img_resized_list: feed_dict = batch_data.copy() feed_dict['img_data'] = img del feed_dict['img_ori'] del feed_dict['info'] feed_dict = async_copy_to(feed_dict, gpu) # forward pass scores_tmp = segmentation_module(feed_dict, segSize=segSize) scores = scores + scores_tmp _, pred = torch.max(scores, dim=1) pred = as_numpy(pred.squeeze(0).cpu()) torch.cuda.synchronize() time_meter.update(time.perf_counter() - tic) # calculate accuracy acc, pix = accuracy(pred, seg_label) acc_patch, pix_patch = accuracy( pred[patch_size:2 * patch_size, patch_size:2 * patch_size], seg_label[patch_size:2 * patch_size, patch_size:2 * patch_size]) intersection, union = intersectionAndUnion(pred, seg_label, num_class) intersection_patch, union_patch = intersectionAndUnion( pred[patch_size:2 * patch_size, patch_size:2 * patch_size], seg_label[patch_size:2 * patch_size, patch_size:2 * patch_size], num_class) acc_meter.update(acc, pix) intersection_meter.update(intersection) union_meter.update(union) acc_meter_patch.update(acc_patch, pix_patch) intersection_meter_patch.update(intersection_patch) union_meter_patch.update(union_patch) conf_matrix = updateConfusionMatrix(conf_matrix, pred, seg_label) # update conf matrix patch conf_matrix_patch = updateConfusionMatrix( conf_matrix_patch, pred[patch_size:2 * patch_size, patch_size:2 * patch_size], seg_label[patch_size:2 * patch_size, patch_size:2 * patch_size]) # visualization if visualize: info = batch_data['info'] img_name = info.split('/')[-1] #np.save(os.path.join(test_dir, 'result', img_name), pred) np.save(os.path.join(results_dir, img_name), pred) # ============================================================================= # if visualize: # visualize_result( # (batch_data['img_ori'], seg_label, batch_data['info']), # pred, # os.path.join(test_dir, 'result') # ) # ============================================================================= pbar.update(1) # turn on for UMAP row, col, cl = find_constant_area( seg_label, 32, patch_size_padded ) #TODO patch_size_padded must be patch_size if only inner patch is checked. if not (row == 999999): activ_mean = np.mean( as_numpy(activations.features.squeeze(0).cpu()), axis=0, keepdims=True)[:, row // 4:row // 4 + 8, col // 4:col // 4 + 8].reshape(1, 8 * 8) activ_max = np.max(as_numpy(activations.features.squeeze(0).cpu()), axis=0, keepdims=True)[:, row // 4:row // 4 + 8, col // 4:col // 4 + 8].reshape( 1, 8 * 8) area_activations_mean[j] = activ_mean area_activations_max[j] = activ_max area_cl[j] = cl area_loc[j, 0] = row area_loc[j, 1] = col area_loc[j, 2] = int(batch_data['info'].split('.')[0]) j += 1 else: area_activations_mean[j] = np.full((1, 64), np.nan, dtype=np.float32) area_activations_max[j] = np.full((1, 64), np.nan, dtype=np.float32) area_cl[j] = 999999 area_loc[j, 0] = row area_loc[j, 1] = col area_loc[j, 2] = int(batch_data['info'].split('.')[0]) j += 1 #activ = np.mean(as_numpy(activations.features.squeeze(0).cpu()),axis=0)[row//4:row//4+8, col//4:col//4+8] #activ = as_numpy(activations.features.squeeze(0).cpu()) # summary iou = intersection_meter.sum / (union_meter.sum + 1e-10) for i, _iou in enumerate(iou): print('class [{}], IoU: {:.4f}'.format(i, _iou)) iou_patch = intersection_meter_patch.sum / (union_meter_patch.sum + 1e-10) for i, _iou_patch in enumerate(iou_patch): print('class [{}], patch IoU: {:.4f}'.format(i, _iou_patch)) print('[Eval Summary]:') print( 'Mean IoU: {:.4f}, Accuracy: {:.2f}%, Inference Time: {:.4f}s'.format( iou.mean(), acc_meter.average() * 100, time_meter.average())) print( 'Patch: Mean IoU: {:.4f}, Accuracy: {:.2f}%, Inference Time: {:.4f}s'. format(iou_patch.mean(), acc_meter_patch.average() * 100, time_meter.average())) print('Confusion matrix:') plot_confusion_matrix(conf_matrix, class_names, normalize=True, title='confusion matrix patch+padding', cmap=plt.cm.Blues) plot_confusion_matrix(conf_matrix_patch, class_names, normalize=True, title='confusion matrix patch', cmap=plt.cm.Blues) np.save(os.path.join(results_dir, 'confmatrix.npy'), conf_matrix) np.save(os.path.join(results_dir, 'confmatrix_patch.npy'), conf_matrix_patch) # turn on for UMAP np.save(os.path.join(results_dir, 'activations_mean.npy'), area_activations_mean) np.save(os.path.join(results_dir, 'activations_max.npy'), area_activations_max) np.save(os.path.join(results_dir, 'activations_labels.npy'), area_cl) np.save(os.path.join(results_dir, 'activations_loc.npy'), area_loc) mcc = compute_mcc(conf_matrix) mcc_patch = compute_mcc(conf_matrix_patch) # save summary of results in csv summary = pd.DataFrame([[ arch_encoder, patch_size, channels, acc_meter.average(), acc_meter_patch.average(), iou.mean(), iou_patch.mean(), mcc, mcc_patch ]], columns=[ 'model', 'patch_size', 'channels', 'test_accuracy', 'test_accuracy_patch', 'meanIoU', 'meanIoU_patch', 'mcc', 'mcc_patch' ]) summary.to_csv(os.path.join(results_dir, 'summary_results.csv'))
def validate(segmentation_module, loader, optimizers, history, epoch, cfg, writer,val_epoch_iters, channels, patch_size): ave_total_loss = AverageMeter() ave_acc = AverageMeter() time_meter = AverageMeter() segmentation_module.eval() iterator = iter(loader) # main loop tic = time.time() for i in range(val_epoch_iters): # load a batch of data batch_data = next(iterator) # get the data in correct format batch_images = torch.zeros( len(batch_data), len(channels), patch_size*3, patch_size*3) if cfg['DATASET']['segm_downsampling_rate'] == 0: batch_segms = torch.zeros( len(batch_data), patch_size*3, patch_size*3).long() else: batch_segms = torch.zeros( len(batch_data), patch_size*3//cfg['DATASET']['segm_downsampling_rate'], patch_size*3//cfg['DATASET']['segm_downsampling_rate']).long() for j, bd in enumerate(batch_data): batch_images[j] = bd['img_data'] batch_segms[j] = bd['seg_label'] batch_data = {'img_data': batch_images.cuda(), 'seg_label':batch_segms.cuda()} with torch.no_grad(): # forward pass loss, acc = segmentation_module(batch_data, patch_size = int(patch_size/4)) loss = loss.mean() acc = acc.mean() # update average loss and acc ave_total_loss.update(loss.data.item()) ave_acc.update(acc.data.item()*100) # measure elapsed time time_meter.update(time.time() - tic) tic = time.time() # calculate accuracy, and display fractional_epoch = epoch - 1 + 1. * i / val_epoch_iters history['val']['epoch'].append(fractional_epoch) history['val']['loss'].append(loss.data.item()) history['val']['acc'].append(acc.data.item()) print('Epoch: [{}], Time: {:.2f}, ' 'Val_Accuracy: {:4.2f}, Val_Loss: {:.6f}' .format(epoch, time_meter.average(), ave_acc.average(), ave_total_loss.average())) writer.add_scalar('Val/Loss', ave_total_loss.average(), epoch) writer.add_scalar('Val/Acc', ave_acc.average(), epoch) return ave_total_loss.average()