def load_models(model_folder, gpu_id=0):
    """ load segmentation model from folder
    :param model_folder:    the folder containing the segmentation model
    :param gpu_id:          the gpu device id to run the segmentation model
    :return: a dictionary containing the model and inference parameters
    """
    assert os.path.isdir(model_folder), 'Model folder does not exist: {}'.format(model_folder)

    # load inference config file
    infer_cfg = load_config(os.path.join(model_folder, 'infer_config.py'))
    models = edict()
    models.infer_cfg = infer_cfg

    # load coarse model if it is enabled
    if models.infer_cfg.general.single_scale == 'coarse':
        coarse_model_folder = os.path.join(model_folder, models.infer_cfg.coarse.model_name)
        coarse_model = load_single_model(coarse_model_folder, gpu_id)
        models.coarse_model = coarse_model
        models.fine_model = None

    elif models.infer_cfg.general.single_scale == 'fine':
        fine_model_folder = os.path.join(model_folder, models.infer_cfg.fine.model_name)
        fine_model = load_single_model(fine_model_folder, gpu_id)
        models.fine_model = fine_model
        models.coarse_model = None

    elif models.infer_cfg.general.single_scale == 'DISABLE':
        coarse_model_folder = os.path.join(model_folder, models.infer_cfg.coarse.model_name)
        coarse_model = load_single_model(coarse_model_folder, gpu_id)
        models.coarse_model = coarse_model

        fine_model_folder = os.path.join(model_folder, models.infer_cfg.fine.model_name)
        fine_model = load_single_model(fine_model_folder, gpu_id)
        models.fine_model = fine_model

    else:
        raise ValueError('Unsupported single scale type!')

    return models
def load_seg_model(model_folder, gpu_id=0):
  """ load segmentation model from folder
  :param model_folder:    the folder containing the segmentation model
  :param gpu_id:          the gpu device id to run the segmentation model
  :return: a dictionary containing the model and inference parameters
  """
  assert os.path.isdir(model_folder), 'Model folder does not exist: {}'.format(model_folder)

  model = edict()

  # load inference config file
  latest_checkpoint_dir = get_checkpoint_folder(os.path.join(model_folder, 'checkpoints'), -1)
  infer_cfg = load_config(os.path.join(model_folder, 'config_infer.py'))
  model.infer_cfg = infer_cfg

  if len(gpu_id) >= 0:
    os.environ['CUDA_VISIBLE_DEVICES'] = '{},{}'.format(int(gpu_id[0]),int(gpu_id[1]))

  # load model state
  chk_file = os.path.join(latest_checkpoint_dir, 'params.pth')
  state = torch.load(chk_file)
  # load network module
  net_module = importlib.import_module('segmentation3d.network.' + state['net'])
  net = net_module.SegmentationNet(state['in_channels'], state['out_channels'])
  net = nn.parallel.DataParallel(net)
  net.load_state_dict(state['state_dict'])
  net.eval()

  if len(gpu_id) >= 0:
    net = net.cuda()
    del os.environ['CUDA_VISIBLE_DEVICES']
    
  model.net = net
  model.spacing = state['spacing']
  model.max_stride = state['max_stride']
  model.interpolation = state['interpolation']
  return model
