def run(img_folder, dl_state, fprop_mode=False,
        img_size=(1152, 896), img_height=None, img_scale=None, 
        rescale_factor=None,
        equalize_hist=False, featurewise_center=False, featurewise_mean=71.8,
        net='vgg19', batch_size=128, patch_size=256, stride=8,
        avg_pool_size=(7, 7), hm_strides=(1, 1),
        pat_csv='./full_img/pat.csv', pat_list=None,
        out='./output/prob_heatmap.pkl'):
    '''Sweep mammograms with trained DL model to create prob heatmaps
    '''
    # Read some env variables.
    random_seed = int(os.getenv('RANDOM_SEED', 12345))
    rng = RandomState(random_seed)  # an rng used across board.
    gpu_count = int(os.getenv('NUM_GPU_DEVICES', 1))

    # Create image generator.
    imgen = DMImageDataGenerator(featurewise_center=featurewise_center)
    imgen.mean = featurewise_mean

    # Get image and label lists.
    df = pd.read_csv(pat_csv, header=0)
    df = df.set_index(['patient_id', 'side'])
    df.sort_index(inplace=True)
    if pat_list is not None:
        pat_ids = pd.read_csv(pat_list, header=0).values.ravel()
        pat_ids = pat_ids.tolist()
        print ("Read %d patient IDs" % (len(pat_ids)))
        df = df.loc[pat_ids]

    # Load DL model, preprocess.
    print ("Load patch classifier:", dl_state)
    sys.stdout.flush()
    dl_model, preprocess_input, _ = get_dl_model(net, resume_from=dl_state)
    if fprop_mode:
        dl_model = add_top_layers(dl_model, img_size, patch_net=net, 
                                  avg_pool_size=avg_pool_size, 
                                  return_heatmap=True, hm_strides=hm_strides)
    if gpu_count > 1:
        print ("Make the model parallel on %d GPUs" % (gpu_count))
        sys.stdout.flush()
        dl_model, _ = make_parallel(dl_model, gpu_count)
        parallelized = True
    else:
        parallelized = False
    if featurewise_center:
        preprocess_input = None

    # Sweep the whole images and classify patches.
    def const_filename(pat, side, view):
        basename = '_'.join([pat, side, view]) + '.png'
        return os.path.join(img_folder, basename)

    print ("Generate prob heatmaps")
    sys.stdout.flush()
    heatmaps = []
    cases_seen = 0
    nb_cases = len(df.index.unique())
    for i, (pat,side) in enumerate(df.index.unique()):
        ## DEBUG ##
        #if i >= 10:
        #    break
        ## DEBUG ##
        cancer = df.loc[pat].loc[side]['cancer']
        cc_fn = const_filename(pat, side, 'CC')
        if os.path.isfile(cc_fn):
            if fprop_mode:
                cc_x = read_img_for_pred(
                    cc_fn, equalize_hist=equalize_hist, data_format=data_format,
                    dup_3_channels=True, 
                    transformer=imgen.random_transform,
                    standardizer=imgen.standardize,
                    target_size=img_size, target_scale=img_scale,
                    rescale_factor=rescale_factor)
                cc_x = cc_x.reshape((1,) + cc_x.shape)
                cc_hm = dl_model.predict_on_batch(cc_x)[0]
                # import pdb; pdb.set_trace()
            else:
                cc_hm = get_prob_heatmap(
                    cc_fn, img_height, img_scale, patch_size, stride, 
                    dl_model, batch_size, featurewise_center=featurewise_center, 
                    featurewise_mean=featurewise_mean, preprocess=preprocess_input, 
                    parallelized=parallelized, equalize_hist=equalize_hist)
        else:
            cc_hm = None
        mlo_fn = const_filename(pat, side, 'MLO')
        if os.path.isfile(mlo_fn):
            if fprop_mode:
                mlo_x = read_img_for_pred(
                    mlo_fn, equalize_hist=equalize_hist, data_format=data_format,
                    dup_3_channels=True, 
                    transformer=imgen.random_transform,
                    standardizer=imgen.standardize,
                    target_size=img_size, target_scale=img_scale,
                    rescale_factor=rescale_factor)
                mlo_x = mlo_x.reshape((1,) + mlo_x.shape)
                mlo_hm = dl_model.predict_on_batch(mlo_x)[0]
            else:
                mlo_hm = get_prob_heatmap(
                    mlo_fn, img_height, img_scale, patch_size, stride, 
                    dl_model, batch_size, featurewise_center=featurewise_center, 
                    featurewise_mean=featurewise_mean, preprocess=preprocess_input, 
                    parallelized=parallelized, equalize_hist=equalize_hist)
        else:
            mlo_hm = None
        heatmaps.append({'patient_id':pat, 'side':side, 'cancer':cancer, 
                         'cc':cc_hm, 'mlo':mlo_hm})
        print ("scored %d/%d cases" % (i + 1, nb_cases))
        sys.stdout.flush()
    print ("Done.")

    # Save the result.
    print ("Saving result to external files.",)
    sys.stdout.flush()
    pickle.dump(heatmaps, open(out, 'w'))
    print ("Done.")
