예제 #1
0
def train_epoch(segmentation_module, loader, optimizers, history, epoch, cfg, 
                writer, epoch_iters, channels, patch_size, disp_iter, lr_encoder, lr_decoder):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    ave_total_loss = AverageMeter()
    ave_acc = AverageMeter()
    
    iterator = iter(loader)

    segmentation_module.train(not cfg['TRAIN']['fix_bn']) #i.e. True

    # main loop
    tic = time.time()
    for i in range(epoch_iters):
        # load a batch of data
        batch_data = next(iterator)
        data_time.update(time.time() - tic)
        segmentation_module.zero_grad()

# =============================================================================
#         # adjust learning rate # TODO turn off if you want stable lr.
#         cur_iter = i + (epoch - 1) * cfg['TRAIN']['epoch_iters']
#         adjust_learning_rate(optimizers, cur_iter, cfg, lr_encoder, lr_decoder)
# =============================================================================

        
        # get the data in correct format
        batch_images = torch.zeros(
                len(batch_data), 
                len(channels), 
                patch_size*3, 
                patch_size*3)
       
        if cfg['DATASET']['segm_downsampling_rate'] == 0:
            batch_segms = torch.zeros(
                    len(batch_data), 
                    patch_size*3, 
                    patch_size*3).long()
        else:
            batch_segms = torch.zeros(
                    len(batch_data), 
                    patch_size*3//cfg['DATASET']['segm_downsampling_rate'], 
                    patch_size*3//cfg['DATASET']['segm_downsampling_rate']).long()
        
        for j, bd in enumerate(batch_data): 
            batch_images[j] = bd['img_data']
            batch_segms[j] = bd['seg_label']
            
        batch_data = {'img_data': batch_images.cuda(), 'seg_label':batch_segms.cuda()}

        # forward pass
        #for HRNET # TODO: first one for HR model with acc/loss only on inner patch, second for smallmodel (or model without downsampling)
        #loss, acc = segmentation_module(batch_data, patch_size = int(patch_size/4)) 
        #loss, acc = segmentation_module(batch_data, patch_size = int(patch_size)) 
        loss, acc = segmentation_module(batch_data)
        loss = loss.mean()
        acc = acc.mean()

        # Backward
        loss.backward()
        for optimizer in optimizers:
            optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - tic)
        tic = time.time()

        # update average loss and acc
        ave_total_loss.update(loss.data.item())
        ave_acc.update(acc.data.item()*100)

        # calculate accuracy, and display
        if i % disp_iter == 0:
            print('Epoch: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, '
                  'lr_encoder: {:.6f}, lr_decoder: {:.6f}, '
                  'Accuracy: {:4.2f}, Loss: {:.6f}'
                  .format(epoch, i, epoch_iters,
                          batch_time.average(), data_time.average(),
                          cfg['TRAIN']['running_lr_encoder'], cfg['TRAIN']['running_lr_decoder'],
                          ave_acc.average(), ave_total_loss.average()))

        fractional_epoch = epoch - 1 + 1. * i / epoch_iters
        history['train']['epoch'].append(fractional_epoch)
        history['train']['loss'].append(loss.data.item())
        history['train']['acc'].append(acc.data.item())
        
    
    writer.add_scalar('Train/Loss', ave_total_loss.average(), epoch)
    writer.add_scalar('Train/Acc', ave_acc.average(), epoch)
