Пример #1
0
def run(img_folder,
        img_size=[288, 224],
        do_featurewise_norm=True,
        featurewise_mean=485.9,
        featurewise_std=765.2,
        batch_size=16,
        img_tsv='./metadata/images_crosswalk_prediction.tsv',
        exam_tsv=None,
        dl_state=None,
        enet_state=None,
        validation_mode=False,
        use_mean=False,
        out_pred='./output/predictions.tsv'):
    '''Run SC1 inference
    Args:
        featurewise_mean, featurewise_std ([float]): they are estimated from 
                1152 x 896 images. Using different sized images give very close
                results. For png, mean=7772, std=12187.
    '''

    # Setup data generator for inference.
    meta_man = DMMetaManager(img_tsv=img_tsv,
                             exam_tsv=exam_tsv,
                             img_folder=img_folder,
                             img_extension='dcm')
    if validation_mode:
        exam_list = meta_man.get_flatten_exam_list()
    else:
        exam_list = meta_man.get_last_exam_list()
    if do_featurewise_norm:
        img_gen = DMImageDataGenerator(featurewise_center=True,
                                       featurewise_std_normalization=True)
        img_gen.mean = featurewise_mean
        img_gen.std = featurewise_std
    else:
        img_gen = DMImageDataGenerator(samplewise_center=True,
                                       samplewise_std_normalization=True)
    if validation_mode:
        class_mode = 'binary'
    else:
        class_mode = None
    datgen_exam = img_gen.flow_from_exam_list(exam_list,
                                              target_size=(img_size[0],
                                                           img_size[1]),
                                              class_mode=class_mode,
                                              prediction_mode=True,
                                              batch_size=batch_size)

    if enet_state is not None:
        model = MultiViewDLElasticNet(*enet_state)
    elif dl_state is not None:
        model = load_model(dl_state)
    else:
        raise Exception('At least one model state must be specified.')
    exams_seen = 0
    fout = open(out_pred, 'w')

    # Print header.
    if validation_mode:
        fout.write(dminfer.INFER_HEADER_VAL)
    else:
        fout.write(dminfer.INFER_HEADER)

    while exams_seen < len(exam_list):
        ebat = next(datgen_exam)
        if class_mode is not None:
            bat_x = ebat[0]
            bat_y = ebat[1]
        else:
            bat_x = ebat
        subj_batch = bat_x[0]
        exam_batch = bat_x[1]
        cc_batch = bat_x[2]
        mlo_batch = bat_x[3]
        for i, subj in enumerate(subj_batch):
            exam = exam_batch[i]
            li = i * 2  # left breast index.
            ri = i * 2 + 1  # right breast index.
            left_pred = dminfer.pred_2view_img_list(cc_batch[li],
                                                    mlo_batch[li], model,
                                                    use_mean)
            right_pred = dminfer.pred_2view_img_list(cc_batch[ri],
                                                     mlo_batch[ri], model,
                                                     use_mean)
            if validation_mode:
                fout.write("%s\t%s\tL\t%f\t%f\n" % \
                           (str(subj), str(exam), left_pred, bat_y[li]))
                fout.write("%s\t%s\tR\t%f\t%f\n" % \
                           (str(subj), str(exam), right_pred, bat_y[ri]))
            else:
                fout.write("%s\tL\t%f\n" % (str(subj), left_pred))
                fout.write("%s\tR\t%f\n" % (str(subj), right_pred))

        exams_seen += len(subj_batch)

    fout.close()