def run(img_folder, dl_state, fprop_mode=False,
        img_size=(1152, 896), img_height=None, img_scale=None, 
        rescale_factor=None,
        equalize_hist=False, featurewise_center=False, featurewise_mean=71.8,
        net='vgg19', batch_size=128, patch_size=256, stride=8,
        avg_pool_size=(7, 7), hm_strides=(1, 1),
        pat_csv='./full_img/pat.csv', pat_list=None,
        out='./output/prob_heatmap.pkl'):
    '''Sweep mammograms with trained DL model to create prob heatmaps
    '''
    # Read some env variables.
    random_seed = int(os.getenv('RANDOM_SEED', 12345))
    rng = RandomState(random_seed)  # an rng used across board.
    gpu_count = int(os.getenv('NUM_GPU_DEVICES', 1))

    # Create image generator.
    imgen = DMImageDataGenerator(featurewise_center=featurewise_center)
    imgen.mean = featurewise_mean

    # Get image and label lists.
    df = pd.read_csv(pat_csv, header=0)
    df = df.set_index(['patient_id', 'side'])
    df.sort_index(inplace=True)
    if pat_list is not None:
        pat_ids = pd.read_csv(pat_list, header=0).values.ravel()
        pat_ids = pat_ids.tolist()
        print "Read %d patient IDs" % (len(pat_ids))
        df = df.loc[pat_ids]

    # Load DL model, preprocess.
    print "Load patch classifier:", dl_state; sys.stdout.flush()
    dl_model, preprocess_input, _ = get_dl_model(net, resume_from=dl_state)
    if fprop_mode:
        dl_model = add_top_layers(dl_model, img_size, patch_net=net, 
                                  avg_pool_size=avg_pool_size, 
                                  return_heatmap=True, hm_strides=hm_strides)
    if gpu_count > 1:
        print "Make the model parallel on %d GPUs" % (gpu_count)
        sys.stdout.flush()
        dl_model, _ = make_parallel(dl_model, gpu_count)
        parallelized = True
    else:
        parallelized = False
    if featurewise_center:
        preprocess_input = None

    # Sweep the whole images and classify patches.
    def const_filename(pat, side, view):
        basename = '_'.join([pat, side, view]) + '.png'
        return os.path.join(img_folder, basename)

    print "Generate prob heatmaps"; sys.stdout.flush()
    heatmaps = []
    cases_seen = 0
    nb_cases = len(df.index.unique())
    for i, (pat,side) in enumerate(df.index.unique()):
        ## DEBUG ##
        #if i >= 10:
        #    break
        ## DEBUG ##
        cancer = df.loc[pat].loc[side]['cancer']
        cc_fn = const_filename(pat, side, 'CC')
        if os.path.isfile(cc_fn):
            if fprop_mode:
                cc_x = read_img_for_pred(
                    cc_fn, equalize_hist=equalize_hist, data_format=data_format,
                    dup_3_channels=True, 
                    transformer=imgen.random_transform,
                    standardizer=imgen.standardize,
                    target_size=img_size, target_scale=img_scale,
                    rescale_factor=rescale_factor)
                cc_x = cc_x.reshape((1,) + cc_x.shape)
                cc_hm = dl_model.predict_on_batch(cc_x)[0]
                # import pdb; pdb.set_trace()
            else:
                cc_hm = get_prob_heatmap(
                    cc_fn, img_height, img_scale, patch_size, stride, 
                    dl_model, batch_size, featurewise_center=featurewise_center, 
                    featurewise_mean=featurewise_mean, preprocess=preprocess_input, 
                    parallelized=parallelized, equalize_hist=equalize_hist)
        else:
            cc_hm = None
        mlo_fn = const_filename(pat, side, 'MLO')
        if os.path.isfile(mlo_fn):
            if fprop_mode:
                mlo_x = read_img_for_pred(
                    mlo_fn, equalize_hist=equalize_hist, data_format=data_format,
                    dup_3_channels=True, 
                    transformer=imgen.random_transform,
                    standardizer=imgen.standardize,
                    target_size=img_size, target_scale=img_scale,
                    rescale_factor=rescale_factor)
                mlo_x = mlo_x.reshape((1,) + mlo_x.shape)
                mlo_hm = dl_model.predict_on_batch(mlo_x)[0]
            else:
                mlo_hm = get_prob_heatmap(
                    mlo_fn, img_height, img_scale, patch_size, stride, 
                    dl_model, batch_size, featurewise_center=featurewise_center, 
                    featurewise_mean=featurewise_mean, preprocess=preprocess_input, 
                    parallelized=parallelized, equalize_hist=equalize_hist)
        else:
            mlo_hm = None
        heatmaps.append({'patient_id':pat, 'side':side, 'cancer':cancer, 
                         'cc':cc_hm, 'mlo':mlo_hm})
        print "scored %d/%d cases" % (i + 1, nb_cases)
        sys.stdout.flush()
    print "Done."

    # Save the result.
    print "Saving result to external files.",
    sys.stdout.flush()
    pickle.dump(heatmaps, open(out, 'w'))
    print "Done."
