def exec_model(model):
    """test model on user-provided data, instead of the preset DeepLesion dataset"""
    import_tag_data()
    model.eval()
    device = torch.device(cfg.MODEL.DEVICE)

    #while True:
        #info = "Please input the path of a nifti CT volume >> "
        #while True:
            #path = input(info)
# ------- Zhoubing 100 datasets -------
#    for num in range(12):
#        if num + 1  < 10:       
#            img_num = 'img000' + str(num + 1)
#            #data_dir = '/nfs/masi/leeh43/MULAN_universal_lesion_analysis/results'
#            #img_dir = '_nfs_masi_leeh43_zhoubing100_img_' + img_num + '.nii.gz/'
#            #result = os.path.join(data_dir, img_dir + 'results.txt' )
#            main_dir = '/nfs/masi/leeh43/zhoubing100/img/'
#            img_dir = os.path.join(main_dir, img_num + '.nii.gz')
#            
#        if num + 1 >= 10 and num + 1 < 100:       
#            img_num = 'img00' + str(num + 1)
#            main_dir = '/nfs/masi/leeh43/zhoubing100/img/'
#            img_dir = os.path.join(main_dir, img_num + '.nii.gz')
#            
#        if num + 1 == 100:       
#            img_num = 'img0' + str(num + 1)
#            main_dir = '/nfs/masi/leeh43/zhoubing100/img/'
#            img_dir = os.path.join(main_dir, img_num + '.nii.gz')
#            if not os.path.exists(img_dir):
#                print('file does not exist!')
#                continue
#        #try:
            
# ------- ImageVU B Datasets -------
    data_dir = os.path.join('/nfs/masi/tangy5/ImageVU_B_bpr_pipeline/INPUTS/cropped/images')
    count = 0
    for item in os.listdir(data_dir):
        img_dir = os.path.join(data_dir, item)
        
        print('reading image ...')
        nifti_data = nib.load(img_dir)
        count = count + 1
        print('Number of Datasets: %d' % count)
        print('Load Image: %s' % img_dir)
            #break
        #except:
            #print('load nifti file error!')

        while True:
            win_sel = '1' #input('Window to show, 1:soft tissue, 2:lung, 3: bone >> ')
            if win_sel not in ['1', '2', '3']:
                continue
            win_show = [[-175, 275], [-1500, 500], [-500, 1300]]
            win_show = win_show[int(win_sel)-1]
            break

        vol, spacing, slice_intv = load_preprocess_nifti(nifti_data)
            
        slice_num_per_run = max(1, int(float(cfg.TEST.TEST_SLICE_INTV_MM)/slice_intv+.5))
        num_total_slice = vol.shape[2]

        total_time = 0
        imageVU_dir = 'ImageVU_B_result'
        output_dir = os.path.join(cfg.RESULTS_DIR,imageVU_dir,img_dir.replace(os.sep, '_'))
        if not os.path.exists(output_dir):
            os.mkdir(output_dir)
        
        
        slices_to_process = range(int(slice_num_per_run/2), num_total_slice, slice_num_per_run)
        msgs_all = []
        print('predicting ...')
        for slice_idx in tqdm(slices_to_process):
            log_file_s = os.path.join(output_dir, 'slice_' + str(slice_idx) + '_resize_shape.csv') 
            log_file_c = os.path.join(output_dir, 'slice_' + str(slice_idx) + '_contour_location.csv')    
            log_file_r = os.path.join(output_dir, 'slice_' + str(slice_idx) + '_recist_location.csv')
            log_file_mask = os.path.join(output_dir, 'slice_' + str(slice_idx) + '_mask_c.csv')
            mask_list = []
            ims, im_np, im_scale, crop, mask_list = get_ims(slice_idx, vol, spacing, slice_intv, mask_list)
            im_list = to_image_list(ims, cfg.DATALOADER.SIZE_DIVISIBILITY).to(device)
            start_time = time()
            with torch.no_grad():
                result = model(im_list)
            result = [o.to("cpu") for o in result]
            
            df_resize = pd.DataFrame()
            df_contours = pd.DataFrame()
            df_recists = pd.DataFrame()
            df_mask = pd.DataFrame()
            shape_0, shape_1 = [], []
            cour_list1, cour_list2 = [], []
            recist_list1, recist_list2 = [], []
            info = {'spacing': spacing, 'im_scale': im_scale}
            post_process_results(result[0], info)
            total_time += time() - start_time
            output_fn = os.path.join(output_dir, '%d.png'%(slice_idx+1))
            real_slice_num = slice_idx + 1
            #contour_list.append('Slice_'+str(real_slice_num))
            #recist_list.append(('Slice_'+str(real_slice_num)))
            shape_0.append(im_np.shape[0])
            shape_1.append(im_np.shape[1])
            overlay, msgs = gen_output(im_np, result[0], info, win_show, cour_list1, cour_list2, recist_list1, recist_list2)
            df_resize['Shape_0'] = shape_0
            df_resize['Shape_1'] = shape_1
            df_contours['list1'] = cour_list1
            df_contours['list2'] = cour_list2
            df_recists['list1'] = recist_list1
            df_mask['c'] = mask_list
            df_resize.to_csv(log_file_s, index=False)
            df_contours.to_csv(log_file_c, index = False)
            df_recists.to_csv(log_file_r, index = False)
            df_mask.to_csv(log_file_mask, index = False)
            cv2.imwrite(output_fn, overlay)
            msgs_all.append('slice %d\r\n' % (slice_idx+1))
            for msg in msgs:
                msgs_all.append(msg+'\r\n')
            msgs_all.append('\r\n')
        
        #np.savetxt(log_file_c, cour_list1, cour_list2, delimiter=',', fmt='%s')
        with open(os.path.join(output_dir, 'results.txt'), 'w') as f:
            f.writelines(msgs_all)
        print('result images and text saved to', output_dir)
        print('processing time: %d ms per slice' % int(1000.*total_time/len(slices_to_process)))
