Beispiel #1
0
def main():
    start_time = time.time()
    # setup environments and seeds
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    # setup networks
    #Network = getattr(models, args.net)
    #model = Network(**args.net_params)

    model = Modified3DUNet(in_channels=1, n_classes=2, base_n_filter=16)
    #load_model
    model_file = os.path.join(ckpts, 'model_last.tar')
    print model_file
    checkpoint = torch.load(model_file,
                            map_location=lambda storage, loc: storage)
    model.load_state_dict(checkpoint['state_dict'])
    model = model.cuda()
    '''optimizer = getattr(torch.optim, args.opt)(
            model.parameters(), **args.opt_params)'''
    #optimizer = torch.optim.SGD(model.parameters(),lr = 0.1,momentum=0.9)

    criterion = getattr(criterions, args.criterion)
    num_gpus = len(args.gpu.split(','))
    args.batch_size *= num_gpus
    args.workers *= num_gpus

    # create dataloaders
    #Dataset = getattr(datasets, args.dataset)
    dset = cell_testing_inter('/home/tom/data1_match/dataset4/')
    print dset.__len__()
    test_loader = DataLoader(dset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=args.workers,
                             pin_memory=True)
    model.eval()
    torch.set_grad_enabled(False)
    inputs = []
    outputs = []
    ground_truth = []
    for i, sample in enumerate(test_loader):
        input = sample['data']
        img_size = sample['image_size']
        print img_size[0]
        file_name = sample['name']
        file_name = str(file_name[0])
        _, _, z, x, y = input.shape
        seg = np.zeros((z, x, y))
        target = sample['seg']
        for j in range(z / 16):
            #ground_truth.append(target)
            #print file_name[0]
            input_temp = input[0, 0, j * 16:(j + 1) * 16].float()
            input_temp = input_temp[None, None, ...]
            output_temp = nn.parallel.data_parallel(model, input_temp)
            output_temp = output_temp.detach().cpu().numpy()
            output_temp = output_temp[0]
            seg_temp = output_temp.argmax(0)
            seg[j * 16:(j + 1) * 16] = seg_temp
        data = input.detach().numpy()
        #print data.shape
        data = data[0, 0, :, :, :]
        data = (255 * data[0:5 * img_size[0]]).astype('uint8')
        data_img = sitk.GetImageFromArray(data)
        sitk.WriteImage(data_img, '/home/tom/result/' + file_name)
        #outputs.append(output)
        #output = output[0]
        #print output.shape
        #seg = output.argmax(0)
        seg = (seg[0:5 * img_size[0]] * 255).astype('uint8')
        seg = seg.astype('float32')
        seg = seg / 255.0
        seg = np.multiply(data, seg)
        result = np.zeros(img_size)
        result = seg[0:img_size[0], 0:512, 0:512]
        '''
        result = result/255
        threshold = 0.05
        result[result>0.06] = 1
        result[result<=0.02] = 0
        result = binary_closing(result)
        gt = target[0]
        gt = gt[0]
        gt = gt[0:img_size[0],0:512,0:512]
        gt = gt.numpy()
        print ("precision:%f",Precision_img(result,gt))
        print ("Recall:%f",Recall_img(result,gt))
        print ("f1_score:%f",F1_score_img(result,gt))'''
        result = (result * 255).astype('uint8')
        seg = sitk.GetImageFromArray(result)
        sitk.WriteImage(seg, '/home/tom/membrane/' + file_name + 'mem.tif')
        print("running time %s" % (time.time() - start_time))
Beispiel #2
0
def main(model_path, cell_hist_datadir, prob_map_datadir):
    props = readprops(os.path.join(model_path, 'cfg.txt'))
    # setup environments and seeds
    os.environ['CUDA_VISIBLE_DEVICES'] = props['gpu']

    # setup networks
    #Network = getattr(models, args.net)
    #model = Network(**args.net_params)

    model = Modified3DUNet(in_channels=1, n_classes=2, base_n_filter=16)
    #load_model
    model_file = os.path.join(model_path, 'model_last.tar')
    checkpoint = torch.load(model_file,
                            map_location=lambda storage, loc: storage)
    model.load_state_dict(checkpoint['state_dict'])
    model = model.cuda()

    criterion = getattr(criterions, props['criterion'])
    num_gpus = len(props['gpu'].split(','))
    batch_size = int(props['batch_size']) * num_gpus
    workers = int(props['workers']) * num_gpus

    # create dataloaders
    #Dataset = getattr(datasets, args.dataset)
    dset = cell_testing_inter(cell_hist_datadir)
    print dset.__len__()
    test_loader = DataLoader(dset,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=workers,
                             pin_memory=True)
    model.eval()
    torch.set_grad_enabled(False)
    inputs = []
    outputs = []
    ground_truth = []
    for i, sample in enumerate(test_loader):
        input = sample['data']
        img_size = sample['image_size']
        print img_size[0]
        file_name = sample['name']
        print file_name
        file_name = os.path.splitext(file_name[0])[0]  # str(file_name[0])
        _, _, z, x, y = input.shape
        seg = np.zeros((z, x, y))
        for j in range(z / 16):
            target = sample['seg']
            #ground_truth.append(target)
            #print file_name[0]
            input_temp = input[0, 0, j * 16:(j + 1) * 16].float()
            input_temp = input_temp[None, None, ...]
            output_temp = nn.parallel.data_parallel(model, input_temp)
            output_temp = output_temp.detach().cpu().numpy()
            output_temp = output_temp[0]
            seg_temp = output_temp.argmax(0)
            seg[j * 16:(j + 1) * 16] = seg_temp
        data = input.detach().numpy()
        #print data.shape
        data = data[0, 0, :, :, :]
        data = (255 * data[0:5 * img_size[0]]).astype('uint8')
        #outputs.append(output)
        #output = output[0]
        #print output.shape
        #seg = output.argmax(0)
        prob_map = (seg[0:5 * img_size[0]] * 255).astype('uint8')
        prob_map = prob_map.astype('float32')
        prob_map = prob_map / 255.0
        prob_map = np.multiply(data, prob_map)
        prob_map = resize(
            prob_map,
            (prob_map.shape[0] / 5, prob_map.shape[1], prob_map.shape[2]))
        prob_map_img = sitk.GetImageFromArray(prob_map.astype('uint8'))
        sitk.WriteImage(prob_map_img,
                        prob_map_datadir + '/' + file_name + '-prob.tif')