def run(img_folder,
        img_height=1024,
        img_scale=4095,
        roi_per_img=32,
        roi_size=(256, 256),
        low_int_threshold=.05,
        blob_min_area=3,
        blob_min_int=.5,
        blob_max_int=.85,
        blob_th_step=10,
        roi_state=None,
        roi_bs=32,
        do_featurewise_norm=True,
        featurewise_mean=884.7,
        featurewise_std=745.3,
        img_tsv='./metadata/images_crosswalk_prediction.tsv',
        exam_tsv=None,
        dl_state=None,
        dl_bs=32,
        nb_top_avg=1,
        validation_mode=False,
        val_size=None,
        img_voting=False,
        out_pred='./output/predictions.tsv'):
    '''Run SC1 inference using the candidate ROI approach
    Notes: 
        "mean=884.7, std=745.3" are estimated from 20 subjects on the 
        training data.
    '''

    # Read some env variables.
    random_seed = int(os.getenv('RANDOM_SEED', 12345))
    # nb_worker = int(os.getenv('NUM_CPU_CORES', 4))
    gpu_count = int(os.getenv('NUM_GPU_DEVICES', 1))

    # Setup data generator for inference.
    meta_man = DMMetaManager(img_tsv=img_tsv,
                             exam_tsv=exam_tsv,
                             img_folder=img_folder,
                             img_extension='dcm')
    if val_size is not None:  # Use a subset for validation.
        subj_list, subj_labs = meta_man.get_subj_labs()
        _, subj_test = train_test_split(subj_list,
                                        test_size=val_size,
                                        random_state=random_seed,
                                        stratify=subj_labs)
    else:
        subj_test = None

    if validation_mode:
        exam_list = meta_man.get_flatten_exam_list(subj_list=subj_test,
                                                   flatten_img_list=True)
    else:
        exam_list = meta_man.get_last_exam_list(subj_list=subj_test,
                                                flatten_img_list=True)

    if do_featurewise_norm:
        img_gen = DMImageDataGenerator(featurewise_center=True,
                                       featurewise_std_normalization=True)
        img_gen.mean = featurewise_mean
        img_gen.std = featurewise_std
    else:
        img_gen = DMImageDataGenerator(samplewise_center=True,
                                       samplewise_std_normalization=True)
    if validation_mode:
        class_mode = 'categorical'
    else:
        class_mode = None

    # Load ROI classifier.
    if roi_state is not None:
        roi_clf = load_model(roi_state,
                             custom_objects={
                                 'sensitivity': DMMetrics.sensitivity,
                                 'specificity': DMMetrics.specificity
                             })
        if gpu_count > 1:
            roi_clf = make_parallel(roi_clf, gpu_count)
    else:
        roi_clf = None

    # Load model.
    if dl_state is not None:
        model = load_model(dl_state)
    else:
        raise Exception('At least one model state must be specified.')
    if gpu_count > 1:
        model = make_parallel(model, gpu_count)

    # A function to make predictions on image patches from an image list.
    def pred_img_list(img_list):
        roi_generator = img_gen.flow_from_candid_roi(
            img_list,
            target_height=img_height,
            target_scale=img_scale,
            class_mode=class_mode,
            validation_mode=True,
            img_per_batch=len(img_list),
            roi_per_img=roi_per_img,
            roi_size=roi_size,
            low_int_threshold=low_int_threshold,
            blob_min_area=blob_min_area,
            blob_min_int=blob_min_int,
            blob_max_int=blob_max_int,
            blob_th_step=blob_th_step,
            roi_clf=roi_clf,
            clf_bs=roi_bs,
            return_sample_weight=True,
            seed=random_seed)
        roi_dat, roi_w = roi_generator.next()
        # import pdb; pdb.set_trace()
        pred = model.predict(roi_dat, batch_size=dl_bs)
        pred = pred[:, 1]  # cancer class predictions.
        if roi_clf is not None:
            # return np.average(pred, weights=roi_w)
            # import pdb; pdb.set_trace()
            return pred[np.argsort(roi_w)[-nb_top_avg:]].mean()
        elif img_voting:
            pred = pred.reshape((-1, roi_per_img))
            img_preds = [np.sort(row)[-nb_top_avg:].mean() for row in pred]
            return np.mean(img_preds)
        else:
            return np.sort(pred)[-nb_top_avg:].mean()

    # Print header.
    fout = open(out_pred, 'w')
    if validation_mode:
        fout.write(dminfer.INFER_HEADER_VAL)
    else:
        fout.write(dminfer.INFER_HEADER)

    for subj, exidx, exam in exam_list:
        try:
            predL = pred_img_list(exam['L']['img'])
        except KeyError:
            predL = .0
        try:
            predR = pred_img_list(exam['R']['img'])
        except KeyError:
            predR = .0

        try:
            cancerL = int(exam['L']['cancer'])
        except ValueError:
            cancerL = 0
        try:
            cancerR = int(exam['R']['cancer'])
        except ValueError:
            cancerR = 0

        if validation_mode:
            fout.write("%s\t%s\tL\t%f\t%d\n" % \
                       (str(subj), str(exidx), predL, cancerL))
            fout.write("%s\t%s\tR\t%f\t%d\n" % \
                       (str(subj), str(exidx), predR, cancerR))
        else:
            fout.write("%s\tL\t%f\n" % (str(subj), predL))
            fout.write("%s\tR\t%f\n" % (str(subj), predR))

    fout.close()