Example #2
0
def exec_model(model):
    """test model on user-provided data, instead of the preset DeepLesion dataset"""
    import_tag_data()
    model.eval()
    device = torch.device(cfg.MODEL.DEVICE)

    while True:
        info = "Please input the path of a nifti CT volume >> "
        while True:
            path = input(info)
            if not os.path.exists(path):
                print('file does not exist!')
                continue
            try:
                print('reading image ...')
                nifti_data = nib.load(path)
                break
            except:
                print('load nifti file error!')

        while True:
            win_sel = input(
                'Window to show, 1:soft tissue, 2:lung, 3: bone >> ')
            if win_sel not in ['1', '2', '3']:
                continue
            win_show = [[-175, 275], [-1500, 500], [-500, 1300]]
            win_show = win_show[int(win_sel) - 1]
            break

        vol, spacing, slice_intv = load_preprocess_nifti(nifti_data)

        slice_num_per_run = max(
            1, int(float(cfg.TEST.TEST_SLICE_INTV_MM) / slice_intv + .5))
        num_total_slice = vol.shape[2]

        total_time = 0
        output_dir = os.path.join(cfg.RESULTS_DIR, path.replace(os.sep, '_'))
        if not os.path.exists(output_dir):
            os.mkdir(output_dir)

        slices_to_process = range(int(slice_num_per_run / 2), num_total_slice,
                                  slice_num_per_run)
        msgs_all = []
        print('predicting ...')
        for slice_idx in tqdm(slices_to_process):
            ims, im_np, im_scale, crop = get_ims(slice_idx, vol, spacing,
                                                 slice_intv)
            im_list = to_image_list(
                ims, cfg.DATALOADER.SIZE_DIVISIBILITY).to(device)
            start_time = time()
            with torch.no_grad():
                result = model(im_list)
            result = [o.to("cpu") for o in result]

            info = {'spacing': spacing, 'im_scale': im_scale}
            post_process_results(result[0], info)
            total_time += time() - start_time
            output_fn = os.path.join(output_dir, '%d.png' % (slice_idx + 1))
            overlay, msgs = gen_output(im_np, result[0], info, win_show)

            cv2.imwrite(output_fn, overlay)
            msgs_all.append('slice %d\r\n' % (slice_idx + 1))
            for msg in msgs:
                msgs_all.append(msg + '\r\n')
            msgs_all.append('\r\n')

        with open(os.path.join(output_dir, 'results.txt'), 'w') as f:
            f.writelines(msgs_all)

        print('result images and text saved to', output_dir)
        print('processing time: %d ms per slice' %
              int(1000. * total_time / len(slices_to_process)))