Example #3
0
def train(config_file):
    """ Medical image segmentation training engine
    :param config_file: the input configuration file
    :return: None
    """
    assert os.path.isfile(config_file), 'Config not found: {}'.format(
        config_file)

    # load config file
    cfg = load_config(config_file)
    # clean the existing folder if training from scratch
    if cfg.general.resume_epoch < 0 and os.path.isdir(cfg.general.save_dir):
        shutil.rmtree(cfg.general.save_dir)

    # enable logging
    log_file = os.path.join(cfg.general.save_dir, 'train_log.txt')
    logger = setup_logger(log_file, 'seg3d')

    # control randomness during training
    np.random.seed(cfg.general.seed)
    torch.manual_seed(cfg.general.seed)
    if cfg.general.num_gpus > 0:
        torch.cuda.manual_seed(cfg.general.seed)

    # dataset
    dataset = SegmentationDataset(
        imlist_file=cfg.general.imseg_list,
        num_classes=cfg.dataset.num_classes,
        spacing=cfg.dataset.spacing,
        crop_size=cfg.dataset.crop_size,
        default_values=cfg.dataset.default_values,
        sampling_method=cfg.dataset.sampling_method,
        random_translation=cfg.dataset.random_translation,
        interpolation=cfg.dataset.interpolation,
        crop_normalizers=cfg.dataset.crop_normalizers)

    sampler = EpochConcateSampler(dataset, cfg.train.epochs)
    #print('total index for training',len(sampler))
    data_loader = DataLoader(dataset,
                             sampler=sampler,
                             batch_size=cfg.train.batchsize,
                             num_workers=cfg.train.num_threads,
                             pin_memory=True)
    net_module = importlib.import_module('segmentation3d.network.' +
                                         cfg.net.name)
    net = net_module.SegmentationNet(dataset.num_modality(),
                                     cfg.dataset.num_classes)  #1,2
    max_stride = net.max_stride()  #return 16
    net_module.parameters_kaiming_init(net)  #initial weights

    if cfg.general.num_gpus > 0:
        net = nn.parallel.DataParallel(net,
                                       device_ids=list(
                                           range(cfg.general.num_gpus)))
        net = net.cuda()

    assert np.all(
        np.array(cfg.dataset.crop_size) % max_stride == 0
    ), 'crop size not divisible by max stride'  #adjust crop size for down conv

    # training optimizer
    opt = optim.Adam(net.parameters(), lr=cfg.train.lr,
                     betas=cfg.train.betas)  # 1e-4 and (0.9,0.999)
    # load checkpoint if resume epoch > 0 for keep training
    if cfg.general.resume_epoch >= 0:
        last_save_epoch, batch_start = load_checkpoint(
            cfg.general.resume_epoch, net, opt, cfg.general.save_dir)
    else:
        last_save_epoch, batch_start = 0, 0

    batch_idx = batch_start
    data_iter = iter(data_loader)
    if cfg.loss.name == 'Focal':
        # reuse focal loss if exists
        loss_func = FocalLoss(class_num=cfg.dataset.num_classes,
                              alpha=cfg.loss.obj_weight,
                              gamma=cfg.loss.focal_gamma)
    elif cfg.loss.name == 'Dice':
        loss_func = MultiDiceLoss(weights=cfg.loss.obj_weight,
                                  num_class=cfg.dataset.num_classes,
                                  use_gpu=cfg.general.num_gpus > 0)
    else:
        raise ValueError('Unknown loss function')

    writer = SummaryWriter(os.path.join(cfg.general.save_dir, 'tensorboard'))

    # loop over batches
    for i in range(len(data_loader)):  #epoches
        begin_t = time.time()
        crops, masks, frames, filenames = data_iter.next()
        print('training ', filenames)
        #print('crops',crops.shape)
        #print('masks',masks.shape)

        if cfg.general.num_gpus > 0:
            crops, masks = crops.cuda(), masks.cuda()

        # clear previous gradients
        opt.zero_grad()

        # network forward and backward
        outputs = net(crops)
        #print('outputs',outputs.shape)
        train_loss = loss_func(outputs,
                               masks)  #each class has a loss ang get average
        train_loss.backward()

        # update weights
        opt.step()

        # save training crops for visualization
        if cfg.debug.save_inputs:
            batch_size = crops.size(0)
            save_intermediate_results(
                list(range(batch_size)), crops, masks, None, frames, filenames,
                os.path.join(cfg.general.save_dir, 'batch_{}'.format(i)))
        epoch_idx = batch_idx * cfg.train.batchsize // len(dataset)
        batch_idx += 1
        batch_duration = time.time() - begin_t
        sample_duration = batch_duration * 1.0 / cfg.train.batchsize
        if (batch_idx + 1) % 1 == 0:
            begin = 2
            end = 31
            r = 3
            image_num = int((end - begin - 1) / r)
            '''show result'''
            image = crops[0, 0:1, :, begin:end:r, :].permute(2, 0, 1, 3)
            grid_image = make_grid(image, image_num, normalize=True)
            writer.add_image('/train/image', grid_image, batch_idx)
            '''show label'''
            label = masks[0, 0:1, :, begin:end:r, :].permute(2, 0, 1, 3)
            grid_label = make_grid(label, image_num)
            writer.add_image('/train/label', grid_label, batch_idx)
            '''show pred'''
            pred1 = outputs[0, 0:1, :, begin:end:r, :].permute(2, 0, 1,
                                                               3)  #1:2 or 0:1
            pred1[pred1 <= 0.8] = 0
            pred1[pred1 > 0.8] = 1

            grid_pred1 = make_grid(pred1, image_num)
            writer.add_image('/train/pred0:1', grid_pred1, batch_idx)
            '''show pred'''
            pred2 = outputs[0, 1:2, :, begin:end:r, :].permute(2, 0, 1,
                                                               3)  #1:2 or 0:1
            pred2[pred2 > 0.8] = 1
            pred2[pred2 <= 0.8] = 0
            grid_pred2 = make_grid(pred2, image_num)
            writer.add_image('/train/pred1:2', grid_pred2, batch_idx)
            '''  '''
            grid_image = grid_image.cpu().detach().numpy().transpose((1, 2, 0))
            grid_label = grid_label.cpu().detach().numpy().transpose((1, 2, 0))
            grid_pred1 = grid_pred1.cpu().detach().numpy().transpose((1, 2, 0))
            grid_pred2 = grid_pred2.cpu().detach().numpy().transpose((1, 2, 0))
            #fig = plt.figure()
            #ax = fig.add_subplot(411)
            #ax.imshow(grid_image[:,:,0],'gray')#,vmin=0,vmax=1.)
            #ax = fig.add_subplot(412)
            #cs = ax.imshow(grid_label[:,:,0],'gray',vmin=0,vmax=1.)
            #ax=fig.add_subplot(413)
            #cs = ax.imshow(grid_pred1[:,:,0],'gray',vmin=0,vmax=1.)
            #ax=fig.add_subplot(414)
            #cs = ax.imshow(grid_pred2[:,:,0],'gray',vmin=0,vmax=1.)
            #fig.colorbar(cs, ax=ax,shrink = 0.9)
            #writer.add_figure('/train/pred_result',fig,batch_idx)
            #fig.clear()
            pred1 = label2rgb(grid_pred1[:, :, 0],
                              grid_image[:, :, 0],
                              bg_label=0)
            pred2 = label2rgb(grid_pred2[:, :, 0],
                              grid_image[:, :, 0],
                              bg_label=0)
            gt = label2rgb(grid_label[:, :, 0],
                           grid_image[:, :, 0],
                           bg_label=0)
            fig = plt.figure()
            ax = fig.add_subplot(311)
            ax.imshow(gt)
            ax.set_title('label on image')
            ax = fig.add_subplot(312)
            ax.imshow(pred1)
            ax.set_title('pred0:1 on image')
            ax = fig.add_subplot(313)
            ax.imshow(pred2)
            ax.set_title('pred1:2 on image')
            fig.tight_layout()
            writer.add_figure('/train/results', fig, batch_idx)
            fig.clear()
        # print training loss per batch
        msg = 'epoch: {}, batch: {}, train_loss: {:.4f}, time: {:.4f} s/vol'
        msg = msg.format(epoch_idx, batch_idx, train_loss.item(),
                         sample_duration)
        logger.info(msg)

        # save checkpoint
        if epoch_idx != 0 and (epoch_idx % cfg.train.save_epochs == 0):
            if last_save_epoch != epoch_idx:
                save_checkpoint(net, opt, epoch_idx, batch_idx,
                                cfg, config_file, max_stride,
                                dataset.num_modality())
                last_save_epoch = epoch_idx
            '''evaluate testing set'''
        writer.add_scalar('Train/Loss', train_loss.item(), batch_idx)

    writer.close()