Пример #3
0
def run(img_folder,
        dl_state,
        clf_info_state,
        img_extension='dcm',
        img_height=4096,
        img_scale=255.,
        equalize_hist=False,
        featurewise_center=False,
        featurewise_mean=91.6,
        net='resnet50',
        batch_size=64,
        patch_size=256,
        stride=64,
        exam_tsv='./metadata/exams_metadata.tsv',
        img_tsv='./metadata/images_crosswalk.tsv',
        validation_mode=False,
        use_mean=False,
        out_pred='./output/predictions.tsv',
        progress='./progress.txt'):
    '''Run SC1 inference using prob heatmaps
    '''
    # Read some env variables.
    random_seed = int(os.getenv('RANDOM_SEED', 12345))
    rng = np.random.RandomState(random_seed)  # an rng used across board.
    gpu_count = int(os.getenv('NUM_GPU_DEVICES', 1))

    # Setup data generator for inference.
    meta_man = DMMetaManager(img_tsv=img_tsv,
                             exam_tsv=exam_tsv,
                             img_folder=img_folder,
                             img_extension=img_extension)
    if validation_mode:
        exam_list = meta_man.get_flatten_exam_list(cc_mlo_only=True)
        exam_labs = meta_man.exam_labs(exam_list)
        exam_labs = np.array(exam_labs)
        print "positive exams=%d, negative exams=%d" \
                % ((exam_labs==1).sum(), (exam_labs==0).sum())
        sys.stdout.flush()
    else:
        exam_list = meta_man.get_last_exam_list(cc_mlo_only=True)
        exam_labs = None

    # Load DL model and classifiers.
    print "Load patch classifier:", dl_state
    sys.stdout.flush()
    dl_model = load_model(dl_state)
    if gpu_count > 1:
        print "Make the model parallel on %d GPUs" % (gpu_count)
        sys.stdout.flush()
        dl_model, _ = make_parallel(dl_model, gpu_count)
        parallelized = True
    else:
        parallelized = False
    feature_name, nb_phm, cutoff_list, k, clf_list = \
            pickle.load(open(clf_info_state))

    # Load preprocess function.
    if featurewise_center:
        preprocess_input = None
    else:
        print "Load preprocess function for net:", net
        if net == 'resnet50':
            from keras.applications.resnet50 import preprocess_input
        elif net == 'vgg16':
            from keras.applications.vgg16 import preprocess_input
        elif net == 'vgg19':
            from keras.applications.vgg19 import preprocess_input
        elif net == 'xception':
            from keras.applications.xception import preprocess_input
        elif net == 'inception':
            from keras.applications.inception_v3 import preprocess_input
        else:
            raise Exception("Pretrained model is not available: " + net)

    # Print header.
    fout = open(out_pred, 'w')
    if validation_mode:
        fout.write(dminfer.INFER_HEADER_VAL)
    else:
        fout.write(dminfer.INFER_HEADER)

    print "Start inference for exam list"
    sys.stdout.flush()
    for i, e in enumerate(exam_list):
        ### DEBUG ###
        # if i >= 3:
        #    break
        ### DEBUG ###
        subj = e[0]
        exam_idx = e[1]
        if validation_mode:
            left_cancer = e[2]['L']['cancer']
            right_cancer = e[2]['R']['cancer']
            left_cancer = 0 if np.isnan(left_cancer) else left_cancer
            right_cancer = 0 if np.isnan(right_cancer) else right_cancer
        try:
            left_cc_phms = get_prob_heatmap(
                e[2]['L']['CC'],
                img_height,
                img_scale,
                patch_size,
                stride,
                dl_model,
                batch_size,
                featurewise_center=featurewise_center,
                featurewise_mean=featurewise_mean,
                preprocess=preprocess_input,
                parallelized=parallelized,
                equalize_hist=equalize_hist)
        except:
            left_cc_phms = [None]
        try:
            left_mlo_phms = get_prob_heatmap(
                e[2]['L']['MLO'],
                img_height,
                img_scale,
                patch_size,
                stride,
                dl_model,
                batch_size,
                featurewise_center=featurewise_center,
                featurewise_mean=featurewise_mean,
                preprocess=preprocess_input,
                parallelized=parallelized,
                equalize_hist=equalize_hist)
        except:
            left_mlo_phms = [None]
        try:
            right_cc_phms = get_prob_heatmap(
                e[2]['R']['CC'],
                img_height,
                img_scale,
                patch_size,
                stride,
                dl_model,
                batch_size,
                featurewise_center=featurewise_center,
                featurewise_mean=featurewise_mean,
                preprocess=preprocess_input,
                parallelized=parallelized,
                equalize_hist=equalize_hist)
        except:
            right_cc_phms = [None]
        try:
            right_mlo_phms = get_prob_heatmap(
                e[2]['R']['MLO'],
                img_height,
                img_scale,
                patch_size,
                stride,
                dl_model,
                batch_size,
                featurewise_center=featurewise_center,
                featurewise_mean=featurewise_mean,
                preprocess=preprocess_input,
                parallelized=parallelized,
                equalize_hist=equalize_hist)
        except:
            right_mlo_phms = [None]
        try:
            left_pred = dminfer.make_pred_case(left_cc_phms,
                                               left_mlo_phms,
                                               feature_name,
                                               cutoff_list,
                                               clf_list,
                                               k=k,
                                               nb_phm=nb_phm,
                                               use_mean=use_mean)
        except:
            print "Exception in predicting left breast" + \
                  " for subj:", subj, "exam:", exam_idx
            sys.stdout.flush()
            left_pred = 0.
        try:
            right_pred = dminfer.make_pred_case(right_cc_phms,
                                                right_mlo_phms,
                                                feature_name,
                                                cutoff_list,
                                                clf_list,
                                                k=k,
                                                nb_phm=nb_phm,
                                                use_mean=use_mean)
        except:
            print "Exception in predicting right breast" + \
                  " for subj:", subj, "exam:", exam_idx
            sys.stdout.flush()
            right_pred = 0.
        if validation_mode:
            fout.write("%s\t%s\tL\t%f\t%f\n" % \
                       (str(subj), str(exam_idx), left_pred, left_cancer))
            fout.write("%s\t%s\tR\t%f\t%f\n" % \
                       (str(subj), str(exam_idx), right_pred, right_cancer))
            fout.flush()
        else:
            fout.write("%s\tL\t%f\n" % (str(subj), left_pred))
            fout.write("%s\tR\t%f\n" % (str(subj), right_pred))
            fout.flush()
        print "processed %d/%d exams" % (i + 1, len(exam_list))
        sys.stdout.flush()
        with open(progress, 'w') as fpro:
            fpro.write("%f\n" % ((i + 1.) / len(exam_list)))
    print "Done."
    fout.close()