Example #3
0
def batch_exec_model(model):
    """test model on user-provided folder of data, instead of the preset DeepLesion dataset"""
    import_tag_data()
    model.eval()
    device = torch.device(cfg.MODEL.DEVICE)

    info = "Please input the path which contains all nifti CT volumes to predict in batch >> "
    while True:
        path = input(info)
        if not os.path.exists(path):
            print('folder does not exist!')
            continue
        else:
            break

    nifti_paths = []
    for dirName, subdirList, fileList in os.walk(path):
        print('found directory: %s' % dirName)
        for fname in fileList:
            if fname.endswith('.nii.gz') or fname.endswith('.nii'):
                nifti_paths.append(os.path.join(path, dirName, fname))
    print('%d nifti files found' % len(nifti_paths))

    total_time = 0
    results = {}
    total_slices = 0
    for file_idx, nifti_path in enumerate(nifti_paths):
        print('(%d/%d) %s' % (file_idx + 1, len(nifti_paths), nifti_path))
        print('reading image ...')
        try:
            nifti_data = nib.load(nifti_path)
        except:
            print('load nifti file error!')
            continue

        vol, spacing, slice_intv = load_preprocess_nifti(nifti_data)
        slice_num_per_run = max(
            1, int(float(cfg.TEST.TEST_SLICE_INTV_MM) / slice_intv + .5))
        num_total_slice = vol.shape[2]
        results[nifti_path] = {}

        slices_to_process = range(int(slice_num_per_run / 2), num_total_slice,
                                  slice_num_per_run)
        total_slices += len(slices_to_process)
        for slice_idx in tqdm(slices_to_process):
            ims, im_np, im_scale, crop = get_ims(slice_idx, vol, spacing,
                                                 slice_intv)
            im_list = to_image_list(
                ims, cfg.DATALOADER.SIZE_DIVISIBILITY).to(device)
            start_time = time()
            with torch.no_grad():
                result = model(im_list)
            result = [o.to("cpu") for o in result]

            info = {'spacing': spacing, 'im_scale': im_scale, 'crop': crop}
            post_process_results(result[0], info)
            result = sort_results_for_batch(result[0], info)
            results[nifti_path][slice_idx] = result
            total_time += time() - start_time

            # # sanity check
            # im = vol[:, :, slice_idx].astype(float) - 32768
            # im = windowing(im, [-175, 275]).astype('uint8')
            # im = cv2.cvtColor(im, cv2.COLOR_GRAY2RGB)
            # overlay, msgs = draw_results(im, np.array(result['boxes']), scores=np.array(result['scores']),
            #                              tag_scores=np.array(result['tag_scores']), tag_predictions=np.array(result['tag_scores'])>.5,
            #                              contours=np.array(result['contour_mm']))
            # plt.imshow(overlay)
            # print(msgs)
            # plt.show()

    output_dir = os.path.join(cfg.RESULTS_DIR)
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    output_fn = os.path.join(output_dir, '%s.pth' % os.path.basename(path))
    torch.save(results, output_fn)
    print('result images and text saved to', output_fn)
    print('processing time: %d ms per slice' %
          int(1000. * total_time / total_slices))