def run(img_folder,
        dl_state,
        img_extension='dcm',
        img_height=1024,
        img_scale=255.,
        equalize_hist=False,
        featurewise_center=False,
        featurewise_mean=91.6,
        neg_vs_pos_ratio=1.,
        net='vgg19',
        batch_size=128,
        patch_size=256,
        stride=8,
        exam_tsv='./metadata/exams_metadata.tsv',
        img_tsv='./metadata/images_crosswalk.tsv',
        out='./modelState/prob_heatmap.pkl',
        predicted_subj_file=None,
        add_subjs=500):
    '''Sweep mammograms with trained DL model to create prob heatmaps
    '''
    # Read some env variables.
    random_seed = int(os.getenv('RANDOM_SEED', 12345))
    rng = RandomState(random_seed)  # an rng used across board.
    gpu_count = int(os.getenv('NUM_GPU_DEVICES', 1))

    # Load and split image and label lists.
    meta_man = DMMetaManager(exam_tsv=exam_tsv,
                             img_tsv=img_tsv,
                             img_folder=img_folder,
                             img_extension=img_extension)
    subj_list, subj_labs = meta_man.get_subj_labs()
    subj_labs = np.array(subj_labs)
    print "Found %d subjests" % (len(subj_list))
    print "cancer patients=%d, normal patients=%d" \
            % ((subj_labs==1).sum(), (subj_labs==0).sum())
    if predicted_subj_file is not None:
        predicted_subjs = np.load(predicted_subj_file)
        subj_list = np.setdiff1d(subj_list, predicted_subjs)
        subj_list = subj_list[:add_subjs]
        print "Will predict additional %d subjects" % (len(subj_list))
    elif neg_vs_pos_ratio is not None:
        subj_list, subj_labs = DMMetaManager.subset_subj_list(
            subj_list, subj_labs, neg_vs_pos_ratio, random_seed)
        subj_labs = np.array(subj_labs)
        print "After subsetting, there are %d subjects" % (len(subj_list))
        print "cancer patients=%d, normal patients=%d" \
                % ((subj_labs==1).sum(), (subj_labs==0).sum())

    # Get exam lists.
    # >>>> Debug <<<< #
    # subj_list = subj_list[:2]
    # >>>> Debug <<<< #
    print "Get flattened exam list"
    exam_list = meta_man.get_flatten_exam_list(subj_list, cc_mlo_only=True)
    exam_labs = meta_man.exam_labs(exam_list)
    exam_labs = np.array(exam_labs)
    print "positive exams=%d, negative exams=%d" \
            % ((exam_labs==1).sum(), (exam_labs==0).sum())
    sys.stdout.flush()

    # Load DL model.
    print "Load patch classifier:", dl_state
    sys.stdout.flush()
    dl_model = load_model(dl_state,
                          custom_objects={
                              'sensitivity': dmm.sensitivity,
                              'specificity': dmm.specificity
                          })

    if gpu_count > 1:
        print "Make the model parallel on %d GPUs" % (gpu_count)
        sys.stdout.flush()
        dl_model, _ = make_parallel(dl_model, gpu_count)
        parallelized = True
    else:
        parallelized = False

    # Load preprocess function.
    if featurewise_center:
        preprocess_input = None
    else:
        print "Load preprocess function for net:", net
        if net == 'resnet50':
            from keras.applications.resnet50 import preprocess_input
        elif net == 'vgg16':
            from keras.applications.vgg16 import preprocess_input
        elif net == 'vgg19':
            from keras.applications.vgg19 import preprocess_input
        elif net == 'xception':
            from keras.applications.xception import preprocess_input
        elif net == 'inception':
            from keras.applications.inception_v3 import preprocess_input
        else:
            raise Exception("Pretrained model is not available: " + net)

    # Sweep the whole images and classify patches.
    print "Generate prob heatmaps for exam list"
    sys.stdout.flush()
    heatmap_dat_list = []
    for i, e in enumerate(exam_list):
        dat = (e[0], e[1], {
            'L': {
                'cancer': e[2]['L']['cancer']
            },
            'R': {
                'cancer': e[2]['R']['cancer']
            }
        })
        dat[2]['L']['CC'] = get_prob_heatmap(
            e[2]['L']['CC'],
            img_height,
            img_scale,
            patch_size,
            stride,
            dl_model,
            batch_size,
            featurewise_center=featurewise_center,
            featurewise_mean=featurewise_mean,
            preprocess=preprocess_input,
            parallelized=parallelized,
            equalize_hist=equalize_hist)
        dat[2]['L']['MLO'] = get_prob_heatmap(
            e[2]['L']['MLO'],
            img_height,
            img_scale,
            patch_size,
            stride,
            dl_model,
            batch_size,
            featurewise_center=featurewise_center,
            featurewise_mean=featurewise_mean,
            preprocess=preprocess_input,
            parallelized=parallelized,
            equalize_hist=equalize_hist)
        dat[2]['R']['CC'] = get_prob_heatmap(
            e[2]['R']['CC'],
            img_height,
            img_scale,
            patch_size,
            stride,
            dl_model,
            batch_size,
            featurewise_center=featurewise_center,
            featurewise_mean=featurewise_mean,
            preprocess=preprocess_input,
            parallelized=parallelized,
            equalize_hist=equalize_hist)
        dat[2]['R']['MLO'] = get_prob_heatmap(
            e[2]['R']['MLO'],
            img_height,
            img_scale,
            patch_size,
            stride,
            dl_model,
            batch_size,
            featurewise_center=featurewise_center,
            featurewise_mean=featurewise_mean,
            preprocess=preprocess_input,
            parallelized=parallelized,
            equalize_hist=equalize_hist)
        heatmap_dat_list.append(dat)
        print "processed %d/%d exams" % (i + 1, len(exam_list))
        sys.stdout.flush()
        ### DEBUG ###
        # if i >= 1:
        #    break
        ### DEBUG ###
    print "Done."

    # Save the result.
    print "Saving result to external files.",
    sys.stdout.flush()
    pickle.dump(heatmap_dat_list, open(out, 'w'))
    print "Done."