def test(config_file):
    '''Medical image segmentation testing engine
    :param config_file: the input confituration file 
    :return: NONE
    '''
    assert os.path.isfile(config_file), 'Config not found: {}'.format(config_file)
    total_metric = 0.0
    metric_dict = OrderedDict()
    metric_dict['name'] = []
    metric_dict['dice'] = []
    metric_dict['jaccard'] = []
    cfg = load_config(config_file)
    #print('cfg',cfg)
    
    #log_file = os.path.join(cfg.general.save_dir,'test_log.txt')
    #logger = setup_logger(log_file,'seg3d_test')

    np.random.seed(cfg.general.seed)
    torch.manual_seed(cfg.general.seed)
    if cfg.general.num_gpus > 0:
        torch.cuda.manual_seed(cfg.general.seed)
    
    #dataset = SegmentationTestDataset(cfg.test.imseg_list)
    dataset = SegmentationDataset(
            imlist_file=cfg.test.imseg_list,
            num_classes=cfg.dataset.num_classes,
            spacing=cfg.dataset.spacing,
            crop_size=cfg.dataset.crop_size,
            default_values=cfg.dataset.default_values,
            sampling_method=cfg.dataset.sampling_method,
            random_translation=cfg.dataset.random_translation,
            interpolation=cfg.dataset.interpolation,
            crop_normalizers=cfg.dataset.crop_normalizers
            )
    testloader = DataLoader(dataset,batch_size=cfg.test.batch_size, shuffle=False, num_workers=cfg.test.num_threads,pin_memory=True)
    print('dataset length',len(testloader))
    net_model = importlib.import_module('segmentation3d.network.'+cfg.net.name)
    net = net_model.SegmentationNet(dataset.num_modality(),cfg.dataset.num_classes)
    net = nn.parallel.DataParallel(net,device_ids=list(range(cfg.general.num_gpus)))
    net = net.cuda()
    epoch_idx = cfg.test.test_epoch
    save_dir = cfg.general.save_dir
    state = load_testmodel(epoch_idx,net,save_dir)
    
    net.load_state_dict(state['state_dict'])
    net.eval()
    for ii, (image,label,fram,name) in enumerate(testloader):
        name = name[0][6:-4] 
        print('testing patient',name)
        print('image',image.shape)
        #print('label',label.shape)
        prediction,score_map = test_single_case(net,image[0][0],label[0][0],stridex,stridey,stridez,patch_size,num_classes=cfg.dataset.num_classes)
        image = image[0][0].numpy()
        label = label[0][0].numpy()
        print('prediction',prediction.shape)
        print('label',label.shape)
        if np.sum(prediction) == 0:
            single_metric = (0,0)
        else:
            single_metric = calculate_metric_percase(prediction,label)
        print('single_metric',single_metric)
        metric_dict['name'].append(name)
        metric_dict['dice'].append(single_metric[0])
        metric_dict['jaccard'].append(single_metric[1])
        total_metric += np.asarray(single_metric)
        if cfg.test.save == True:
            test_save_path_temp = os.path.join(cfg.general.save_dir+cfg.test.save_filename+'/',name)
            if not os.path.exists(test_save_path_temp):
                os.makedirs(test_save_path_temp)
            nib.save(nib.Nifti1Image(prediction.astype(np.float32),np.eye(4)),test_save_path_temp+'/'+'pred.nii.gz')
            nib.save(nib.Nifti1Image(image.astype(np.float32),np.eye(4)),test_save_path_temp+'/'+'img.nii.gz')
            nib.save(nib.Nifti1Image(label.astype(np.float32),np.eye(4)),test_save_path_temp+'/'+'gt.nii.gz')
    avg_metric = total_metric / len(testloader)
    metric_csv = pd.DataFrame(metric_dict)
    metric_csv.to_csv(cfg.general.save_dir+cfg.test.save_filename+'/metric.csv',index=False)
    print('average metric is {}'.format(avg_metric))
    f = open(cfg.general.save_dir+cfg.test.save_filename+'/metric.csv','a+')
    f.write('%s,%.4f,%.4f\n'%('average',avg_metric[0],avg_metric[1]))
    f.close()