예제 #2
0
def evaluate(segmentation_module, loader, cfg, gpu, activations, num_class,
             patch_size, patch_size_padded, class_names, channels, index_test,
             visualize, results_dir, arch_encoder):
    acc_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    acc_meter_patch = AverageMeter()
    intersection_meter_patch = AverageMeter()
    union_meter_patch = AverageMeter()
    time_meter = AverageMeter()

    # initiate confusion matrix
    conf_matrix = np.zeros((num_class, num_class))
    conf_matrix_patch = np.zeros((num_class, num_class))
    # turn on for initialise for umap
    area_activations_mean = np.zeros((len(index_test), 32 // 4 * 32 // 4))
    area_activations_max = np.zeros((len(index_test), 32 // 4 * 32 // 4))
    area_cl = np.zeros((len(index_test), ), dtype=np.int)
    area_loc = np.zeros((len(index_test), 3), dtype=np.int)
    j = 0

    segmentation_module.eval()

    pbar = tqdm(total=len(loader))
    for batch_data in loader:

        # process data
        batch_data = batch_data[0]
        seg_label = as_numpy(batch_data['seg_label'][0])
        img_resized_list = batch_data['img_data']

        torch.cuda.synchronize()
        tic = time.perf_counter()
        with torch.no_grad():
            segSize = (seg_label.shape[0], seg_label.shape[1])
            scores = torch.zeros(1, num_class, segSize[0], segSize[1])
            scores = async_copy_to(scores, gpu)

            for img in img_resized_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, gpu)

                # forward pass
                scores_tmp = segmentation_module(feed_dict, segSize=segSize)
                scores = scores + scores_tmp

            _, pred = torch.max(scores, dim=1)
            pred = as_numpy(pred.squeeze(0).cpu())

        torch.cuda.synchronize()
        time_meter.update(time.perf_counter() - tic)

        # calculate accuracy
        acc, pix = accuracy(pred, seg_label)
        acc_patch, pix_patch = accuracy(
            pred[patch_size:2 * patch_size, patch_size:2 * patch_size],
            seg_label[patch_size:2 * patch_size, patch_size:2 * patch_size])

        intersection, union = intersectionAndUnion(pred, seg_label, num_class)
        intersection_patch, union_patch = intersectionAndUnion(
            pred[patch_size:2 * patch_size, patch_size:2 * patch_size],
            seg_label[patch_size:2 * patch_size,
                      patch_size:2 * patch_size], num_class)

        acc_meter.update(acc, pix)
        intersection_meter.update(intersection)
        union_meter.update(union)
        acc_meter_patch.update(acc_patch, pix_patch)
        intersection_meter_patch.update(intersection_patch)
        union_meter_patch.update(union_patch)

        conf_matrix = updateConfusionMatrix(conf_matrix, pred, seg_label)

        # update conf matrix patch
        conf_matrix_patch = updateConfusionMatrix(
            conf_matrix_patch, pred[patch_size:2 * patch_size,
                                    patch_size:2 * patch_size],
            seg_label[patch_size:2 * patch_size, patch_size:2 * patch_size])

        # visualization
        if visualize:
            info = batch_data['info']
            img_name = info.split('/')[-1]
            #np.save(os.path.join(test_dir, 'result', img_name), pred)
            np.save(os.path.join(results_dir, img_name), pred)


# =============================================================================
#         if visualize:
#             visualize_result(
#                 (batch_data['img_ori'], seg_label, batch_data['info']),
#                 pred,
#                 os.path.join(test_dir, 'result')
#             )
# =============================================================================

        pbar.update(1)

        # turn on for UMAP
        row, col, cl = find_constant_area(
            seg_label, 32, patch_size_padded
        )  #TODO patch_size_padded must be patch_size if only inner patch is checked.
        if not (row == 999999):
            activ_mean = np.mean(
                as_numpy(activations.features.squeeze(0).cpu()),
                axis=0,
                keepdims=True)[:, row // 4:row // 4 + 8,
                               col // 4:col // 4 + 8].reshape(1, 8 * 8)
            activ_max = np.max(as_numpy(activations.features.squeeze(0).cpu()),
                               axis=0,
                               keepdims=True)[:, row // 4:row // 4 + 8,
                                              col // 4:col // 4 + 8].reshape(
                                                  1, 8 * 8)

            area_activations_mean[j] = activ_mean
            area_activations_max[j] = activ_max
            area_cl[j] = cl
            area_loc[j, 0] = row
            area_loc[j, 1] = col
            area_loc[j, 2] = int(batch_data['info'].split('.')[0])
            j += 1
        else:
            area_activations_mean[j] = np.full((1, 64),
                                               np.nan,
                                               dtype=np.float32)
            area_activations_max[j] = np.full((1, 64),
                                              np.nan,
                                              dtype=np.float32)
            area_cl[j] = 999999
            area_loc[j, 0] = row
            area_loc[j, 1] = col
            area_loc[j, 2] = int(batch_data['info'].split('.')[0])
            j += 1

        #activ = np.mean(as_numpy(activations.features.squeeze(0).cpu()),axis=0)[row//4:row//4+8, col//4:col//4+8]
        #activ = as_numpy(activations.features.squeeze(0).cpu())

    # summary
    iou = intersection_meter.sum / (union_meter.sum + 1e-10)
    for i, _iou in enumerate(iou):
        print('class [{}], IoU: {:.4f}'.format(i, _iou))
    iou_patch = intersection_meter_patch.sum / (union_meter_patch.sum + 1e-10)
    for i, _iou_patch in enumerate(iou_patch):
        print('class [{}], patch IoU: {:.4f}'.format(i, _iou_patch))

    print('[Eval Summary]:')
    print(
        'Mean IoU: {:.4f}, Accuracy: {:.2f}%, Inference Time: {:.4f}s'.format(
            iou.mean(),
            acc_meter.average() * 100, time_meter.average()))
    print(
        'Patch: Mean IoU: {:.4f}, Accuracy: {:.2f}%, Inference Time: {:.4f}s'.
        format(iou_patch.mean(),
               acc_meter_patch.average() * 100, time_meter.average()))

    print('Confusion matrix:')
    plot_confusion_matrix(conf_matrix,
                          class_names,
                          normalize=True,
                          title='confusion matrix patch+padding',
                          cmap=plt.cm.Blues)
    plot_confusion_matrix(conf_matrix_patch,
                          class_names,
                          normalize=True,
                          title='confusion matrix patch',
                          cmap=plt.cm.Blues)

    np.save(os.path.join(results_dir, 'confmatrix.npy'), conf_matrix)
    np.save(os.path.join(results_dir, 'confmatrix_patch.npy'),
            conf_matrix_patch)
    # turn on for UMAP
    np.save(os.path.join(results_dir, 'activations_mean.npy'),
            area_activations_mean)
    np.save(os.path.join(results_dir, 'activations_max.npy'),
            area_activations_max)
    np.save(os.path.join(results_dir, 'activations_labels.npy'), area_cl)
    np.save(os.path.join(results_dir, 'activations_loc.npy'), area_loc)

    mcc = compute_mcc(conf_matrix)
    mcc_patch = compute_mcc(conf_matrix_patch)
    # save summary of results in csv
    summary = pd.DataFrame([[
        arch_encoder, patch_size, channels,
        acc_meter.average(),
        acc_meter_patch.average(),
        iou.mean(),
        iou_patch.mean(), mcc, mcc_patch
    ]],
                           columns=[
                               'model', 'patch_size', 'channels',
                               'test_accuracy', 'test_accuracy_patch',
                               'meanIoU', 'meanIoU_patch', 'mcc', 'mcc_patch'
                           ])
    summary.to_csv(os.path.join(results_dir, 'summary_results.csv'))
예제 #3
0
def validate(segmentation_module, loader, optimizers, history, epoch, cfg, writer,val_epoch_iters, channels, patch_size):
    ave_total_loss = AverageMeter()
    ave_acc = AverageMeter()
    time_meter = AverageMeter()

    segmentation_module.eval()
    
    iterator = iter(loader)

    # main loop
    tic = time.time()
    for i in range(val_epoch_iters):
        # load a batch of data
        batch_data = next(iterator)
        
        # get the data in correct format
        batch_images = torch.zeros(
                len(batch_data), 
                len(channels), 
                patch_size*3, 
                patch_size*3)
        
        if cfg['DATASET']['segm_downsampling_rate'] == 0:
            batch_segms = torch.zeros(
                    len(batch_data), 
                    patch_size*3, 
                    patch_size*3).long()
        else:
            batch_segms = torch.zeros(
                len(batch_data), 
                patch_size*3//cfg['DATASET']['segm_downsampling_rate'], 
                patch_size*3//cfg['DATASET']['segm_downsampling_rate']).long()
        
        for j, bd in enumerate(batch_data): 
            batch_images[j] = bd['img_data']
            batch_segms[j] = bd['seg_label']
            
        batch_data = {'img_data': batch_images.cuda(), 'seg_label':batch_segms.cuda()}
      
        with torch.no_grad():
            # forward pass
            loss, acc = segmentation_module(batch_data, patch_size = int(patch_size/4))
        
        loss = loss.mean()
        acc = acc.mean()

        # update average loss and acc
        ave_total_loss.update(loss.data.item())
        ave_acc.update(acc.data.item()*100)


        # measure elapsed time
        time_meter.update(time.time() - tic)
        tic = time.time()

        

        # calculate accuracy, and display
        fractional_epoch = epoch - 1 + 1. * i / val_epoch_iters
        history['val']['epoch'].append(fractional_epoch)
        history['val']['loss'].append(loss.data.item())
        history['val']['acc'].append(acc.data.item())
        
    print('Epoch: [{}], Time: {:.2f}, ' 
          'Val_Accuracy: {:4.2f}, Val_Loss: {:.6f}'
          .format(epoch, time_meter.average(),
                  ave_acc.average(), ave_total_loss.average()))
    writer.add_scalar('Val/Loss', ave_total_loss.average(), epoch)
    writer.add_scalar('Val/Acc', ave_acc.average(), epoch)

    return ave_total_loss.average()