def run(img_folder,
        dl_state,
        clf_info_state,
        meta_clf_state,
        img_extension='dcm',
        img_height=4096,
        img_scale=255.,
        equalize_hist=False,
        featurewise_center=False,
        featurewise_mean=91.6,
        net='resnet50',
        batch_size=64,
        patch_size=256,
        stride=64,
        exam_tsv='./metadata/exams_metadata.tsv',
        img_tsv='./metadata/images_crosswalk.tsv',
        validation_mode=False,
        use_mean=False,
        out_pred='./output/predictions.tsv',
        progress='./progress.txt'):
    '''Run SC2 inference based on prob heatmap
    '''
    # Read some env variables.
    random_seed = int(os.getenv('RANDOM_SEED', 12345))
    rng = np.random.RandomState(random_seed)  # an rng used across board.
    gpu_count = int(os.getenv('NUM_GPU_DEVICES', 1))

    # Setup data generator for inference.
    meta_man = DMMetaManager(img_tsv=img_tsv,
                             exam_tsv=exam_tsv,
                             img_folder=img_folder,
                             img_extension='dcm')
    last2_exgen = meta_man.last_2_exam_generator()
    last2_exam_list = list(last2_exgen)

    # Load DL model and classifiers.
    print "Load patch classifier:", dl_state
    sys.stdout.flush()
    dl_model = load_model(dl_state)
    if gpu_count > 1:
        print "Make the model parallel on %d GPUs" % (gpu_count)
        sys.stdout.flush()
        dl_model, _ = make_parallel(dl_model, gpu_count)
        parallelized = True
    else:
        parallelized = False
    feature_name, nb_phm, cutoff_list, k, clf_list = \
            pickle.load(open(clf_info_state))
    meta_model = pickle.load(open(meta_clf_state))

    # Load preprocess function.
    if featurewise_center:
        preprocess_input = None
    else:
        print "Load preprocess function for net:", net
        if net == 'resnet50':
            from keras.applications.resnet50 import preprocess_input
        elif net == 'vgg16':
            from keras.applications.vgg16 import preprocess_input
        elif net == 'vgg19':
            from keras.applications.vgg19 import preprocess_input
        elif net == 'xception':
            from keras.applications.xception import preprocess_input
        elif net == 'inception':
            from keras.applications.inception_v3 import preprocess_input
        else:
            raise Exception("Pretrained model is not available: " + net)

    # Print header.
    fout = open(out_pred, 'w')
    if validation_mode:
        fout.write(dminfer.INFER_HEADER_VAL)
    else:
        fout.write(dminfer.INFER_HEADER)

    # Loop through all last 2 exam pairs.
    for i, (subj_id, curr_idx, curr_dat, prior_idx, prior_dat) in \
            enumerate(last2_exam_list):
        # DEBUG
        #if i < 23:
        #    continue
        # DEBUG
        # Get meta info for both breasts.
        left_record, right_record = meta_man.get_info_exam_pair(
            curr_dat, prior_dat)
        nb_days = left_record['daysSincePreviousExam']

        # Get image data and make predictions.
        current_exam = meta_man.get_info_per_exam(curr_dat, cc_mlo_only=True)
        if prior_idx is not None:
            prior_exam = meta_man.get_info_per_exam(prior_dat,
                                                    cc_mlo_only=True)

        if validation_mode:
            left_cancer = current_exam['L']['cancer']
            right_cancer = current_exam['R']['cancer']
            left_cancer = 0 if np.isnan(left_cancer) else left_cancer
            right_cancer = 0 if np.isnan(right_cancer) else right_cancer

        # Get prob heatmaps.
        try:
            left_cc_phms = get_prob_heatmap(
                current_exam['L']['CC'],
                img_height,
                img_scale,
                patch_size,
                stride,
                dl_model,
                batch_size,
                featurewise_center=featurewise_center,
                featurewise_mean=featurewise_mean,
                preprocess=preprocess_input,
                parallelized=parallelized,
                equalize_hist=equalize_hist)
        except:
            left_cc_phms = [None]
        try:
            left_mlo_phms = get_prob_heatmap(
                current_exam['L']['MLO'],
                img_height,
                img_scale,
                patch_size,
                stride,
                dl_model,
                batch_size,
                featurewise_center=featurewise_center,
                featurewise_mean=featurewise_mean,
                preprocess=preprocess_input,
                parallelized=parallelized,
                equalize_hist=equalize_hist)
        except:
            left_mlo_phms = [None]
        try:
            right_cc_phms = get_prob_heatmap(
                current_exam['R']['CC'],
                img_height,
                img_scale,
                patch_size,
                stride,
                dl_model,
                batch_size,
                featurewise_center=featurewise_center,
                featurewise_mean=featurewise_mean,
                preprocess=preprocess_input,
                parallelized=parallelized,
                equalize_hist=equalize_hist)
        except:
            right_cc_phms = [None]
        try:
            right_mlo_phms = get_prob_heatmap(
                current_exam['R']['MLO'],
                img_height,
                img_scale,
                patch_size,
                stride,
                dl_model,
                batch_size,
                featurewise_center=featurewise_center,
                featurewise_mean=featurewise_mean,
                preprocess=preprocess_input,
                parallelized=parallelized,
                equalize_hist=equalize_hist)
        except:
            right_mlo_phms = [None]
        #import pdb; pdb.set_trace()
        try:
            curr_left_pred = dminfer.make_pred_case(left_cc_phms,
                                                    left_mlo_phms,
                                                    feature_name,
                                                    cutoff_list,
                                                    clf_list,
                                                    k=k,
                                                    nb_phm=nb_phm,
                                                    use_mean=use_mean)
        except:
            curr_left_pred = np.nan
        try:
            curr_right_pred = dminfer.make_pred_case(right_cc_phms,
                                                     right_mlo_phms,
                                                     feature_name,
                                                     cutoff_list,
                                                     clf_list,
                                                     k=k,
                                                     nb_phm=nb_phm,
                                                     use_mean=use_mean)
        except:
            curr_right_pred = np.nan

        if prior_idx is not None:
            try:
                left_cc_phms = get_prob_heatmap(
                    prior_exam['L']['CC'],
                    img_height,
                    img_scale,
                    patch_size,
                    stride,
                    dl_model,
                    batch_size,
                    featurewise_center=featurewise_center,
                    featurewise_mean=featurewise_mean,
                    preprocess=preprocess_input,
                    parallelized=parallelized,
                    equalize_hist=equalize_hist)
            except:
                left_cc_phms = [None]
            try:
                left_mlo_phms = get_prob_heatmap(
                    prior_exam['L']['MLO'],
                    img_height,
                    img_scale,
                    patch_size,
                    stride,
                    dl_model,
                    batch_size,
                    featurewise_center=featurewise_center,
                    featurewise_mean=featurewise_mean,
                    preprocess=preprocess_input,
                    parallelized=parallelized,
                    equalize_hist=equalize_hist)
            except:
                left_mlo_phms = [None]
            try:
                right_cc_phms = get_prob_heatmap(
                    prior_exam['R']['CC'],
                    img_height,
                    img_scale,
                    patch_size,
                    stride,
                    dl_model,
                    batch_size,
                    featurewise_center=featurewise_center,
                    featurewise_mean=featurewise_mean,
                    preprocess=preprocess_input,
                    parallelized=parallelized,
                    equalize_hist=equalize_hist)
            except:
                right_cc_phms = [None]
            try:
                right_mlo_phms = get_prob_heatmap(
                    prior_exam['R']['MLO'],
                    img_height,
                    img_scale,
                    patch_size,
                    stride,
                    dl_model,
                    batch_size,
                    featurewise_center=featurewise_center,
                    featurewise_mean=featurewise_mean,
                    preprocess=preprocess_input,
                    parallelized=parallelized,
                    equalize_hist=equalize_hist)
            except:
                right_mlo_phms = [None]
            try:
                prior_left_pred = dminfer.make_pred_case(left_cc_phms,
                                                         left_mlo_phms,
                                                         feature_name,
                                                         cutoff_list,
                                                         clf_list,
                                                         k=k,
                                                         nb_phm=nb_phm,
                                                         use_mean=use_mean)
            except:
                prior_left_pred = np.nan
            try:
                prior_right_pred = dminfer.make_pred_case(right_cc_phms,
                                                          right_mlo_phms,
                                                          feature_name,
                                                          cutoff_list,
                                                          clf_list,
                                                          k=k,
                                                          nb_phm=nb_phm,
                                                          use_mean=use_mean)
            except:
                prior_right_pred = np.nan
            try:
                diff_left_pred = (curr_left_pred -
                                  prior_left_pred) / nb_days * 365
            except:
                diff_left_pred = np.nan
            try:
                diff_right_pred = (curr_right_pred -
                                   prior_right_pred) / nb_days * 365
            except:
                diff_right_pred = np.nan
        else:
            prior_left_pred = np.nan
            prior_right_pred = np.nan
            diff_left_pred = np.nan
            diff_right_pred = np.nan

        try:
            # Merge image scores into meta info.
            left_record = left_record\
                    .assign(curr_score=curr_left_pred)\
                    .assign(prior_score=prior_left_pred)\
                    .assign(diff_score=diff_left_pred)
            right_record = right_record\
                    .assign(curr_score=curr_right_pred)\
                    .assign(prior_score=prior_right_pred)\
                    .assign(diff_score=diff_right_pred)
            dsubj = pd.concat([left_record, right_record], ignore_index=True)
            # Predict using meta classifier.
            pred = meta_model.predict_proba(dsubj)[:, 1]
        except:
            pred = [0., 0.]

        # Output.
        if validation_mode:
            fout.write("%s\t%s\tL\t%f\t%f\n" % \
                       (str(subj_id), str(curr_idx), pred[0], left_cancer))
            fout.write("%s\t%s\tR\t%f\t%f\n" % \
                       (str(subj_id), str(curr_idx), pred[1], right_cancer))
            fout.flush()
        else:
            fout.write("%s\tL\t%f\n" % (str(subj_id), pred[0]))
            fout.write("%s\tR\t%f\n" % (str(subj_id), pred[1]))
            fout.flush()

        print "processed %d/%d exams" % (i + 1, len(last2_exam_list))
        sys.stdout.flush()
        with open(progress, 'w') as fpro:
            fpro.write("%f\n" % ((i + 1.) / len(last2_exam_list)))

    print "Done."
    fout.close()