Example #5
0
def train(train_config_file):
    """ Medical image segmentation training engine
    :param train_config_file: the input configuration file
    :return: None
    """
    assert os.path.isfile(train_config_file), 'Config not found: {}'.format(train_config_file)

    # load config file
    train_cfg = load_config(train_config_file)

    # clean the existing folder if training from scratch
    model_folder = os.path.join(train_cfg.general.save_dir, train_cfg.general.model_scale)
    if os.path.isdir(model_folder):
        if train_cfg.general.resume_epoch < 0:
            shutil.rmtree(model_folder)
            os.makedirs(model_folder)
    else:
        os.makedirs(model_folder)

    # copy training and inference config files to the model folder
    shutil.copy(train_config_file, os.path.join(model_folder, 'train_config.py'))
    infer_config_file = os.path.join(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'config', 'infer_config.py'))
    shutil.copy(infer_config_file, os.path.join(train_cfg.general.save_dir, 'infer_config.py'))

    # enable logging
    log_file = os.path.join(model_folder, 'train_log.txt')
    logger = setup_logger(log_file, 'seg3d')

    # control randomness during training
    np.random.seed(train_cfg.general.seed)
    torch.manual_seed(train_cfg.general.seed)
    if train_cfg.general.num_gpus > 0:
        torch.cuda.manual_seed(train_cfg.general.seed)

    # dataset
    dataset = SegmentationDataset(
                imlist_file=train_cfg.general.imseg_list,
                num_classes=train_cfg.dataset.num_classes,
                spacing=train_cfg.dataset.spacing,
                crop_size=train_cfg.dataset.crop_size,
                sampling_method=train_cfg.dataset.sampling_method,
                random_translation=train_cfg.dataset.random_translation,
                random_scale=train_cfg.dataset.random_scale,
                interpolation=train_cfg.dataset.interpolation,
                crop_normalizers=train_cfg.dataset.crop_normalizers)

    sampler = EpochConcateSampler(dataset, train_cfg.train.epochs)
    data_loader = DataLoader(dataset, sampler=sampler, batch_size=train_cfg.train.batchsize,
                             num_workers=train_cfg.train.num_threads, pin_memory=True)

    net_module = importlib.import_module('segmentation3d.network.' + train_cfg.net.name)
    net = net_module.SegmentationNet(dataset.num_modality(), train_cfg.dataset.num_classes)
    max_stride = net.max_stride()
    net_module.parameters_kaiming_init(net)
    if train_cfg.general.num_gpus > 0:
        net = nn.parallel.DataParallel(net, device_ids=list(range(train_cfg.general.num_gpus)))
        net = net.cuda()

    assert np.all(np.array(train_cfg.dataset.crop_size) % max_stride == 0), 'crop size not divisible by max stride'

    # training optimizer
    opt = optim.Adam(net.parameters(), lr=train_cfg.train.lr, betas=train_cfg.train.betas)

    # load checkpoint if resume epoch > 0
    if train_cfg.general.resume_epoch >= 0:
        last_save_epoch, batch_start = load_checkpoint(train_cfg.general.resume_epoch, net, opt, model_folder)
    else:
        last_save_epoch, batch_start = 0, 0

    if train_cfg.loss.name == 'Focal':
        # reuse focal loss if exists
        loss_func = FocalLoss(class_num=train_cfg.dataset.num_classes, alpha=train_cfg.loss.obj_weight, gamma=train_cfg.loss.focal_gamma,
                              use_gpu=train_cfg.general.num_gpus > 0)
    elif train_cfg.loss.name == 'Dice':
        loss_func = MultiDiceLoss(weights=train_cfg.loss.obj_weight, num_class=train_cfg.dataset.num_classes,
                                  use_gpu=train_cfg.general.num_gpus > 0)
    elif train_cfg.loss.name == 'CE':
        loss_func = CrossEntropyLoss()

    else:
        raise ValueError('Unknown loss function')

    writer = SummaryWriter(os.path.join(model_folder, 'tensorboard'))

    batch_idx = batch_start
    data_iter = iter(data_loader)

    # loop over batches
    for i in range(len(data_loader)):
        begin_t = time.time()

        crops, masks, frames, filenames = data_iter.next()

        if train_cfg.general.num_gpus > 0:
            crops, masks = crops.cuda(), masks.cuda()

        # clear previous gradients
        opt.zero_grad()

        # network forward and backward
        outputs = net(crops)
        train_loss = loss_func(outputs, masks)
        train_loss.backward()

        # update weights
        opt.step()

        # save training crops for visualization
        if train_cfg.debug.save_inputs:
            batch_size = crops.size(0)
            save_intermediate_results(list(range(batch_size)), crops, masks, outputs, frames, filenames,
                                      os.path.join(model_folder, 'batch_{}'.format(i)))

        epoch_idx = batch_idx * train_cfg.train.batchsize // len(dataset)
        batch_idx += 1
        batch_duration = time.time() - begin_t
        sample_duration = batch_duration * 1.0 / train_cfg.train.batchsize

        # print training loss per batch
        msg = 'epoch: {}, batch: {}, train_loss: {:.4f}, time: {:.4f} s/vol'
        msg = msg.format(epoch_idx, batch_idx, train_loss.item(), sample_duration)
        logger.info(msg)

        # save checkpoint
        if epoch_idx != 0 and (epoch_idx % train_cfg.train.save_epochs == 0):
            if last_save_epoch != epoch_idx:
                save_checkpoint(net, opt, epoch_idx, batch_idx, train_cfg, max_stride, dataset.num_modality())
                last_save_epoch = epoch_idx

        writer.add_scalar('Train/Loss', train_loss.item(), batch_idx)

    writer.close()
