def run(img_folder,
        img_size=[288, 224],
        do_featurewise_norm=True,
        featurewise_mean=485.9,
        featurewise_std=765.2,
        img_tsv='./metadata/images_crosswalk.tsv',
        exam_tsv='./metadata/exams_metadata.tsv',
        dl_state=None,
        enet_state=None,
        xgb_state=None,
        validation_mode=False,
        use_mean=False,
        out_pred='./output/predictions.tsv'):
    '''Run SC2 inference
    Args:
        featurewise_mean, featurewise_std ([float]): they are estimated from 
                1152 x 896 images. Using different sized images give very close
                results. For png, mean=7772, std=12187.
    '''

    # Setup data generator for inference.
    meta_man = DMMetaManager(img_tsv=img_tsv,
                             exam_tsv=exam_tsv,
                             img_folder=img_folder,
                             img_extension='dcm')
    last2_exgen = meta_man.last_2_exam_generator()
    if do_featurewise_norm:
        img_gen = DMImageDataGenerator(featurewise_center=True,
                                       featurewise_std_normalization=True)
        img_gen.mean = featurewise_mean
        img_gen.std = featurewise_std
    else:
        img_gen = DMImageDataGenerator(samplewise_center=True,
                                       samplewise_std_normalization=True)
    if validation_mode:
        class_mode = 'binary'
    else:
        class_mode = None

    # Image prediction model.
    if enet_state is not None:
        model = MultiViewDLElasticNet(*enet_state)
    elif dl_state is not None:
        model = load_model(dl_state)
    else:
        raise Exception('At least one image model state must be specified.')

    # XGB model.
    xgb_clf = pickle.load(open(xgb_state))

    # Print header.
    fout = open(out_pred, 'w')
    if validation_mode:
        fout.write(dminfer.INFER_HEADER_VAL)
    else:
        fout.write(dminfer.INFER_HEADER)

    # Loop through all last 2 exam pairs.
    for subj_id, curr_idx, curr_dat, prior_idx, prior_dat in last2_exgen:
        # Get meta info for both breasts.
        left_record, right_record = meta_man.get_info_exam_pair(
            curr_dat, prior_dat)
        nb_days = left_record['daysSincePreviousExam']

        # Get image data and make predictions.
        exam_list = []
        exam_list.append(
            (subj_id, curr_idx, meta_man.get_info_per_exam(curr_dat)))
        if prior_idx is not None:
            exam_list.append(
                (subj_id, prior_idx, meta_man.get_info_per_exam(prior_dat)))
        datgen_exam = img_gen.flow_from_exam_list(exam_list,
                                                  target_size=(img_size[0],
                                                               img_size[1]),
                                                  class_mode=class_mode,
                                                  prediction_mode=True,
                                                  batch_size=len(exam_list),
                                                  verbose=False)
        ebat = next(datgen_exam)
        if class_mode is not None:
            bat_x = ebat[0]
            bat_y = ebat[1]
        else:
            bat_x = ebat
        cc_batch = bat_x[2]
        mlo_batch = bat_x[3]
        curr_left_score = dminfer.pred_2view_img_list(cc_batch[0],
                                                      mlo_batch[0], model,
                                                      use_mean)
        curr_right_score = dminfer.pred_2view_img_list(cc_batch[1],
                                                       mlo_batch[1], model,
                                                       use_mean)
        if prior_idx is not None:
            prior_left_score = dminfer.pred_2view_img_list(
                cc_batch[2], mlo_batch[2], model, use_mean)
            prior_right_score = dminfer.pred_2view_img_list(
                cc_batch[3], mlo_batch[3], model, use_mean)
            diff_left_score = (curr_left_score -
                               prior_left_score) / nb_days * 365
            diff_right_score = (curr_right_score -
                                prior_right_score) / nb_days * 365
        else:
            prior_left_score = np.nan
            prior_right_score = np.nan
            diff_left_score = np.nan
            diff_right_score = np.nan

        # Merge image scores into meta info.
        left_record = left_record\
                .assign(curr_score=curr_left_score)\
                .assign(prior_score=prior_left_score)\
                .assign(diff_score=diff_left_score)
        right_record = right_record\
                .assign(curr_score=curr_right_score)\
                .assign(prior_score=prior_right_score)\
                .assign(diff_score=diff_right_score)
        dsubj = xgb.DMatrix(
            pd.concat([left_record, right_record], ignore_index=True))

        # Predict using XGB.
        pred = xgb_clf.predict(dsubj, ntree_limit=xgb_clf.best_ntree_limit)

        # Output.
        if validation_mode:
            fout.write("%s\t%s\tL\t%f\t%f\n" % \
                       (str(subj_id), str(curr_idx), pred[0], bat_y[0]))
            fout.write("%s\t%s\tR\t%f\t%f\n" % \
                       (str(subj_id), str(curr_idx), pred[1], bat_y[1]))
        else:
            fout.write("%s\tL\t%f\n" % (str(subj_id), pred[0]))
            fout.write("%s\tR\t%f\n" % (str(subj_id), pred[1]))

    fout.close()