def run(img_folder,
        dl_state,
        clf_info_state,
        img_extension='dcm',
        img_height=4096,
        img_scale=255.,
        equalize_hist=False,
        featurewise_center=False,
        featurewise_mean=91.6,
        net='resnet50',
        batch_size=64,
        patch_size=256,
        stride=64,
        exam_tsv='./metadata/exams_metadata.tsv',
        img_tsv='./metadata/images_crosswalk.tsv',
        validation_mode=False,
        use_mean=False,
        out_pred='./output/predictions.tsv',
        progress='./progress.txt'):
    '''Run SC1 inference using prob heatmaps
    '''
    # Read some env variables.
    random_seed = int(os.getenv('RANDOM_SEED', 12345))
    rng = np.random.RandomState(random_seed)  # an rng used across board.
    gpu_count = int(os.getenv('NUM_GPU_DEVICES', 1))

    # Setup data generator for inference.
    meta_man = DMMetaManager(img_tsv=img_tsv,
                             exam_tsv=exam_tsv,
                             img_folder=img_folder,
                             img_extension=img_extension)
    if validation_mode:
        exam_list = meta_man.get_flatten_exam_list(cc_mlo_only=True)
        exam_labs = meta_man.exam_labs(exam_list)
        exam_labs = np.array(exam_labs)
        print "positive exams=%d, negative exams=%d" \
                % ((exam_labs==1).sum(), (exam_labs==0).sum())
        sys.stdout.flush()
    else:
        exam_list = meta_man.get_last_exam_list(cc_mlo_only=True)
        exam_labs = None

    # Load DL model and classifiers.
    print "Load patch classifier:", dl_state
    sys.stdout.flush()
    dl_model = load_model(dl_state)
    if gpu_count > 1:
        print "Make the model parallel on %d GPUs" % (gpu_count)
        sys.stdout.flush()
        dl_model, _ = make_parallel(dl_model, gpu_count)
        parallelized = True
    else:
        parallelized = False
    feature_name, nb_phm, cutoff_list, k, clf_list = \
            pickle.load(open(clf_info_state))

    # Load preprocess function.
    if featurewise_center:
        preprocess_input = None
    else:
        print "Load preprocess function for net:", net
        if net == 'resnet50':
            from keras.applications.resnet50 import preprocess_input
        elif net == 'vgg16':
            from keras.applications.vgg16 import preprocess_input
        elif net == 'vgg19':
            from keras.applications.vgg19 import preprocess_input
        elif net == 'xception':
            from keras.applications.xception import preprocess_input
        elif net == 'inception':
            from keras.applications.inception_v3 import preprocess_input
        else:
            raise Exception("Pretrained model is not available: " + net)

    # Print header.
    fout = open(out_pred, 'w')
    if validation_mode:
        fout.write(dminfer.INFER_HEADER_VAL)
    else:
        fout.write(dminfer.INFER_HEADER)

    print "Start inference for exam list"
    sys.stdout.flush()
    for i, e in enumerate(exam_list):
        ### DEBUG ###
        # if i >= 3:
        #    break
        ### DEBUG ###
        subj = e[0]
        exam_idx = e[1]
        if validation_mode:
            left_cancer = e[2]['L']['cancer']
            right_cancer = e[2]['R']['cancer']
            left_cancer = 0 if np.isnan(left_cancer) else left_cancer
            right_cancer = 0 if np.isnan(right_cancer) else right_cancer
        try:
            left_cc_phms = get_prob_heatmap(
                e[2]['L']['CC'],
                img_height,
                img_scale,
                patch_size,
                stride,
                dl_model,
                batch_size,
                featurewise_center=featurewise_center,
                featurewise_mean=featurewise_mean,
                preprocess=preprocess_input,
                parallelized=parallelized,
                equalize_hist=equalize_hist)
        except:
            left_cc_phms = [None]
        try:
            left_mlo_phms = get_prob_heatmap(
                e[2]['L']['MLO'],
                img_height,
                img_scale,
                patch_size,
                stride,
                dl_model,
                batch_size,
                featurewise_center=featurewise_center,
                featurewise_mean=featurewise_mean,
                preprocess=preprocess_input,
                parallelized=parallelized,
                equalize_hist=equalize_hist)
        except:
            left_mlo_phms = [None]
        try:
            right_cc_phms = get_prob_heatmap(
                e[2]['R']['CC'],
                img_height,
                img_scale,
                patch_size,
                stride,
                dl_model,
                batch_size,
                featurewise_center=featurewise_center,
                featurewise_mean=featurewise_mean,
                preprocess=preprocess_input,
                parallelized=parallelized,
                equalize_hist=equalize_hist)
        except:
            right_cc_phms = [None]
        try:
            right_mlo_phms = get_prob_heatmap(
                e[2]['R']['MLO'],
                img_height,
                img_scale,
                patch_size,
                stride,
                dl_model,
                batch_size,
                featurewise_center=featurewise_center,
                featurewise_mean=featurewise_mean,
                preprocess=preprocess_input,
                parallelized=parallelized,
                equalize_hist=equalize_hist)
        except:
            right_mlo_phms = [None]
        try:
            left_pred = dminfer.make_pred_case(left_cc_phms,
                                               left_mlo_phms,
                                               feature_name,
                                               cutoff_list,
                                               clf_list,
                                               k=k,
                                               nb_phm=nb_phm,
                                               use_mean=use_mean)
        except:
            print "Exception in predicting left breast" + \
                  " for subj:", subj, "exam:", exam_idx
            sys.stdout.flush()
            left_pred = 0.
        try:
            right_pred = dminfer.make_pred_case(right_cc_phms,
                                                right_mlo_phms,
                                                feature_name,
                                                cutoff_list,
                                                clf_list,
                                                k=k,
                                                nb_phm=nb_phm,
                                                use_mean=use_mean)
        except:
            print "Exception in predicting right breast" + \
                  " for subj:", subj, "exam:", exam_idx
            sys.stdout.flush()
            right_pred = 0.
        if validation_mode:
            fout.write("%s\t%s\tL\t%f\t%f\n" % \
                       (str(subj), str(exam_idx), left_pred, left_cancer))
            fout.write("%s\t%s\tR\t%f\t%f\n" % \
                       (str(subj), str(exam_idx), right_pred, right_cancer))
            fout.flush()
        else:
            fout.write("%s\tL\t%f\n" % (str(subj), left_pred))
            fout.write("%s\tR\t%f\n" % (str(subj), right_pred))
            fout.flush()
        print "processed %d/%d exams" % (i + 1, len(exam_list))
        sys.stdout.flush()
        with open(progress, 'w') as fpro:
            fpro.write("%f\n" % ((i + 1.) / len(exam_list)))
    print "Done."
    fout.close()