Example #6
0
def test(config_file):
    '''Medical image segmentation testing engine
    :param config_file: the input confituration file 
    :return: NONE
    '''
    assert os.path.isfile(config_file), 'Config not found: {}'.format(
        config_file)
    total_metric = 0.0
    metric_dict = OrderedDict()
    metric_dict['name'] = []
    metric_dict['dice'] = []
    metric_dict['jaccard'] = []
    cfg = load_config(config_file)
    if cfg.general.num_gpus > 0:
        os.environ['CUDA_VISIBLE_DEVICES'] = cfg.general.gpu

    np.random.seed(cfg.general.seed)
    torch.manual_seed(cfg.general.seed)
    if cfg.general.num_gpus > 0:
        torch.cuda.manual_seed(cfg.general.seed)

    dataset = ABUS(base_dir=cfg.test.imseg_list,
                   transform=transforms.Compose([ToTensor()]))
    testloader = DataLoader(dataset,
                            batch_size=cfg.test.batch_size,
                            shuffle=False,
                            num_workers=cfg.test.num_threads,
                            pin_memory=True)
    print('dataset length', len(testloader))
    net_model = importlib.import_module('segmentation3d.network.' +
                                        cfg.net.name)
    net = net_model.SegmentationNet(1, cfg.dataset.num_classes)
    net = nn.parallel.DataParallel(net,
                                   device_ids=list(range(
                                       cfg.general.num_gpus)))
    net = net.cuda()
    epoch_idx = cfg.test.test_epoch
    save_dir = cfg.test.model_dir
    state = load_testmodel(epoch_idx, net, save_dir)
    net.load_state_dict(state['state_dict'])
    net.eval()
    for ii, sample in enumerate(testloader):
        name = sample['name']
        print('testing patient', name)
        crops, masks = sample['image'], sample['label']
        crops, masks = crops.cuda(), masks.cuda()
        outputs = net(crops)
        #print('outputs',outputs.shape)

        #outputs_softmax = F.softmax(outputs,dim=1)
        output_numpy = outputs.cpu().data.numpy()[0,
                                                  1, :, :]  #0for bh 1 for map
        #print('output_numpy',output_numpy.shape)
        #print('output_numpy',output_numpy.max())
        pred = output_numpy
        pred[pred > 0.5] = 1
        pred[pred != 1] = 0
        #print('pred',pred.shape)
        #print('pred',pred.max())
        masks = masks.cpu().detach().numpy()[0, :, :, :]
        crops = crops.cpu().detach().numpy()[0, 0, :, :, :]
        #print('crops',crops.shape)
        #print('masks',masks.shape)
        #print('np.sum(pred)',np.sum(pred))
        if np.sum(pred) == 0:
            single_metric = (0, 0)
        else:
            single_metric = calculate_metric_percase(pred, masks)
        print('single_metric', single_metric)
        metric_dict['name'].append(name)
        metric_dict['dice'].append(single_metric[0])
        metric_dict['jaccard'].append(single_metric[1])
        total_metric += np.asarray(single_metric)
        if cfg.test.save == True:
            test_save_path_temp = os.path.join(
                cfg.test.model_dir + cfg.test.save_filename + '/', name[0])
            if not os.path.exists(test_save_path_temp):
                os.makedirs(test_save_path_temp)
            print('test_save_path_temp', test_save_path_temp)
            nib.save(nib.Nifti1Image(pred.astype(np.float32), np.eye(4)),
                     test_save_path_temp + '/' + 'pred.nii.gz')
            nib.save(nib.Nifti1Image(crops.astype(np.float32), np.eye(4)),
                     test_save_path_temp + '/' + 'img.nii.gz')
            nib.save(nib.Nifti1Image(masks.astype(np.float32), np.eye(4)),
                     test_save_path_temp + '/' + 'gt.nii.gz')
    avg_metric = total_metric / len(testloader)
    metric_csv = pd.DataFrame(metric_dict)
    metric_csv.to_csv(cfg.test.model_dir + cfg.test.save_filename +
                      '/metric.csv',
                      index=False)
    print('average metric is {}'.format(avg_metric))
    f = open(cfg.test.model_dir + cfg.test.save_filename + '/metric.csv', 'a+')
    f.write('%s,%.4f,%.4f\n' % ('average', avg_metric[0], avg_metric[1]))
    f.close()