def run(img_folder,
        dl_state,
        clf_info_state,
        meta_clf_state,
        img_extension='dcm',
        img_height=4096,
        img_scale=255.,
        equalize_hist=False,
        featurewise_center=False,
        featurewise_mean=91.6,
        net='resnet50',
        batch_size=64,
        patch_size=256,
        stride=64,
        exam_tsv='./metadata/exams_metadata.tsv',
        img_tsv='./metadata/images_crosswalk.tsv',
        validation_mode=False,
        use_mean=False,
        out_pred='./output/predictions.tsv',
        progress='./progress.txt'):
    '''Run SC2 inference based on prob heatmap
    '''
    # Read some env variables.
    random_seed = int(os.getenv('RANDOM_SEED', 12345))
    rng = np.random.RandomState(random_seed)  # an rng used across board.
    gpu_count = int(os.getenv('NUM_GPU_DEVICES', 1))

    # Setup data generator for inference.
    meta_man = DMMetaManager(img_tsv=img_tsv,
                             exam_tsv=exam_tsv,
                             img_folder=img_folder,
                             img_extension='dcm')
    last2_exgen = meta_man.last_2_exam_generator()
    last2_exam_list = list(last2_exgen)

    # Load DL model and classifiers.
    print "Load patch classifier:", dl_state
    sys.stdout.flush()
    dl_model = load_model(dl_state)
    if gpu_count > 1:
        print "Make the model parallel on %d GPUs" % (gpu_count)
        sys.stdout.flush()
        dl_model, _ = make_parallel(dl_model, gpu_count)
        parallelized = True
    else:
        parallelized = False
    feature_name, nb_phm, cutoff_list, k, clf_list = \
            pickle.load(open(clf_info_state))
    meta_model = pickle.load(open(meta_clf_state))

    # Load preprocess function.
    if featurewise_center:
        preprocess_input = None
    else:
        print "Load preprocess function for net:", net
        if net == 'resnet50':
            from keras.applications.resnet50 import preprocess_input
        elif net == 'vgg16':
            from keras.applications.vgg16 import preprocess_input
        elif net == 'vgg19':
            from keras.applications.vgg19 import preprocess_input
        elif net == 'xception':
            from keras.applications.xception import preprocess_input
        elif net == 'inception':
            from keras.applications.inception_v3 import preprocess_input
        else:
            raise Exception("Pretrained model is not available: " + net)

    # Print header.
    fout = open(out_pred, 'w')
    if validation_mode:
        fout.write(dminfer.INFER_HEADER_VAL)
    else:
        fout.write(dminfer.INFER_HEADER)

    # Loop through all last 2 exam pairs.
    for i, (subj_id, curr_idx, curr_dat, prior_idx, prior_dat) in \
            enumerate(last2_exam_list):
        # DEBUG
        #if i < 23:
        #    continue
        # DEBUG
        # Get meta info for both breasts.
        left_record, right_record = meta_man.get_info_exam_pair(
            curr_dat, prior_dat)
        nb_days = left_record['daysSincePreviousExam']

        # Get image data and make predictions.
        current_exam = meta_man.get_info_per_exam(curr_dat, cc_mlo_only=True)
        if prior_idx is not None:
            prior_exam = meta_man.get_info_per_exam(prior_dat,
                                                    cc_mlo_only=True)

        if validation_mode:
            left_cancer = current_exam['L']['cancer']
            right_cancer = current_exam['R']['cancer']
            left_cancer = 0 if np.isnan(left_cancer) else left_cancer
            right_cancer = 0 if np.isnan(right_cancer) else right_cancer

        # Get prob heatmaps.
        try:
            left_cc_phms = get_prob_heatmap(
                current_exam['L']['CC'],
                img_height,
                img_scale,
                patch_size,
                stride,
                dl_model,
                batch_size,
                featurewise_center=featurewise_center,
                featurewise_mean=featurewise_mean,
                preprocess=preprocess_input,
                parallelized=parallelized,
                equalize_hist=equalize_hist)
        except:
            left_cc_phms = [None]
        try:
            left_mlo_phms = get_prob_heatmap(
                current_exam['L']['MLO'],
                img_height,
                img_scale,
                patch_size,
                stride,
                dl_model,
                batch_size,
                featurewise_center=featurewise_center,
                featurewise_mean=featurewise_mean,
                preprocess=preprocess_input,
                parallelized=parallelized,
                equalize_hist=equalize_hist)
        except:
            left_mlo_phms = [None]
        try:
            right_cc_phms = get_prob_heatmap(
                current_exam['R']['CC'],
                img_height,
                img_scale,
                patch_size,
                stride,
                dl_model,
                batch_size,
                featurewise_center=featurewise_center,
                featurewise_mean=featurewise_mean,
                preprocess=preprocess_input,
                parallelized=parallelized,
                equalize_hist=equalize_hist)
        except:
            right_cc_phms = [None]
        try:
            right_mlo_phms = get_prob_heatmap(
                current_exam['R']['MLO'],
                img_height,
                img_scale,
                patch_size,
                stride,
                dl_model,
                batch_size,
                featurewise_center=featurewise_center,
                featurewise_mean=featurewise_mean,
                preprocess=preprocess_input,
                parallelized=parallelized,
                equalize_hist=equalize_hist)
        except:
            right_mlo_phms = [None]
        #import pdb; pdb.set_trace()
        try:
            curr_left_pred = dminfer.make_pred_case(left_cc_phms,
                                                    left_mlo_phms,
                                                    feature_name,
                                                    cutoff_list,
                                                    clf_list,
                                                    k=k,
                                                    nb_phm=nb_phm,
                                                    use_mean=use_mean)
        except:
            curr_left_pred = np.nan
        try:
            curr_right_pred = dminfer.make_pred_case(right_cc_phms,
                                                     right_mlo_phms,
                                                     feature_name,
                                                     cutoff_list,
                                                     clf_list,
                                                     k=k,
                                                     nb_phm=nb_phm,
                                                     use_mean=use_mean)
        except:
            curr_right_pred = np.nan

        if prior_idx is not None:
            try:
                left_cc_phms = get_prob_heatmap(
                    prior_exam['L']['CC'],
                    img_height,
                    img_scale,
                    patch_size,
                    stride,
                    dl_model,
                    batch_size,
                    featurewise_center=featurewise_center,
                    featurewise_mean=featurewise_mean,
                    preprocess=preprocess_input,
                    parallelized=parallelized,
                    equalize_hist=equalize_hist)
            except:
                left_cc_phms = [None]
            try:
                left_mlo_phms = get_prob_heatmap(
                    prior_exam['L']['MLO'],
                    img_height,
                    img_scale,
                    patch_size,
                    stride,
                    dl_model,
                    batch_size,
                    featurewise_center=featurewise_center,
                    featurewise_mean=featurewise_mean,
                    preprocess=preprocess_input,
                    parallelized=parallelized,
                    equalize_hist=equalize_hist)
            except:
                left_mlo_phms = [None]
            try:
                right_cc_phms = get_prob_heatmap(
                    prior_exam['R']['CC'],
                    img_height,
                    img_scale,
                    patch_size,
                    stride,
                    dl_model,
                    batch_size,
                    featurewise_center=featurewise_center,
                    featurewise_mean=featurewise_mean,
                    preprocess=preprocess_input,
                    parallelized=parallelized,
                    equalize_hist=equalize_hist)
            except:
                right_cc_phms = [None]
            try:
                right_mlo_phms = get_prob_heatmap(
                    prior_exam['R']['MLO'],
                    img_height,
                    img_scale,
                    patch_size,
                    stride,
                    dl_model,
                    batch_size,
                    featurewise_center=featurewise_center,
                    featurewise_mean=featurewise_mean,
                    preprocess=preprocess_input,
                    parallelized=parallelized,
                    equalize_hist=equalize_hist)
            except:
                right_mlo_phms = [None]
            try:
                prior_left_pred = dminfer.make_pred_case(left_cc_phms,
                                                         left_mlo_phms,
                                                         feature_name,
                                                         cutoff_list,
                                                         clf_list,
                                                         k=k,
                                                         nb_phm=nb_phm,
                                                         use_mean=use_mean)
            except:
                prior_left_pred = np.nan
            try:
                prior_right_pred = dminfer.make_pred_case(right_cc_phms,
                                                          right_mlo_phms,
                                                          feature_name,
                                                          cutoff_list,
                                                          clf_list,
                                                          k=k,
                                                          nb_phm=nb_phm,
                                                          use_mean=use_mean)
            except:
                prior_right_pred = np.nan
            try:
                diff_left_pred = (curr_left_pred -
                                  prior_left_pred) / nb_days * 365
            except:
                diff_left_pred = np.nan
            try:
                diff_right_pred = (curr_right_pred -
                                   prior_right_pred) / nb_days * 365
            except:
                diff_right_pred = np.nan
        else:
            prior_left_pred = np.nan
            prior_right_pred = np.nan
            diff_left_pred = np.nan
            diff_right_pred = np.nan

        try:
            # Merge image scores into meta info.
            left_record = left_record\
                    .assign(curr_score=curr_left_pred)\
                    .assign(prior_score=prior_left_pred)\
                    .assign(diff_score=diff_left_pred)
            right_record = right_record\
                    .assign(curr_score=curr_right_pred)\
                    .assign(prior_score=prior_right_pred)\
                    .assign(diff_score=diff_right_pred)
            dsubj = pd.concat([left_record, right_record], ignore_index=True)
            # Predict using meta classifier.
            pred = meta_model.predict_proba(dsubj)[:, 1]
        except:
            pred = [0., 0.]

        # Output.
        if validation_mode:
            fout.write("%s\t%s\tL\t%f\t%f\n" % \
                       (str(subj_id), str(curr_idx), pred[0], left_cancer))
            fout.write("%s\t%s\tR\t%f\t%f\n" % \
                       (str(subj_id), str(curr_idx), pred[1], right_cancer))
            fout.flush()
        else:
            fout.write("%s\tL\t%f\n" % (str(subj_id), pred[0]))
            fout.write("%s\tR\t%f\n" % (str(subj_id), pred[1]))
            fout.flush()

        print "processed %d/%d exams" % (i + 1, len(last2_exam_list))
        sys.stdout.flush()
        with open(progress, 'w') as fpro:
            fpro.write("%f\n" % ((i + 1.) / len(last2_exam_list)))

    print "Done."
    fout.close()