Example #7
0
def load_seg_model(model_folder, gpu_id=0):
    """ load segmentation model from folder
  :param model_folder:    the folder containing the segmentation model
  :param gpu_id:          the gpu device id to run the segmentation model
  :return: a dictionary containing the model and inference parameters
  """
    assert os.path.isdir(
        model_folder), 'Model folder does not exist: {}'.format(model_folder)

    model = edict()

    # load inference config file
    latest_checkpoint_dir = get_checkpoint_folder(
        os.path.join(model_folder, 'checkpoints'), -1)
    infer_cfg = load_config(
        os.path.join(latest_checkpoint_dir, 'infer_config.py'))
    model.infer_cfg = infer_cfg

    # load model state
    chk_file = os.path.join(latest_checkpoint_dir, 'params.pth')

    if gpu_id >= 0:
        os.environ['CUDA_VISIBLE_DEVICES'] = '{}'.format(int(gpu_id))

        # load network module
        state = torch.load(chk_file)
        net_module = importlib.import_module('segmentation3d.network.' +
                                             state['net'])
        net = net_module.SegmentationNet(state['in_channels'],
                                         state['out_channels'],
                                         state['dropout'])
        net = nn.parallel.DataParallel(net)
        net.load_state_dict(state['state_dict'])
        net.eval()
        net = net.cuda()

        del os.environ['CUDA_VISIBLE_DEVICES']

    else:
        state = torch.load(chk_file, map_location='cpu')
        net_module = importlib.import_module('segmentation3d.network.' +
                                             state['net'])
        net = net_module.SegmentationNet(state['in_channels'],
                                         state['out_channels'],
                                         state['dropout'])
        net = nn.parallel.DataParallel(net)
        net.load_state_dict(state['state_dict'])
        net.eval()

    model.net = net
    model.spacing, model.max_stride, model.interpolation = state[
        'spacing'], state['max_stride'], state['interpolation']
    model.in_channels, model.out_channels = state['in_channels'], state[
        'out_channels']

    model.crop_normalizers = []
    for crop_normalizer in state['crop_normalizers']:
        if crop_normalizer['type'] == 0:
            mean, stddev, clip = crop_normalizer['mean'], crop_normalizer[
                'stddev'], crop_normalizer['clip']
            model.crop_normalizers.append(FixedNormalizer(mean, stddev, clip))

        elif crop_normalizer['type'] == 1:
            min_p, max_p, clip = crop_normalizer['min_p'], crop_normalizer[
                'max_p'], crop_normalizer['clip']
            model.crop_normalizers.append(
                AdaptiveNormalizer(min_p, max_p, clip))

        else:
            raise ValueError('Unsupported normalization type.')

    return model