def run(img_folder,
        img_height=1024,
        img_scale=4095,
        roi_per_img=32,
        roi_size=(256, 256),
        low_int_threshold=.05,
        blob_min_area=3,
        blob_min_int=.5,
        blob_max_int=.85,
        blob_th_step=10,
        roi_state=None,
        roi_bs=32,
        do_featurewise_norm=True,
        featurewise_mean=884.7,
        featurewise_std=745.3,
        img_tsv='./metadata/images_crosswalk_prediction.tsv',
        exam_tsv=None,
        dl_state=None,
        dl_bs=32,
        nb_top_avg=1,
        validation_mode=False,
        val_size=None,
        img_voting=False,
        out_pred='./output/predictions.tsv'):
    '''Run SC1 inference using the candidate ROI approach
    Notes: 
        "mean=884.7, std=745.3" are estimated from 20 subjects on the 
        training data.
    '''

    # Read some env variables.
    random_seed = int(os.getenv('RANDOM_SEED', 12345))
    # nb_worker = int(os.getenv('NUM_CPU_CORES', 4))
    gpu_count = int(os.getenv('NUM_GPU_DEVICES', 1))

    # Setup data generator for inference.
    meta_man = DMMetaManager(img_tsv=img_tsv,
                             exam_tsv=exam_tsv,
                             img_folder=img_folder,
                             img_extension='dcm')
    if val_size is not None:  # Use a subset for validation.
        subj_list, subj_labs = meta_man.get_subj_labs()
        _, subj_test = train_test_split(subj_list,
                                        test_size=val_size,
                                        random_state=random_seed,
                                        stratify=subj_labs)
    else:
        subj_test = None

    if validation_mode:
        exam_list = meta_man.get_flatten_exam_list(subj_list=subj_test,
                                                   flatten_img_list=True)
    else:
        exam_list = meta_man.get_last_exam_list(subj_list=subj_test,
                                                flatten_img_list=True)

    if do_featurewise_norm:
        img_gen = DMImageDataGenerator(featurewise_center=True,
                                       featurewise_std_normalization=True)
        img_gen.mean = featurewise_mean
        img_gen.std = featurewise_std
    else:
        img_gen = DMImageDataGenerator(samplewise_center=True,
                                       samplewise_std_normalization=True)
    if validation_mode:
        class_mode = 'categorical'
    else:
        class_mode = None

    # Load ROI classifier.
    if roi_state is not None:
        roi_clf = load_model(roi_state,
                             custom_objects={
                                 'sensitivity': DMMetrics.sensitivity,
                                 'specificity': DMMetrics.specificity
                             })
        if gpu_count > 1:
            roi_clf = make_parallel(roi_clf, gpu_count)
    else:
        roi_clf = None

    # Load model.
    if dl_state is not None:
        model = load_model(dl_state)
    else:
        raise Exception('At least one model state must be specified.')
    if gpu_count > 1:
        model = make_parallel(model, gpu_count)

    # A function to make predictions on image patches from an image list.
    def pred_img_list(img_list):
        roi_generator = img_gen.flow_from_candid_roi(
            img_list,
            target_height=img_height,
            target_scale=img_scale,
            class_mode=class_mode,
            validation_mode=True,
            img_per_batch=len(img_list),
            roi_per_img=roi_per_img,
            roi_size=roi_size,
            low_int_threshold=low_int_threshold,
            blob_min_area=blob_min_area,
            blob_min_int=blob_min_int,
            blob_max_int=blob_max_int,
            blob_th_step=blob_th_step,
            roi_clf=roi_clf,
            clf_bs=roi_bs,
            return_sample_weight=True,
            seed=random_seed)
        roi_dat, roi_w = roi_generator.next()
        # import pdb; pdb.set_trace()
        pred = model.predict(roi_dat, batch_size=dl_bs)
        pred = pred[:, 1]  # cancer class predictions.
        if roi_clf is not None:
            # return np.average(pred, weights=roi_w)
            # import pdb; pdb.set_trace()
            return pred[np.argsort(roi_w)[-nb_top_avg:]].mean()
        elif img_voting:
            pred = pred.reshape((-1, roi_per_img))
            img_preds = [np.sort(row)[-nb_top_avg:].mean() for row in pred]
            return np.mean(img_preds)
        else:
            return np.sort(pred)[-nb_top_avg:].mean()

    # Print header.
    fout = open(out_pred, 'w')
    if validation_mode:
        fout.write(dminfer.INFER_HEADER_VAL)
    else:
        fout.write(dminfer.INFER_HEADER)

    for subj, exidx, exam in exam_list:
        try:
            predL = pred_img_list(exam['L']['img'])
        except KeyError:
            predL = .0
        try:
            predR = pred_img_list(exam['R']['img'])
        except KeyError:
            predR = .0

        try:
            cancerL = int(exam['L']['cancer'])
        except ValueError:
            cancerL = 0
        try:
            cancerR = int(exam['R']['cancer'])
        except ValueError:
            cancerR = 0

        if validation_mode:
            fout.write("%s\t%s\tL\t%f\t%d\n" % \
                       (str(subj), str(exidx), predL, cancerL))
            fout.write("%s\t%s\tR\t%f\t%d\n" % \
                       (str(subj), str(exidx), predR, cancerR))
        else:
            fout.write("%s\tL\t%f\n" % (str(subj), predL))
            fout.write("%s\tR\t%f\n" % (str(subj), predR))

    fout.close()
Example #2
0
from meta import DMMetaManager
meta_man = DMMetaManager(img_folder='preprocessedData/png_288x224/',
                         img_extension='png')
exam_list = meta_man.get_flatten_exam_list()
img_list = meta_man.get_flatten_img_list()
from dm_image import DMImageDataGenerator
img_gen = DMImageDataGenerator(featurewise_center=True,
                               featurewise_std_normalization=True)
img_gen.mean = 7772.
img_gen.std = 12187.
datgen_exam = img_gen.flow_from_exam_list(exam_list,
                                          target_size=(288, 224),
                                          batch_size=8,
                                          shuffle=False,
                                          seed=123)
datgen_image = img_gen.flow_from_img_list(img_list[0],
                                          img_list[1],
                                          target_size=(288, 224),
                                          batch_size=32,
                                          shuffle=False,
                                          seed=123)
import numpy as np
def run(train_dir,
        val_dir,
        test_dir,
        patch_model_state=None,
        resume_from=None,
        img_size=[1152, 896],
        img_scale=None,
        rescale_factor=None,
        featurewise_center=True,
        featurewise_mean=52.16,
        equalize_hist=False,
        augmentation=True,
        class_list=['neg', 'pos'],
        patch_net='resnet50',
        block_type='resnet',
        top_depths=[512, 512],
        top_repetitions=[3, 3],
        bottleneck_enlarge_factor=4,
        add_heatmap=False,
        avg_pool_size=[7, 7],
        add_conv=True,
        add_shortcut=False,
        hm_strides=(1, 1),
        hm_pool_size=(5, 5),
        fc_init_units=64,
        fc_layers=2,
        top_layer_nb=None,
        batch_size=64,
        train_bs_multiplier=.5,
        nb_epoch=5,
        all_layer_epochs=20,
        load_val_ram=False,
        load_train_ram=False,
        weight_decay=.0001,
        hidden_dropout=.0,
        weight_decay2=.0001,
        hidden_dropout2=.0,
        optim='sgd',
        init_lr=.01,
        lr_patience=10,
        es_patience=25,
        auto_batch_balance=False,
        pos_cls_weight=1.0,
        neg_cls_weight=1.0,
        all_layer_multiplier=.1,
        best_model='./modelState/image_clf.h5',
        final_model="NOSAVE"):
    '''Train a deep learning model for image classifications
    '''

    # ======= Environmental variables ======== #
    random_seed = int(os.getenv('RANDOM_SEED', 12345))
    nb_worker = int(os.getenv('NUM_CPU_CORES', 4))
    gpu_count = int(os.getenv('NUM_GPU_DEVICES', 1))

    # ========= Image generator ============== #
    if featurewise_center:
        train_imgen = DMImageDataGenerator(featurewise_center=True)
        val_imgen = DMImageDataGenerator(featurewise_center=True)
        test_imgen = DMImageDataGenerator(featurewise_center=True)
        train_imgen.mean = featurewise_mean
        val_imgen.mean = featurewise_mean
        test_imgen.mean = featurewise_mean
    else:
        train_imgen = DMImageDataGenerator()
        val_imgen = DMImageDataGenerator()
        test_imgen = DMImageDataGenerator()

    # Add augmentation options.
    if augmentation:
        train_imgen.horizontal_flip = True
        train_imgen.vertical_flip = True
        train_imgen.rotation_range = 25.  # in degree.
        train_imgen.shear_range = .2  # in radians.
        train_imgen.zoom_range = [.8, 1.2]  # in proportion.
        train_imgen.channel_shift_range = 20.  # in pixel intensity values.

    # ================= Model creation ============== #
    if resume_from is not None:
        image_model = load_model(resume_from, compile=False)
    else:
        patch_model = load_model(patch_model_state, compile=False)
        image_model, top_layer_nb = add_top_layers(
            patch_model,
            img_size,
            patch_net,
            block_type,
            top_depths,
            top_repetitions,
            bottleneck_org,
            nb_class=len(class_list),
            shortcut_with_bn=True,
            bottleneck_enlarge_factor=bottleneck_enlarge_factor,
            dropout=hidden_dropout,
            weight_decay=weight_decay,
            add_heatmap=add_heatmap,
            avg_pool_size=avg_pool_size,
            add_conv=add_conv,
            add_shortcut=add_shortcut,
            hm_strides=hm_strides,
            hm_pool_size=hm_pool_size,
            fc_init_units=fc_init_units,
            fc_layers=fc_layers)
    if gpu_count > 1:
        image_model, org_model = make_parallel(image_model, gpu_count)
    else:
        org_model = image_model

    # ============ Train & validation set =============== #
    train_bs = int(batch_size * train_bs_multiplier)
    dup_3_channels = True
    if load_train_ram:
        raw_imgen = DMImageDataGenerator()
        print "Create generator for raw train set"
        raw_generator = raw_imgen.flow_from_directory(
            train_dir,
            target_size=img_size,
            target_scale=img_scale,
            rescale_factor=rescale_factor,
            equalize_hist=equalize_hist,
            dup_3_channels=dup_3_channels,
            classes=class_list,
            class_mode='categorical',
            batch_size=train_bs,
            shuffle=False)
        print "Loading raw train set into RAM.",
        sys.stdout.flush()
        raw_set = load_dat_ram(raw_generator, raw_generator.nb_sample)
        print "Done."
        sys.stdout.flush()
        print "Create generator for train set"
        train_generator = train_imgen.flow(
            raw_set[0],
            raw_set[1],
            batch_size=train_bs,
            auto_batch_balance=auto_batch_balance,
            shuffle=True,
            seed=random_seed)
    else:
        print "Create generator for train set"
        train_generator = train_imgen.flow_from_directory(
            train_dir,
            target_size=img_size,
            target_scale=img_scale,
            rescale_factor=rescale_factor,
            equalize_hist=equalize_hist,
            dup_3_channels=dup_3_channels,
            classes=class_list,
            class_mode='categorical',
            auto_batch_balance=auto_batch_balance,
            batch_size=train_bs,
            shuffle=True,
            seed=random_seed)

    print "Create generator for val set"
    validation_set = val_imgen.flow_from_directory(
        val_dir,
        target_size=img_size,
        target_scale=img_scale,
        rescale_factor=rescale_factor,
        equalize_hist=equalize_hist,
        dup_3_channels=dup_3_channels,
        classes=class_list,
        class_mode='categorical',
        batch_size=batch_size,
        shuffle=False)
    sys.stdout.flush()
    if load_val_ram:
        print "Loading validation set into RAM.",
        sys.stdout.flush()
        validation_set = load_dat_ram(validation_set, validation_set.nb_sample)
        print "Done."
        sys.stdout.flush()

    # ==================== Model training ==================== #
    # Do 2-stage training.
    train_batches = int(train_generator.nb_sample / train_bs) + 1
    if isinstance(validation_set, tuple):
        val_samples = len(validation_set[0])
    else:
        val_samples = validation_set.nb_sample
    validation_steps = int(val_samples / batch_size)
    #### DEBUG ####
    # train_batches = 1
    # val_samples = batch_size*5
    # validation_steps = 5
    #### DEBUG ####
    if load_val_ram:
        auc_checkpointer = DMAucModelCheckpoint(best_model,
                                                validation_set,
                                                batch_size=batch_size)
    else:
        auc_checkpointer = DMAucModelCheckpoint(best_model,
                                                validation_set,
                                                test_samples=val_samples)
    # import pdb; pdb.set_trace()
    image_model, loss_hist, acc_hist = do_2stage_training(
        image_model,
        org_model,
        train_generator,
        validation_set,
        validation_steps,
        best_model,
        train_batches,
        top_layer_nb,
        nb_epoch=nb_epoch,
        all_layer_epochs=all_layer_epochs,
        optim=optim,
        init_lr=init_lr,
        all_layer_multiplier=all_layer_multiplier,
        es_patience=es_patience,
        lr_patience=lr_patience,
        auto_batch_balance=auto_batch_balance,
        pos_cls_weight=pos_cls_weight,
        neg_cls_weight=neg_cls_weight,
        nb_worker=nb_worker,
        auc_checkpointer=auc_checkpointer,
        weight_decay=weight_decay,
        hidden_dropout=hidden_dropout,
        weight_decay2=weight_decay2,
        hidden_dropout2=hidden_dropout2,
    )

    # Training report.
    if len(loss_hist) > 0:
        min_loss_locs, = np.where(loss_hist == min(loss_hist))
        best_val_loss = loss_hist[min_loss_locs[0]]
        best_val_accuracy = acc_hist[min_loss_locs[0]]
        print "\n==== Training summary ===="
        print "Minimum val loss achieved at epoch:", min_loss_locs[0] + 1
        print "Best val loss:", best_val_loss
        print "Best val accuracy:", best_val_accuracy

    if final_model != "NOSAVE":
        image_model.save(final_model)

    # ==== Predict on test set ==== #
    print "\n==== Predicting on test set ===="
    test_generator = test_imgen.flow_from_directory(
        test_dir,
        target_size=img_size,
        target_scale=img_scale,
        rescale_factor=rescale_factor,
        equalize_hist=equalize_hist,
        dup_3_channels=dup_3_channels,
        classes=class_list,
        class_mode='categorical',
        batch_size=batch_size,
        shuffle=False)
    test_samples = test_generator.nb_sample
    #### DEBUG ####
    # test_samples = 5
    #### DEBUG ####
    print "Test samples =", test_samples
    print "Load saved best model:", best_model + '.',
    sys.stdout.flush()
    org_model.load_weights(best_model)
    print "Done."
    # test_steps = int(test_generator.nb_sample/batch_size)
    # test_res = image_model.evaluate_generator(
    #     test_generator, test_steps, nb_worker=nb_worker,
    #     pickle_safe=True if nb_worker > 1 else False)
    test_auc = DMAucModelCheckpoint.calc_test_auc(test_generator,
                                                  image_model,
                                                  test_samples=test_samples)
    print "AUROC on test set:", test_auc
def run(img_folder, dl_state, fprop_mode=False,
        img_size=(1152, 896), img_height=None, img_scale=None, 
        rescale_factor=None,
        equalize_hist=False, featurewise_center=False, featurewise_mean=71.8,
        net='vgg19', batch_size=128, patch_size=256, stride=8,
        avg_pool_size=(7, 7), hm_strides=(1, 1),
        pat_csv='./full_img/pat.csv', pat_list=None,
        out='./output/prob_heatmap.pkl'):
    '''Sweep mammograms with trained DL model to create prob heatmaps
    '''
    # Read some env variables.
    random_seed = int(os.getenv('RANDOM_SEED', 12345))
    rng = RandomState(random_seed)  # an rng used across board.
    gpu_count = int(os.getenv('NUM_GPU_DEVICES', 1))

    # Create image generator.
    imgen = DMImageDataGenerator(featurewise_center=featurewise_center)
    imgen.mean = featurewise_mean

    # Get image and label lists.
    df = pd.read_csv(pat_csv, header=0)
    df = df.set_index(['patient_id', 'side'])
    df.sort_index(inplace=True)
    if pat_list is not None:
        pat_ids = pd.read_csv(pat_list, header=0).values.ravel()
        pat_ids = pat_ids.tolist()
        print "Read %d patient IDs" % (len(pat_ids))
        df = df.loc[pat_ids]

    # Load DL model, preprocess.
    print "Load patch classifier:", dl_state; sys.stdout.flush()
    dl_model, preprocess_input, _ = get_dl_model(net, resume_from=dl_state)
    if fprop_mode:
        dl_model = add_top_layers(dl_model, img_size, patch_net=net, 
                                  avg_pool_size=avg_pool_size, 
                                  return_heatmap=True, hm_strides=hm_strides)
    if gpu_count > 1:
        print "Make the model parallel on %d GPUs" % (gpu_count)
        sys.stdout.flush()
        dl_model, _ = make_parallel(dl_model, gpu_count)
        parallelized = True
    else:
        parallelized = False
    if featurewise_center:
        preprocess_input = None

    # Sweep the whole images and classify patches.
    def const_filename(pat, side, view):
        basename = '_'.join([pat, side, view]) + '.png'
        return os.path.join(img_folder, basename)

    print "Generate prob heatmaps"; sys.stdout.flush()
    heatmaps = []
    cases_seen = 0
    nb_cases = len(df.index.unique())
    for i, (pat,side) in enumerate(df.index.unique()):
        ## DEBUG ##
        #if i >= 10:
        #    break
        ## DEBUG ##
        cancer = df.loc[pat].loc[side]['cancer']
        cc_fn = const_filename(pat, side, 'CC')
        if os.path.isfile(cc_fn):
            if fprop_mode:
                cc_x = read_img_for_pred(
                    cc_fn, equalize_hist=equalize_hist, data_format=data_format,
                    dup_3_channels=True, 
                    transformer=imgen.random_transform,
                    standardizer=imgen.standardize,
                    target_size=img_size, target_scale=img_scale,
                    rescale_factor=rescale_factor)
                cc_x = cc_x.reshape((1,) + cc_x.shape)
                cc_hm = dl_model.predict_on_batch(cc_x)[0]
                # import pdb; pdb.set_trace()
            else:
                cc_hm = get_prob_heatmap(
                    cc_fn, img_height, img_scale, patch_size, stride, 
                    dl_model, batch_size, featurewise_center=featurewise_center, 
                    featurewise_mean=featurewise_mean, preprocess=preprocess_input, 
                    parallelized=parallelized, equalize_hist=equalize_hist)
        else:
            cc_hm = None
        mlo_fn = const_filename(pat, side, 'MLO')
        if os.path.isfile(mlo_fn):
            if fprop_mode:
                mlo_x = read_img_for_pred(
                    mlo_fn, equalize_hist=equalize_hist, data_format=data_format,
                    dup_3_channels=True, 
                    transformer=imgen.random_transform,
                    standardizer=imgen.standardize,
                    target_size=img_size, target_scale=img_scale,
                    rescale_factor=rescale_factor)
                mlo_x = mlo_x.reshape((1,) + mlo_x.shape)
                mlo_hm = dl_model.predict_on_batch(mlo_x)[0]
            else:
                mlo_hm = get_prob_heatmap(
                    mlo_fn, img_height, img_scale, patch_size, stride, 
                    dl_model, batch_size, featurewise_center=featurewise_center, 
                    featurewise_mean=featurewise_mean, preprocess=preprocess_input, 
                    parallelized=parallelized, equalize_hist=equalize_hist)
        else:
            mlo_hm = None
        heatmaps.append({'patient_id':pat, 'side':side, 'cancer':cancer, 
                         'cc':cc_hm, 'mlo':mlo_hm})
        print "scored %d/%d cases" % (i + 1, nb_cases)
        sys.stdout.flush()
    print "Done."

    # Save the result.
    print "Saving result to external files.",
    sys.stdout.flush()
    pickle.dump(heatmaps, open(out, 'w'))
    print "Done."
Example #5
0
def run(img_folder,
        img_size=[288, 224],
        do_featurewise_norm=True,
        featurewise_mean=485.9,
        featurewise_std=765.2,
        batch_size=16,
        img_tsv='./metadata/images_crosswalk_prediction.tsv',
        exam_tsv=None,
        dl_state=None,
        enet_state=None,
        validation_mode=False,
        use_mean=False,
        out_pred='./output/predictions.tsv'):
    '''Run SC1 inference
    Args:
        featurewise_mean, featurewise_std ([float]): they are estimated from 
                1152 x 896 images. Using different sized images give very close
                results. For png, mean=7772, std=12187.
    '''

    # Setup data generator for inference.
    meta_man = DMMetaManager(img_tsv=img_tsv,
                             exam_tsv=exam_tsv,
                             img_folder=img_folder,
                             img_extension='dcm')
    if validation_mode:
        exam_list = meta_man.get_flatten_exam_list()
    else:
        exam_list = meta_man.get_last_exam_list()
    if do_featurewise_norm:
        img_gen = DMImageDataGenerator(featurewise_center=True,
                                       featurewise_std_normalization=True)
        img_gen.mean = featurewise_mean
        img_gen.std = featurewise_std
    else:
        img_gen = DMImageDataGenerator(samplewise_center=True,
                                       samplewise_std_normalization=True)
    if validation_mode:
        class_mode = 'binary'
    else:
        class_mode = None
    datgen_exam = img_gen.flow_from_exam_list(exam_list,
                                              target_size=(img_size[0],
                                                           img_size[1]),
                                              class_mode=class_mode,
                                              prediction_mode=True,
                                              batch_size=batch_size)

    if enet_state is not None:
        model = MultiViewDLElasticNet(*enet_state)
    elif dl_state is not None:
        model = load_model(dl_state)
    else:
        raise Exception('At least one model state must be specified.')
    exams_seen = 0
    fout = open(out_pred, 'w')

    # Print header.
    if validation_mode:
        fout.write(dminfer.INFER_HEADER_VAL)
    else:
        fout.write(dminfer.INFER_HEADER)

    while exams_seen < len(exam_list):
        ebat = next(datgen_exam)
        if class_mode is not None:
            bat_x = ebat[0]
            bat_y = ebat[1]
        else:
            bat_x = ebat
        subj_batch = bat_x[0]
        exam_batch = bat_x[1]
        cc_batch = bat_x[2]
        mlo_batch = bat_x[3]
        for i, subj in enumerate(subj_batch):
            exam = exam_batch[i]
            li = i * 2  # left breast index.
            ri = i * 2 + 1  # right breast index.
            left_pred = dminfer.pred_2view_img_list(cc_batch[li],
                                                    mlo_batch[li], model,
                                                    use_mean)
            right_pred = dminfer.pred_2view_img_list(cc_batch[ri],
                                                     mlo_batch[ri], model,
                                                     use_mean)
            if validation_mode:
                fout.write("%s\t%s\tL\t%f\t%f\n" % \
                           (str(subj), str(exam), left_pred, bat_y[li]))
                fout.write("%s\t%s\tR\t%f\t%f\n" % \
                           (str(subj), str(exam), right_pred, bat_y[ri]))
            else:
                fout.write("%s\tL\t%f\n" % (str(subj), left_pred))
                fout.write("%s\tR\t%f\n" % (str(subj), right_pred))

        exams_seen += len(subj_batch)

    fout.close()
Example #6
0
def run(img_folder,
        dl_state,
        img_extension='dcm',
        img_height=1024,
        img_scale=4095,
        val_size=.2,
        neg_vs_pos_ratio=10.,
        do_featurewise_norm=True,
        featurewise_mean=873.6,
        featurewise_std=739.3,
        img_per_batch=2,
        roi_per_img=32,
        roi_size=(256, 256),
        low_int_threshold=.05,
        blob_min_area=3,
        blob_min_int=.5,
        blob_max_int=.85,
        blob_th_step=10,
        layer_name=['flatten_1', 'dense_1'],
        layer_index=None,
        roi_state=None,
        roi_clf_bs=32,
        pc_components=.95,
        pc_whiten=True,
        nb_words=[512],
        km_max_iter=100,
        km_bs=1000,
        km_patience=20,
        km_init=10,
        exam_tsv='./metadata/exams_metadata.tsv',
        img_tsv='./metadata/images_crosswalk.tsv',
        pca_km_states='./modelState/dlrepr_pca_km_models.pkl',
        bow_train_out='./modelState/bow_dat_train.pkl',
        bow_test_out='./modelState/bow_dat_test.pkl'):
    '''Calculate bag of deep visual words count matrix for all breasts
    '''

    # Read some env variables.
    random_seed = int(os.getenv('RANDOM_SEED', 12345))
    rng = RandomState(random_seed)  # an rng used across board.

    # Load and split image and label lists.
    meta_man = DMMetaManager(exam_tsv=exam_tsv,
                             img_tsv=img_tsv,
                             img_folder=img_folder,
                             img_extension=img_extension)
    subj_list, subj_labs = meta_man.get_subj_labs()
    subj_train, subj_test, labs_train, labs_test = train_test_split(
        subj_list,
        subj_labs,
        test_size=val_size,
        stratify=subj_labs,
        random_state=random_seed)
    if neg_vs_pos_ratio is not None:

        def subset_subj(subj, labs):
            subj = np.array(subj)
            labs = np.array(labs)
            pos_idx = np.where(labs == 1)[0]
            neg_idx = np.where(labs == 0)[0]
            nb_neg_desired = int(len(pos_idx) * neg_vs_pos_ratio)
            if nb_neg_desired >= len(neg_idx):
                return subj.tolist()
            else:
                neg_chosen = rng.choice(neg_idx, nb_neg_desired, replace=False)
                subset_idx = np.concatenate([pos_idx, neg_chosen])
                return subj[subset_idx].tolist()

        subj_train = subset_subj(subj_train, labs_train)
        subj_test = subset_subj(subj_test, labs_test)

    img_list, lab_list = meta_man.get_flatten_img_list(subj_train)
    lab_list = np.array(lab_list)
    print "Train set - Nb of positive images: %d, Nb of negative images: %d" \
            % ( (lab_list==1).sum(), (lab_list==0).sum())
    sys.stdout.flush()

    # Create image generator for ROIs for representation extraction.
    print "Create an image generator for ROIs"
    sys.stdout.flush()
    if do_featurewise_norm:
        imgen = DMImageDataGenerator(featurewise_center=True,
                                     featurewise_std_normalization=True)
        imgen.mean = featurewise_mean
        imgen.std = featurewise_std
    else:
        imgen = DMImageDataGenerator(samplewise_center=True,
                                     samplewise_std_normalization=True)

    # Load ROI classifier.
    if roi_state is not None:
        print "Load ROI classifier"
        sys.stdout.flush()
        roi_clf = load_model(roi_state,
                             custom_objects={
                                 'sensitivity': dmm.sensitivity,
                                 'specificity': dmm.specificity
                             })
        graph = tf.get_default_graph()
    else:
        roi_clf = None
        graph = None

    # Create ROI generators for pos and neg images separately.
    print "Create ROI generators for pos and neg images"
    sys.stdout.flush()
    roi_generator = imgen.flow_from_candid_roi(
        img_list,
        target_height=img_height,
        target_scale=img_scale,
        class_mode=None,
        validation_mode=True,
        img_per_batch=img_per_batch,
        roi_per_img=roi_per_img,
        roi_size=roi_size,
        low_int_threshold=low_int_threshold,
        blob_min_area=blob_min_area,
        blob_min_int=blob_min_int,
        blob_max_int=blob_max_int,
        blob_th_step=blob_th_step,
        tf_graph=graph,
        roi_clf=roi_clf,
        clf_bs=roi_clf_bs,
        return_sample_weight=False,
        seed=random_seed)

    # Generate image patches and extract their DL representations.
    print "Load DL representation model"
    sys.stdout.flush()
    dlrepr_model = DLRepr(dl_state,
                          custom_objects={
                              'sensitivity': dmm.sensitivity,
                              'specificity': dmm.specificity
                          },
                          layer_name=layer_name,
                          layer_index=layer_index)
    last_output_size = dlrepr_model.get_output_shape()[-1][-1]
    if last_output_size != 3 and last_output_size != 1:
        raise Exception("The last output must be prob outputs (size=3 or 1)")

    nb_tot_samples = len(img_list) * roi_per_img
    print "Extract ROIs from pos and neg images"
    sys.stdout.flush()
    pred = dlrepr_model.predict_generator(roi_generator,
                                          val_samples=nb_tot_samples)
    for i, d in enumerate(pred):
        print "Shape of representation/output data %d:" % (i), d.shape
    sys.stdout.flush()

    # Flatten feature maps, e.g. an 8x8 feature map will become a 64-d vector.
    pred = [d.reshape((-1, d.shape[-1])) for d in pred]
    for i, d in enumerate(pred):
        print "Shape of flattened data %d:" % (i), d.shape
    sys.stdout.flush()

    # Split representations and prob outputs.
    dl_repr = pred[0]
    prob_out = pred[1]
    if prob_out.shape[1] == 3:
        prob_out = prob_out[:, 1]  # pos class.
    prob_out = prob_out.reshape((len(img_list), -1))
    print "Reshape prob output to:", prob_out.shape
    sys.stdout.flush()

    # Use PCA to reduce dimension of the representation data.
    if pc_components is not None:
        print "Start PCA dimension reduction on DL representation"
        sys.stdout.flush()
        pca = PCA(n_components=pc_components, whiten=pc_whiten)
        pca.fit(dl_repr)
        print "Nb of PCA components:", pca.n_components_
        print "Total explained variance ratio: %.4f" % \
                (pca.explained_variance_ratio_.sum())
        dl_repr_pca = pca.transform(dl_repr)
        print "Shape of transformed representation data:", dl_repr_pca.shape
        sys.stdout.flush()
    else:
        pca = None

    # Use K-means to create a codebook for deep visual words.
    print "Start K-means training on DL representation"
    sys.stdout.flush()
    clf_list = []
    clust_list = []
    # Shuffling indices for mini-batches learning.
    perm_idx = rng.permutation(len(dl_repr))
    for n in nb_words:
        print "Train K-means with %d cluster centers" % (n)
        sys.stdout.flush()
        clf = MiniBatchKMeans(n_clusters=n,
                              init='k-means++',
                              max_iter=km_max_iter,
                              batch_size=km_bs,
                              compute_labels=True,
                              random_state=random_seed,
                              tol=0.0,
                              max_no_improvement=km_patience,
                              init_size=None,
                              n_init=km_init,
                              reassignment_ratio=0.01,
                              verbose=0)
        clf.fit(dl_repr[perm_idx])
        clf_list.append(clf)
        clust = np.zeros_like(clf.labels_)
        clust[perm_idx] = clf.labels_
        clust = clust.reshape((len(img_list), -1))
        clust_list.append(clust)

    if pca is not None:
        print "Start K-means training on transformed representation"
        sys.stdout.flush()
        clf_list_pca = []
        clust_list_pca = []
        # Shuffling indices for mini-batches learning.
        perm_idx = rng.permutation(len(dl_repr_pca))
        for n in nb_words:
            print "Train K-means with %d cluster centers" % (n)
            sys.stdout.flush()
            clf = MiniBatchKMeans(n_clusters=n,
                                  init='k-means++',
                                  max_iter=km_max_iter,
                                  batch_size=km_bs,
                                  compute_labels=True,
                                  random_state=random_seed,
                                  tol=0.0,
                                  max_no_improvement=km_patience,
                                  init_size=None,
                                  n_init=km_init,
                                  reassignment_ratio=0.01,
                                  verbose=0)
            clf.fit(dl_repr_pca[perm_idx])
            clf_list_pca.append(clf)
            clust = np.zeros_like(clf.labels_)
            clust[perm_idx] = clf.labels_
            clust = clust.reshape((len(img_list), -1))
            clust_list_pca.append(clust)

    # Read exam lists.
    exam_train = meta_man.get_flatten_exam_list(subj_train,
                                                flatten_img_list=True)
    exam_test = meta_man.get_flatten_exam_list(subj_test,
                                               flatten_img_list=True)
    exam_labs_train = np.array(meta_man.exam_labs(exam_train))
    exam_labs_test = np.array(meta_man.exam_labs(exam_test))
    nb_pos_exams_train = (exam_labs_train == 1).sum()
    nb_neg_exams_train = (exam_labs_train == 0).sum()
    nb_pos_exams_test = (exam_labs_test == 1).sum()
    nb_neg_exams_test = (exam_labs_test == 0).sum()
    print "Train set - Nb of pos exams: %d, Nb of neg exams: %d" % \
            (nb_pos_exams_train, nb_neg_exams_train)
    print "Test set - Nb of pos exams: %d, Nb of neg exams: %d" % \
            (nb_pos_exams_test, nb_neg_exams_test)

    # Do BoW counts for each breast.
    print "BoW counting for train exam list"
    sys.stdout.flush()
    bow_dat_train = get_exam_bow_dat(exam_train,
                                     nb_words,
                                     roi_per_img,
                                     img_list=img_list,
                                     prob_out=prob_out,
                                     clust_list=clust_list)
    for i, d in enumerate(bow_dat_train[1]):
        print "Shape of train BoW matrix %d:" % (i), d.shape
    sys.stdout.flush()

    print "BoW counting for test exam list"
    sys.stdout.flush()
    bow_dat_test = get_exam_bow_dat(exam_test,
                                    nb_words,
                                    roi_per_img,
                                    imgen=imgen,
                                    clf_list=clf_list,
                                    transformer=None,
                                    target_height=img_height,
                                    target_scale=img_scale,
                                    img_per_batch=img_per_batch,
                                    roi_size=roi_size,
                                    low_int_threshold=low_int_threshold,
                                    blob_min_area=blob_min_area,
                                    blob_min_int=blob_min_int,
                                    blob_max_int=blob_max_int,
                                    blob_th_step=blob_th_step,
                                    seed=random_seed,
                                    dlrepr_model=dlrepr_model)
    for i, d in enumerate(bow_dat_test[1]):
        print "Shape of test BoW matrix %d:" % (i), d.shape
    sys.stdout.flush()

    if pca is not None:
        print "== Do same BoW counting on PCA transformed data =="
        print "BoW counting for train exam list"
        sys.stdout.flush()
        bow_dat_train_pca = get_exam_bow_dat(exam_train,
                                             nb_words,
                                             roi_per_img,
                                             img_list=img_list,
                                             prob_out=prob_out,
                                             clust_list=clust_list_pca)
        for i, d in enumerate(bow_dat_train_pca[1]):
            print "Shape of train BoW matrix %d:" % (i), d.shape
        sys.stdout.flush()

        print "BoW counting for test exam list"
        sys.stdout.flush()
        bow_dat_test_pca = get_exam_bow_dat(
            exam_test,
            nb_words,
            roi_per_img,
            imgen=imgen,
            clf_list=clf_list_pca,
            transformer=pca,
            target_height=img_height,
            target_scale=img_scale,
            img_per_batch=img_per_batch,
            roi_size=roi_size,
            low_int_threshold=low_int_threshold,
            blob_min_area=blob_min_area,
            blob_min_int=blob_min_int,
            blob_max_int=blob_max_int,
            blob_th_step=blob_th_step,
            seed=random_seed,
            dlrepr_model=dlrepr_model)
        for i, d in enumerate(bow_dat_test_pca[1]):
            print "Shape of test BoW matrix %d:" % (i), d.shape
        sys.stdout.flush()

    # Save K-means model and BoW count data.
    if pca is None:
        pickle.dump(clf_list, open(pca_km_states, 'w'))
        pickle.dump(bow_dat_train, open(bow_train_out, 'w'))
        pickle.dump(bow_dat_test, open(bow_test_out, 'w'))
    else:
        pickle.dump((pca, clf_list), open(pca_km_states, 'w'))
        pickle.dump((bow_dat_train, bow_dat_train_pca),
                    open(bow_train_out, 'w'))
        pickle.dump((bow_dat_test, bow_dat_test_pca), open(bow_test_out, 'w'))

    print "Done."
def run(train_dir, val_dir, test_dir,
        img_size=[256, 256], img_scale=None, rescale_factor=None,
        featurewise_center=True, featurewise_mean=59.6,
        equalize_hist=True, augmentation=False,
        class_list=['background', 'malignant', 'benign'],
        batch_size=64, train_bs_multiplier=.5, nb_epoch=5,
        top_layer_epochs=10, all_layer_epochs=20,
        load_val_ram=False, load_train_ram=False,
        net='resnet50', use_pretrained=True,
        nb_init_filter=32, init_filter_size=5, init_conv_stride=2,
        pool_size=2, pool_stride=2,
        weight_decay=.0001, weight_decay2=.0001,
        alpha=.0001, l1_ratio=.0,
        inp_dropout=.0, hidden_dropout=.0, hidden_dropout2=.0,
        optim='sgd', init_lr=.01, lr_patience=10, es_patience=25,
        resume_from=None, auto_batch_balance=False,
        pos_cls_weight=1.0, neg_cls_weight=1.0,
        top_layer_nb=None, top_layer_multiplier=.1, all_layer_multiplier=.01,
        best_model='./modelState/patch_clf.h5',
        final_model="NOSAVE"):
    '''Train a deep learning model for patch classifications
    '''
    #给块分类训练一个深度学习模型
    # ======= Environmental variables ======== #
    random_seed = int(os.getenv('RANDOM_SEED', 12345))
    nb_worker = int(os.getenv('NUM_CPU_CORES', 4))
    gpu_count = int(os.getenv('NUM_GPU_DEVICES', 1))

    # ========= Image generator ============== #图片生成
    if featurewise_center:#数据集去中心化
        train_imgen = DMImageDataGenerator(featurewise_center=True)
        val_imgen = DMImageDataGenerator(featurewise_center=True)
        test_imgen = DMImageDataGenerator(featurewise_center=True)
        train_imgen.mean = featurewise_mean
        val_imgen.mean = featurewise_mean
        test_imgen.mean = featurewise_mean
    else:
        train_imgen = DMImageDataGenerator()
        val_imgen = DMImageDataGenerator()
        test_imgen = DMImageDataGenerator()

    # Add augmentation options.
    #图像增强
    if augmentation:
        train_imgen.horizontal_flip = True #进行随机水平翻转
        train_imgen.vertical_flip = True#进行随机垂直翻转
        train_imgen.rotation_range = 25.  # in degree.#整数,数据提升时图片随机转动的角度
        train_imgen.shear_range = .2  # in radians.浮点数,剪切强度(逆时针方向的剪切变换角度)
        train_imgen.zoom_range = [.8, 1.2]  # in proportion.
        '''
        浮点数或形如[lower,upper]的列表,随机缩放的幅度,若为浮点数,则相当于[lower,upper] = [1 - zoom_range, 1+zoom_range]
        '''
        train_imgen.channel_shift_range = 20.  # in pixel intensity values.
        #.浮点数,随机通道偏移的幅度
        #通过对颜色通道的数值偏移,改变图片的整体的颜色

    # ================= Model creation ============== #模型创建
    '''
    一、weight decay(权值衰减)使用的目的是防止过拟合。
    在损失函数中,weight decay是放在正则项(regularization)前面的一个系数,正则项一般指示模型的复杂度,
    所以weight decay的作用是调节模型复杂度对损失函数的影响,若weight decay很大,则复杂的模型损失函数的值也就大。
    hidden_dropout 防止过拟合
    init_conv_stride 卷积核步幅大小
    pool_size 池化层大小,pool_stride 池化层步幅(一般是最大值池化,和平均值)
    alpha 给图像添加透明度
    l1_ratio 交叉验证选择l1和l2惩罚之间的折中,类可以通过交叉验证来设置 alpha(α) 和 l1_ratio(ρ) **参数 :l1_ratio 参数来控制L1和L2的凸组合
    inp_dropout 输入权重随机抛弃
    '''
    model, preprocess_input, top_layer_nb = get_dl_model(
        net, nb_class=len(class_list), use_pretrained=use_pretrained,
        resume_from=resume_from, img_size=img_size, top_layer_nb=top_layer_nb,
        weight_decay=weight_decay, hidden_dropout=hidden_dropout,
        nb_init_filter=nb_init_filter, init_filter_size=init_filter_size,
        init_conv_stride=init_conv_stride, pool_size=pool_size,
        pool_stride=pool_stride, alpha=alpha, l1_ratio=l1_ratio,
        inp_dropout=inp_dropout)
    if featurewise_center:
        preprocess_input = None
    if gpu_count > 1:
        model, org_model = make_parallel(model, gpu_count)#并行计算
    else:
        org_model = model

    # ============ Train & validation set =============== #
    #训练和验证集
    train_bs = int(batch_size*train_bs_multiplier)#每批数据量的大小*乘数
    if net != 'yaroslav':#dm_keras_ext.py
        dup_3_channels = True
    else:
        dup_3_channels = False
    if load_train_ram:
        raw_imgen = DMImageDataGenerator()#t图片数据生成器
        #创建行训练集数据生成器
        print ("Create generator for raw train set")
        #以文件夹路径为参数,生成经过数据提升/归一化后的数据,在一个无限循环中无限产生batch数据
        '''
        equalize_hist 直方图均衡,
        shuffle 随机打乱数据
        '''
        raw_generator = raw_imgen.flow_from_directory(
            train_dir, target_size=img_size, target_scale=img_scale,
            rescale_factor=rescale_factor,
            equalize_hist=equalize_hist, dup_3_channels=dup_3_channels,
            classes=class_list, class_mode='categorical',
            batch_size=train_bs, shuffle=False)
        #加载行训练数据集到内存
        print ("Loading raw train set into RAM.",sys.stdout.flush())
        #行数据集
        raw_set = load_dat_ram(raw_generator, raw_generator.nb_sample)
        print ("Done."); sys.stdout.flush()
        #为训练集创建生成器
        print ("Create generator for train set")
        #接收numpy数组和标签为参数,生成经过数据提升或标准化后的batch数据,并在一个无限循环中不断的返回batch数据
        train_generator = train_imgen.flow(
            raw_set[0], raw_set[1], batch_size=train_bs,
            auto_batch_balance=auto_batch_balance, preprocess=preprocess_input,
            shuffle=True, seed=random_seed)
    else:
        print ("Create generator for train set")
        #以文件夹路径为参数,生成经过数据提升/归一化后的数据,在一个无限循环中无限产生batch数据
        train_generator = train_imgen.flow_from_directory(
            train_dir, target_size=img_size, target_scale=img_scale,
            rescale_factor=rescale_factor,
            equalize_hist=equalize_hist, dup_3_channels=dup_3_channels,
            classes=class_list, class_mode='categorical',
            auto_batch_balance=auto_batch_balance, batch_size=train_bs,
            preprocess=preprocess_input, shuffle=True, seed=random_seed)
    #创建验证集生成器
    print ("Create generator for val set")
    # 以文件夹路径为参数,生成经过数据提升/归一化后的数据,在一个无限循环中无限产生batch数据
    validation_set = val_imgen.flow_from_directory(
        val_dir, target_size=img_size, target_scale=img_scale,
        rescale_factor=rescale_factor,
        equalize_hist=equalize_hist, dup_3_channels=dup_3_channels,
        classes=class_list, class_mode='categorical',
        batch_size=batch_size, preprocess=preprocess_input, shuffle=False)
    sys.stdout.flush()
    #是否加载验证集到内存中
    if load_val_ram:
        print ("Loading validation set into RAM.",
        sys.stdout.flush())
        validation_set = load_dat_ram(validation_set, validation_set.nb_sample)
        print ("Done."); sys.stdout.flush()

    # ==================== Model training ==================== #模型训练
    # Do 3-stage training.三个阶段训练
    train_batches = int(train_generator.nb_sample/train_bs) + 1
    #判断验证集是否三元组
    if isinstance(validation_set, tuple):
        val_samples = len(validation_set[0])
    else:
        val_samples = validation_set.nb_sample
    validation_steps = int(val_samples/batch_size)
    #### DEBUG ####
    # val_samples = 100
    #### DEBUG ####
    # import pdb; pdb.set_trace()
    #通过三阶段训练得到模型,损失率,准确率
    model, loss_hist, acc_hist = do_3stage_training(
        model, org_model, train_generator, validation_set, validation_steps,
        best_model, train_batches, top_layer_nb, net, nb_epoch=nb_epoch,
        top_layer_epochs=top_layer_epochs, all_layer_epochs=all_layer_epochs,
        use_pretrained=use_pretrained, optim=optim, init_lr=init_lr,
        top_layer_multiplier=top_layer_multiplier,
        all_layer_multiplier=all_layer_multiplier,
        es_patience=es_patience, lr_patience=lr_patience,
        auto_batch_balance=auto_batch_balance, nb_class=len(class_list),
        pos_cls_weight=pos_cls_weight, neg_cls_weight=neg_cls_weight,
        nb_worker=nb_worker, weight_decay2=weight_decay2,
        hidden_dropout2=hidden_dropout2)

    # Training report.
    #训练报告
    if len(loss_hist) > 0:
        min_loss_locs, = np.where(loss_hist == min(loss_hist))
        best_val_loss = loss_hist[min_loss_locs[0]]
        best_val_accuracy = acc_hist[min_loss_locs[0]]
        print ("\n==== Training summary ====")
        print ("Minimum val loss achieved at epoch:", min_loss_locs[0] + 1)
        print ("Best val loss:", best_val_loss)
        print ("Best val accuracy:", best_val_accuracy)
#保存模型
    if final_model != "NOSAVE":
        model.save(final_model)

    # ==== Predict on test set ==== #
    #基于测试集的预测
    print ("\n==== Predicting on test set ====")
    # 以文件夹路径为参数,生成经过数据提升/归一化后的数据,在一个无限循环中无限产生batch数据
    test_generator = test_imgen.flow_from_directory(
        test_dir, target_size=img_size, target_scale=img_scale,
        rescale_factor=rescale_factor,
        equalize_hist=equalize_hist, dup_3_channels=dup_3_channels,
        classes=class_list, class_mode='categorical', batch_size=batch_size,
        preprocess=preprocess_input, shuffle=False)
    print ("Test samples =", test_generator.nb_sample)
    #加载最好的模型
    print ("Load saved best model:", best_model + '.',
    sys.stdout.flush())
    #原始模型加载最好模型的权重
    org_model.load_weights(best_model)
    print ("Done.")
    #测试的步数
    test_steps = int(test_generator.nb_sample/batch_size)
    #### DEBUG ####
    # test_samples = 10
    #### DEBUG ####
    test_res = model.evaluate_generator(
        test_generator, test_steps, nb_worker=nb_worker,
        pickle_safe=True if nb_worker > 1 else False)
    print ("Evaluation result on test set:", test_res)
def run(train_dir, val_dir, test_dir,
        img_size=[256, 256], img_scale=None, rescale_factor=None,
        featurewise_center=True, featurewise_mean=59.6, 
        equalize_hist=True, augmentation=False,
        class_list=['background', 'malignant', 'benign'],
        batch_size=64, train_bs_multiplier=.5, nb_epoch=5, 
        top_layer_epochs=10, all_layer_epochs=20,
        load_val_ram=False, load_train_ram=False,
        net='resnet50', use_pretrained=True,
        nb_init_filter=32, init_filter_size=5, init_conv_stride=2, 
        pool_size=2, pool_stride=2, 
        weight_decay=.0001, weight_decay2=.0001, 
        alpha=.0001, l1_ratio=.0, 
        inp_dropout=.0, hidden_dropout=.0, hidden_dropout2=.0, 
        optim='sgd', init_lr=.01, lr_patience=10, es_patience=25,
        resume_from=None, auto_batch_balance=False, 
        pos_cls_weight=1.0, neg_cls_weight=1.0,
        top_layer_nb=None, top_layer_multiplier=.1, all_layer_multiplier=.01,
        best_model='./modelState/patch_clf.h5',
        final_model="NOSAVE"):
    '''Train a deep learning model for patch classifications
    '''

    # ======= Environmental variables ======== #
    random_seed = int(os.getenv('RANDOM_SEED', 12345))
    nb_worker = int(os.getenv('NUM_CPU_CORES', 4))
    gpu_count = int(os.getenv('NUM_GPU_DEVICES', 1))

    # ========= Image generator ============== #
    if featurewise_center:
        train_imgen = DMImageDataGenerator(featurewise_center=True)
        val_imgen = DMImageDataGenerator(featurewise_center=True)
        test_imgen = DMImageDataGenerator(featurewise_center=True)
        train_imgen.mean = featurewise_mean
        val_imgen.mean = featurewise_mean
        test_imgen.mean = featurewise_mean
    else:
        train_imgen = DMImageDataGenerator()
        val_imgen = DMImageDataGenerator()
        test_imgen = DMImageDataGenerator()

    # Add augmentation options.
    if augmentation:
        train_imgen.horizontal_flip = True 
        train_imgen.vertical_flip = True
        train_imgen.rotation_range = 25.  # in degree.
        train_imgen.shear_range = .2  # in radians.
        train_imgen.zoom_range = [.8, 1.2]  # in proportion.
        train_imgen.channel_shift_range = 20.  # in pixel intensity values.

    # ================= Model creation ============== #
    model, preprocess_input, top_layer_nb = get_dl_model(
        net, nb_class=len(class_list), use_pretrained=use_pretrained,
        resume_from=resume_from, img_size=img_size, top_layer_nb=top_layer_nb,
        weight_decay=weight_decay, hidden_dropout=hidden_dropout, 
        nb_init_filter=nb_init_filter, init_filter_size=init_filter_size, 
        init_conv_stride=init_conv_stride, pool_size=pool_size, 
        pool_stride=pool_stride, alpha=alpha, l1_ratio=l1_ratio, 
        inp_dropout=inp_dropout)
    if featurewise_center:
        preprocess_input = None
    if gpu_count > 1:
        model, org_model = make_parallel(model, gpu_count)
    else:
        org_model = model

    # ============ Train & validation set =============== #
    train_bs = int(batch_size*train_bs_multiplier)
    if net != 'yaroslav':
        dup_3_channels = True
    else:
        dup_3_channels = False
    if load_train_ram:
        raw_imgen = DMImageDataGenerator()
        print "Create generator for raw train set"
        raw_generator = raw_imgen.flow_from_directory(
            train_dir, target_size=img_size, target_scale=img_scale, 
            rescale_factor=rescale_factor,
            equalize_hist=equalize_hist, dup_3_channels=dup_3_channels,
            classes=class_list, class_mode='categorical', 
            batch_size=train_bs, shuffle=False)
        print "Loading raw train set into RAM.",
        sys.stdout.flush()
        raw_set = load_dat_ram(raw_generator, raw_generator.nb_sample)
        print "Done."; sys.stdout.flush()
        print "Create generator for train set"
        train_generator = train_imgen.flow(
            raw_set[0], raw_set[1], batch_size=train_bs, 
            auto_batch_balance=auto_batch_balance, preprocess=preprocess_input, 
            shuffle=True, seed=random_seed)
    else:
        print "Create generator for train set"
        train_generator = train_imgen.flow_from_directory(
            train_dir, target_size=img_size, target_scale=img_scale,
            rescale_factor=rescale_factor,
            equalize_hist=equalize_hist, dup_3_channels=dup_3_channels,
            classes=class_list, class_mode='categorical', 
            auto_batch_balance=auto_batch_balance, batch_size=train_bs, 
            preprocess=preprocess_input, shuffle=True, seed=random_seed)

    print "Create generator for val set"
    validation_set = val_imgen.flow_from_directory(
        val_dir, target_size=img_size, target_scale=img_scale,
        rescale_factor=rescale_factor,
        equalize_hist=equalize_hist, dup_3_channels=dup_3_channels,
        classes=class_list, class_mode='categorical', 
        batch_size=batch_size, preprocess=preprocess_input, shuffle=False)
    sys.stdout.flush()
    if load_val_ram:
        print "Loading validation set into RAM.",
        sys.stdout.flush()
        validation_set = load_dat_ram(validation_set, validation_set.nb_sample)
        print "Done."; sys.stdout.flush()

    # ==================== Model training ==================== #
    # Do 3-stage training.
    train_batches = int(train_generator.nb_sample/train_bs) + 1
    if isinstance(validation_set, tuple):
        val_samples = len(validation_set[0])
    else:
        val_samples = validation_set.nb_sample
    validation_steps = int(val_samples/batch_size)
    #### DEBUG ####
    # val_samples = 100
    #### DEBUG ####
    # import pdb; pdb.set_trace()
    model, loss_hist, acc_hist = do_3stage_training(
        model, org_model, train_generator, validation_set, validation_steps, 
        best_model, train_batches, top_layer_nb, net, nb_epoch=nb_epoch,
        top_layer_epochs=top_layer_epochs, all_layer_epochs=all_layer_epochs,
        use_pretrained=use_pretrained, optim=optim, init_lr=init_lr, 
        top_layer_multiplier=top_layer_multiplier, 
        all_layer_multiplier=all_layer_multiplier,
        es_patience=es_patience, lr_patience=lr_patience, 
        auto_batch_balance=auto_batch_balance, nb_class=len(class_list),
        pos_cls_weight=pos_cls_weight, neg_cls_weight=neg_cls_weight,
        nb_worker=nb_worker, weight_decay2=weight_decay2, 
        hidden_dropout2=hidden_dropout2)

    # Training report.
    if len(loss_hist) > 0:
        min_loss_locs, = np.where(loss_hist == min(loss_hist))
        best_val_loss = loss_hist[min_loss_locs[0]]
        best_val_accuracy = acc_hist[min_loss_locs[0]]
        print "\n==== Training summary ===="
        print "Minimum val loss achieved at epoch:", min_loss_locs[0] + 1
        print "Best val loss:", best_val_loss
        print "Best val accuracy:", best_val_accuracy

    if final_model != "NOSAVE":
        model.save(final_model)

    # ==== Predict on test set ==== #
    print "\n==== Predicting on test set ===="
    test_generator = test_imgen.flow_from_directory(
        test_dir, target_size=img_size, target_scale=img_scale,
        rescale_factor=rescale_factor,
        equalize_hist=equalize_hist, dup_3_channels=dup_3_channels, 
        classes=class_list, class_mode='categorical', batch_size=batch_size, 
        preprocess=preprocess_input, shuffle=False)
    print "Test samples =", test_generator.nb_sample
    print "Load saved best model:", best_model + '.',
    sys.stdout.flush()
    org_model.load_weights(best_model)
    print "Done."
    test_steps = int(test_generator.nb_sample/batch_size)
    #### DEBUG ####
    # test_samples = 10
    #### DEBUG ####
    test_res = model.evaluate_generator(
        test_generator, test_steps, nb_worker=nb_worker, 
        pickle_safe=True if nb_worker > 1 else False)
    print "Evaluation result on test set:", test_res
def run(img_folder, img_extension='dcm', 
        img_height=1024, img_scale=4095, 
        do_featurewise_norm=True, norm_fit_size=10,
        img_per_batch=2, roi_per_img=32, roi_size=(256, 256), 
        one_patch_mode=False,
        low_int_threshold=.05, blob_min_area=3, 
        blob_min_int=.5, blob_max_int=.85, blob_th_step=10,
        data_augmentation=False, roi_state=None, clf_bs=32, cutpoint=.5,
        amp_factor=1., return_sample_weight=True, auto_batch_balance=True,
        patches_per_epoch=12800, nb_epoch=20, 
        neg_vs_pos_ratio=None, all_neg_skip=0., 
        nb_init_filter=32, init_filter_size=5, init_conv_stride=2, 
        pool_size=2, pool_stride=2, 
        weight_decay=.0001, alpha=.0001, l1_ratio=.0, 
        inp_dropout=.0, hidden_dropout=.0, init_lr=.01,
        test_size=.2, val_size=.0, 
        lr_patience=3, es_patience=10, 
        resume_from=None, net='resnet50', load_val_ram=False, 
        load_train_ram=False, no_pos_skip=0., balance_classes=0.,
        pred_img_per_batch=1, pred_roi_per_img=32,
        exam_tsv='./metadata/exams_metadata.tsv',
        img_tsv='./metadata/images_crosswalk.tsv',
        best_model='./modelState/dm_candidROI_best_model.h5',
        final_model="NOSAVE",
        pred_trainval=False, pred_out="dl_pred_out.pkl"):
    '''Run ResNet training on candidate ROIs from mammograms
    Args:
        norm_fit_size ([int]): the number of patients used to calculate 
                feature-wise mean and std.
    '''

    # Read some env variables.
    random_seed = int(os.getenv('RANDOM_SEED', 12345))
    # Use of multiple CPU cores is not working!
    # When nb_worker>1 and pickle_safe=True, this error is encountered:
    # "failed to enqueue async memcpy from host to device: CUDA_ERROR_NOT_INITIALIZED"
    # To avoid the error, only this combination worked: 
    # nb_worker=1 and pickle_safe=False.
    nb_worker = int(os.getenv('NUM_CPU_CORES', 4))
    gpu_count = int(os.getenv('NUM_GPU_DEVICES', 1))
    
    # Setup training and validation data.
    # Load image or exam lists and split them into train and val sets.
    meta_man = DMMetaManager(exam_tsv=exam_tsv, 
                             img_tsv=img_tsv, 
                             img_folder=img_folder, 
                             img_extension=img_extension)
    # Split data based on subjects.
    subj_list, subj_labs = meta_man.get_subj_labs()
    subj_train, subj_test, slab_train, slab_test = train_test_split(
        subj_list, subj_labs, test_size=test_size, random_state=random_seed, 
        stratify=subj_labs)
    if val_size > 0:  # train/val split.
        subj_train, subj_val, slab_train, slab_val = train_test_split(
            subj_train, slab_train, test_size=val_size, 
            random_state=random_seed, stratify=slab_train)
    else:  # use test as val. make a copy of the test list.
        subj_val = list(subj_test)
        slab_val = list(slab_test)
    # import pdb; pdb.set_trace()
    # Subset subject lists to desired ratio.
    if neg_vs_pos_ratio is not None:
        subj_train, slab_train = DMMetaManager.subset_subj_list(
            subj_train, slab_train, neg_vs_pos_ratio, random_seed)
        subj_val, slab_val = DMMetaManager.subset_subj_list(
            subj_val, slab_val, neg_vs_pos_ratio, random_seed)
    print "After sampling, Nb of subjects for train=%d, val=%d, test=%d" \
            % (len(subj_train), len(subj_val), len(subj_test))
    # Get image and label lists.
    img_train, lab_train = meta_man.get_flatten_img_list(subj_train)
    img_val, lab_val = meta_man.get_flatten_img_list(subj_val)

    # Create image generators for train, fit and val.
    imgen_trainval = DMImageDataGenerator()
    if data_augmentation:
        imgen_trainval.horizontal_flip=True 
        imgen_trainval.vertical_flip=True
        imgen_trainval.rotation_range = 45.
        imgen_trainval.shear_range = np.pi/8.
        # imgen_trainval.width_shift_range = .05
        # imgen_trainval.height_shift_range = .05
        # imgen_trainval.zoom_range = [.95, 1.05]

    if do_featurewise_norm:
        imgen_trainval.featurewise_center = True
        imgen_trainval.featurewise_std_normalization = True
        # Fit feature-wise mean and std.
        img_fit,_ = meta_man.get_flatten_img_list(
            subj_train[:norm_fit_size])  # fit on a subset.
        print ">>> Fit image generator <<<"; sys.stdout.flush()
        fit_generator = imgen_trainval.flow_from_candid_roi(
            img_fit,
            target_height=img_height, target_scale=img_scale,
            class_mode=None, validation_mode=True, 
            img_per_batch=len(img_fit), roi_per_img=roi_per_img, 
            roi_size=roi_size,
            low_int_threshold=low_int_threshold, blob_min_area=blob_min_area, 
            blob_min_int=blob_min_int, blob_max_int=blob_max_int, 
            blob_th_step=blob_th_step,
            roi_clf=None, return_sample_weight=False, seed=random_seed)
        imgen_trainval.fit(fit_generator.next())
        print "Estimates from %d images: mean=%.1f, std=%.1f." % \
            (len(img_fit), imgen_trainval.mean, imgen_trainval.std)
        sys.stdout.flush()
    else:
        imgen_trainval.samplewise_center = True
        imgen_trainval.samplewise_std_normalization = True

    # Load ROI classifier.
    if roi_state is not None:
        roi_clf = load_model(
            roi_state, 
            custom_objects={
                'sensitivity': DMMetrics.sensitivity, 
                'specificity': DMMetrics.specificity
            }
        )
        graph = tf.get_default_graph()
    else:
        roi_clf = None
        graph = None

    # Set some DL training related parameters.
    if one_patch_mode:
        class_mode = 'binary'
        loss = 'binary_crossentropy'
        metrics = [DMMetrics.sensitivity, DMMetrics.specificity]
    else:
        class_mode = 'categorical'
        loss = 'categorical_crossentropy'
        metrics = ['accuracy', 'precision', 'recall']
    if load_train_ram:
        validation_mode = True
        return_raw_img = True
    else:
        validation_mode = False
        return_raw_img = False

    # Create train and val generators.
    print ">>> Train image generator <<<"; sys.stdout.flush()
    train_generator = imgen_trainval.flow_from_candid_roi(
        img_train, lab_train, 
        target_height=img_height, target_scale=img_scale,
        class_mode=class_mode, validation_mode=validation_mode, 
        img_per_batch=img_per_batch, roi_per_img=roi_per_img, 
        roi_size=roi_size, one_patch_mode=one_patch_mode,
        low_int_threshold=low_int_threshold, blob_min_area=blob_min_area, 
        blob_min_int=blob_min_int, blob_max_int=blob_max_int, 
        blob_th_step=blob_th_step,
        tf_graph=graph, roi_clf=roi_clf, clf_bs=clf_bs, cutpoint=cutpoint,
        amp_factor=amp_factor, return_sample_weight=return_sample_weight,
        auto_batch_balance=auto_batch_balance,
        all_neg_skip=all_neg_skip, shuffle=True, seed=random_seed,
        return_raw_img=return_raw_img)

    print ">>> Validation image generator <<<"; sys.stdout.flush()
    val_generator = imgen_trainval.flow_from_candid_roi(
        img_val, lab_val, 
        target_height=img_height, target_scale=img_scale,
        class_mode=class_mode, validation_mode=True, 
        img_per_batch=img_per_batch, roi_per_img=roi_per_img, 
        roi_size=roi_size, one_patch_mode=one_patch_mode,
        low_int_threshold=low_int_threshold, blob_min_area=blob_min_area, 
        blob_min_int=blob_min_int, blob_max_int=blob_max_int, 
        blob_th_step=blob_th_step,
        tf_graph=graph, roi_clf=roi_clf, clf_bs=clf_bs, cutpoint=cutpoint,
        amp_factor=amp_factor, return_sample_weight=False, 
        auto_batch_balance=False,
        seed=random_seed)

    # Load train and validation set into RAM.
    if one_patch_mode:
        nb_train_samples = len(img_train)
        nb_val_samples = len(img_val)
    else:
        nb_train_samples = len(img_train)*roi_per_img
        nb_val_samples = len(img_val)*roi_per_img
    if load_val_ram:
        print "Loading validation data into RAM.",
        sys.stdout.flush()
        validation_set = load_dat_ram(val_generator, nb_val_samples)
        print "Done."; sys.stdout.flush()
        sparse_y = to_sparse(validation_set[1])
        for uy in np.unique(sparse_y):
            print "Nb of samples for class:%d = %d" % \
                    (uy, (sparse_y==uy).sum())
        sys.stdout.flush()
    if load_train_ram:
        print "Loading train data into RAM.",
        sys.stdout.flush()
        train_set = load_dat_ram(train_generator, nb_train_samples)
        print "Done."; sys.stdout.flush()
        sparse_y = to_sparse(train_set[1])
        for uy in np.unique(sparse_y):
            print "Nb of samples for class:%d = %d" % \
                    (uy, (sparse_y==uy).sum())
        sys.stdout.flush()
        train_generator = imgen_trainval.flow(
            train_set[0], train_set[1], batch_size=clf_bs, 
            auto_batch_balance=auto_batch_balance, no_pos_skip=no_pos_skip,
            balance_classes=balance_classes, shuffle=True, seed=random_seed)

    # Load or create model.
    if resume_from is not None:
        model = load_model(
            resume_from,
            custom_objects={
                'sensitivity': DMMetrics.sensitivity, 
                'specificity': DMMetrics.specificity
            }
        )
    else:
        builder = ResNetBuilder
        if net == 'resnet18':
            model = builder.build_resnet_18(
                (1, roi_size[0], roi_size[1]), 3, nb_init_filter, init_filter_size, 
                init_conv_stride, pool_size, pool_stride, weight_decay, alpha, l1_ratio, 
                inp_dropout, hidden_dropout)
        elif net == 'resnet34':
            model = builder.build_resnet_34(
                (1, roi_size[0], roi_size[1]), 3, nb_init_filter, init_filter_size, 
                init_conv_stride, pool_size, pool_stride, weight_decay, alpha, l1_ratio, 
                inp_dropout, hidden_dropout)
        elif net == 'resnet50':
            model = builder.build_resnet_50(
                (1, roi_size[0], roi_size[1]), 3, nb_init_filter, init_filter_size, 
                init_conv_stride, pool_size, pool_stride, weight_decay, alpha, l1_ratio, 
                inp_dropout, hidden_dropout)
        elif net == 'resnet101':
            model = builder.build_resnet_101(
                (1, roi_size[0], roi_size[1]), 3, nb_init_filter, init_filter_size, 
                init_conv_stride, pool_size, pool_stride, weight_decay, alpha, l1_ratio, 
                inp_dropout, hidden_dropout)
        elif net == 'resnet152':
            model = builder.build_resnet_152(
                (1, roi_size[0], roi_size[1]), 3, nb_init_filter, init_filter_size, 
                init_conv_stride, pool_size, pool_stride, weight_decay, alpha, l1_ratio, 
                inp_dropout, hidden_dropout)
    
    if gpu_count > 1:
        model = make_parallel(model, gpu_count)

    # Model training.
    sgd = SGD(lr=init_lr, momentum=0.9, decay=0.0, nesterov=True)
    model.compile(optimizer=sgd, loss=loss, metrics=metrics)
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, 
                                  patience=lr_patience, verbose=1)
    early_stopping = EarlyStopping(monitor='val_loss', patience=es_patience, 
                                   verbose=1)
    if load_val_ram:
        auc_checkpointer = DMAucModelCheckpoint(
            best_model, validation_set, batch_size=clf_bs)
    else:
        auc_checkpointer = DMAucModelCheckpoint(
            best_model, val_generator, nb_test_samples=nb_val_samples)
    hist = model.fit_generator(
        train_generator, 
        samples_per_epoch=patches_per_epoch, 
        nb_epoch=nb_epoch,
        validation_data=validation_set if load_val_ram else val_generator, 
        nb_val_samples=nb_val_samples, 
        callbacks=[reduce_lr, early_stopping, auc_checkpointer],
        # nb_worker=1, pickle_safe=False,
        nb_worker=nb_worker if load_train_ram else 1,
        pickle_safe=True if load_train_ram else False,
        verbose=2)

    if final_model != "NOSAVE":
        print "Saving final model to:", final_model; sys.stdout.flush()
        model.save(final_model)
    
    # Training report.
    min_loss_locs, = np.where(hist.history['val_loss'] == min(hist.history['val_loss']))
    best_val_loss = hist.history['val_loss'][min_loss_locs[0]]
    if one_patch_mode:
        best_val_sensitivity = hist.history['val_sensitivity'][min_loss_locs[0]]
        best_val_specificity = hist.history['val_specificity'][min_loss_locs[0]]
    else:
        best_val_precision = hist.history['val_precision'][min_loss_locs[0]]
        best_val_recall = hist.history['val_recall'][min_loss_locs[0]]
        best_val_accuracy = hist.history['val_acc'][min_loss_locs[0]]
    print "\n==== Training summary ===="
    print "Minimum val loss achieved at epoch:", min_loss_locs[0] + 1
    print "Best val loss:", best_val_loss
    if one_patch_mode:
        print "Best val sensitivity:", best_val_sensitivity
        print "Best val specificity:", best_val_specificity
    else:
        print "Best val precision:", best_val_precision
        print "Best val recall:", best_val_recall
        print "Best val accuracy:", best_val_accuracy

    # Make predictions on train, val, test exam lists.
    if best_model != 'NOSAVE':
        print "\n==== Making predictions ===="
        print "Load best model for prediction:", best_model
        sys.stdout.flush()
        pred_model = load_model(best_model)
        if gpu_count > 1:
            pred_model = make_parallel(pred_model, gpu_count)
        
        if pred_trainval:
            print "Load exam lists for train, val sets"; sys.stdout.flush()
            exam_train = meta_man.get_flatten_exam_list(
                subj_train, flatten_img_list=True)
            print "Train exam list length=", len(exam_train); sys.stdout.flush()
            exam_val = meta_man.get_flatten_exam_list(
                subj_val, flatten_img_list=True)
            print "Val exam list length=", len(exam_val); sys.stdout.flush()
        print "Load exam list for test set"; sys.stdout.flush()
        exam_test = meta_man.get_flatten_exam_list(
            subj_test, flatten_img_list=True)
        print "Test exam list length=", len(exam_test); sys.stdout.flush()
        
        if do_featurewise_norm:
            imgen_pred = DMImageDataGenerator()
            imgen_pred.featurewise_center = True
            imgen_pred.featurewise_std_normalization = True
            imgen_pred.mean = imgen_trainval.mean
            imgen_pred.std = imgen_trainval.std
        else:
            imgen_pred.samplewise_center = True
            imgen_pred.samplewise_std_normalization = True
        
        if pred_trainval:
            print "Make predictions on train exam list"; sys.stdout.flush()
            meta_prob_train = get_exam_pred(
                exam_train, pred_roi_per_img, imgen_pred, 
                target_height=img_height, target_scale=img_scale,
                img_per_batch=pred_img_per_batch, roi_size=roi_size,
                low_int_threshold=low_int_threshold, blob_min_area=blob_min_area, 
                blob_min_int=blob_min_int, blob_max_int=blob_max_int, 
                blob_th_step=blob_th_step, seed=random_seed, 
                dl_model=pred_model)
            print "Train prediction list length=", len(meta_prob_train)
            
            print "Make predictions on val exam list"; sys.stdout.flush()
            meta_prob_val = get_exam_pred(
                exam_val, pred_roi_per_img, imgen_pred, 
                target_height=img_height, target_scale=img_scale,
                img_per_batch=pred_img_per_batch, roi_size=roi_size,
                low_int_threshold=low_int_threshold, blob_min_area=blob_min_area, 
                blob_min_int=blob_min_int, blob_max_int=blob_max_int, 
                blob_th_step=blob_th_step, seed=random_seed, 
                dl_model=pred_model)
            print "Val prediction list length=", len(meta_prob_val)
        
        print "Make predictions on test exam list"; sys.stdout.flush()
        meta_prob_test = get_exam_pred(
            exam_test, pred_roi_per_img, imgen_pred, 
            target_height=img_height, target_scale=img_scale,
            img_per_batch=pred_img_per_batch, roi_size=roi_size,
            low_int_threshold=low_int_threshold, blob_min_area=blob_min_area, 
            blob_min_int=blob_min_int, blob_max_int=blob_max_int, 
            blob_th_step=blob_th_step, seed=random_seed, 
            dl_model=pred_model)
        print "Test prediction list length=", len(meta_prob_test)
        
        if pred_trainval:
            pickle.dump((meta_prob_train, meta_prob_val, meta_prob_test), 
                        open(pred_out, 'w'))
        else:
            pickle.dump(meta_prob_test, open(pred_out, 'w'))

    return hist
Example #10
0
def run(img_folder,
        img_size=[288, 224],
        do_featurewise_norm=True,
        featurewise_mean=485.9,
        featurewise_std=765.2,
        img_tsv='./metadata/images_crosswalk.tsv',
        exam_tsv='./metadata/exams_metadata.tsv',
        dl_state=None,
        enet_state=None,
        xgb_state=None,
        validation_mode=False,
        use_mean=False,
        out_pred='./output/predictions.tsv'):
    '''Run SC2 inference
    Args:
        featurewise_mean, featurewise_std ([float]): they are estimated from 
                1152 x 896 images. Using different sized images give very close
                results. For png, mean=7772, std=12187.
    '''

    # Setup data generator for inference.
    meta_man = DMMetaManager(img_tsv=img_tsv,
                             exam_tsv=exam_tsv,
                             img_folder=img_folder,
                             img_extension='dcm')
    last2_exgen = meta_man.last_2_exam_generator()
    if do_featurewise_norm:
        img_gen = DMImageDataGenerator(featurewise_center=True,
                                       featurewise_std_normalization=True)
        img_gen.mean = featurewise_mean
        img_gen.std = featurewise_std
    else:
        img_gen = DMImageDataGenerator(samplewise_center=True,
                                       samplewise_std_normalization=True)
    if validation_mode:
        class_mode = 'binary'
    else:
        class_mode = None

    # Image prediction model.
    if enet_state is not None:
        model = MultiViewDLElasticNet(*enet_state)
    elif dl_state is not None:
        model = load_model(dl_state)
    else:
        raise Exception('At least one image model state must be specified.')

    # XGB model.
    xgb_clf = pickle.load(open(xgb_state))

    # Print header.
    fout = open(out_pred, 'w')
    if validation_mode:
        fout.write(dminfer.INFER_HEADER_VAL)
    else:
        fout.write(dminfer.INFER_HEADER)

    # Loop through all last 2 exam pairs.
    for subj_id, curr_idx, curr_dat, prior_idx, prior_dat in last2_exgen:
        # Get meta info for both breasts.
        left_record, right_record = meta_man.get_info_exam_pair(
            curr_dat, prior_dat)
        nb_days = left_record['daysSincePreviousExam']

        # Get image data and make predictions.
        exam_list = []
        exam_list.append(
            (subj_id, curr_idx, meta_man.get_info_per_exam(curr_dat)))
        if prior_idx is not None:
            exam_list.append(
                (subj_id, prior_idx, meta_man.get_info_per_exam(prior_dat)))
        datgen_exam = img_gen.flow_from_exam_list(exam_list,
                                                  target_size=(img_size[0],
                                                               img_size[1]),
                                                  class_mode=class_mode,
                                                  prediction_mode=True,
                                                  batch_size=len(exam_list),
                                                  verbose=False)
        ebat = next(datgen_exam)
        if class_mode is not None:
            bat_x = ebat[0]
            bat_y = ebat[1]
        else:
            bat_x = ebat
        cc_batch = bat_x[2]
        mlo_batch = bat_x[3]
        curr_left_score = dminfer.pred_2view_img_list(cc_batch[0],
                                                      mlo_batch[0], model,
                                                      use_mean)
        curr_right_score = dminfer.pred_2view_img_list(cc_batch[1],
                                                       mlo_batch[1], model,
                                                       use_mean)
        if prior_idx is not None:
            prior_left_score = dminfer.pred_2view_img_list(
                cc_batch[2], mlo_batch[2], model, use_mean)
            prior_right_score = dminfer.pred_2view_img_list(
                cc_batch[3], mlo_batch[3], model, use_mean)
            diff_left_score = (curr_left_score -
                               prior_left_score) / nb_days * 365
            diff_right_score = (curr_right_score -
                                prior_right_score) / nb_days * 365
        else:
            prior_left_score = np.nan
            prior_right_score = np.nan
            diff_left_score = np.nan
            diff_right_score = np.nan

        # Merge image scores into meta info.
        left_record = left_record\
                .assign(curr_score=curr_left_score)\
                .assign(prior_score=prior_left_score)\
                .assign(diff_score=diff_left_score)
        right_record = right_record\
                .assign(curr_score=curr_right_score)\
                .assign(prior_score=prior_right_score)\
                .assign(diff_score=diff_right_score)
        dsubj = xgb.DMatrix(
            pd.concat([left_record, right_record], ignore_index=True))

        # Predict using XGB.
        pred = xgb_clf.predict(dsubj, ntree_limit=xgb_clf.best_ntree_limit)

        # Output.
        if validation_mode:
            fout.write("%s\t%s\tL\t%f\t%f\n" % \
                       (str(subj_id), str(curr_idx), pred[0], bat_y[0]))
            fout.write("%s\t%s\tR\t%f\t%f\n" % \
                       (str(subj_id), str(curr_idx), pred[1], bat_y[1]))
        else:
            fout.write("%s\tL\t%f\n" % (str(subj_id), pred[0]))
            fout.write("%s\tR\t%f\n" % (str(subj_id), pred[1]))

    fout.close()
Example #11
0
def run(img_folder,
        dl_state,
        best_model,
        img_extension='dcm',
        img_height=1024,
        img_scale=255.,
        equalize_hist=False,
        featurewise_center=False,
        featurewise_mean=91.6,
        neg_vs_pos_ratio=1.,
        val_size=.1,
        test_size=.15,
        net='vgg19',
        batch_size=128,
        train_bs_multiplier=.5,
        patch_size=256,
        stride=8,
        roi_cutoff=.9,
        bkg_cutoff=[.5, 1.],
        sample_bkg=True,
        train_out='./scratch/train',
        val_out='./scratch/val',
        test_out='./scratch/test',
        out_img_ext='png',
        neg_name='benign',
        pos_name='malignant',
        bkg_name='background',
        augmentation=True,
        load_train_ram=False,
        load_val_ram=False,
        top_layer_nb=None,
        nb_epoch=10,
        top_layer_epochs=0,
        all_layer_epochs=0,
        optim='sgd',
        init_lr=.01,
        top_layer_multiplier=.01,
        all_layer_multiplier=.0001,
        es_patience=5,
        lr_patience=2,
        weight_decay2=.01,
        bias_multiplier=.1,
        hidden_dropout2=.0,
        exam_tsv='./metadata/exams_metadata.tsv',
        img_tsv='./metadata/images_crosswalk.tsv',
        out='./modelState/subj_lists.pkl'):
    '''Finetune a trained DL model on a different dataset
    '''
    # Read some env variables.
    random_seed = int(os.getenv('RANDOM_SEED', 12345))
    rng = RandomState(random_seed)  # an rng used across board.
    nb_worker = int(os.getenv('NUM_CPU_CORES', 4))
    gpu_count = int(os.getenv('NUM_GPU_DEVICES', 1))

    # Load and split image and label lists.
    meta_man = DMMetaManager(exam_tsv=exam_tsv,
                             img_tsv=img_tsv,
                             img_folder=img_folder,
                             img_extension=img_extension)
    subj_list, subj_labs = meta_man.get_subj_labs()
    subj_labs = np.array(subj_labs)
    print "Found %d subjests" % (len(subj_list))
    print "cancer patients=%d, normal patients=%d" \
            % ((subj_labs==1).sum(), (subj_labs==0).sum())
    if neg_vs_pos_ratio is not None:
        subj_list, subj_labs = DMMetaManager.subset_subj_list(
            subj_list, subj_labs, neg_vs_pos_ratio, random_seed)
        subj_labs = np.array(subj_labs)
        print "After subsetting, there are %d subjects" % (len(subj_list))
        print "cancer patients=%d, normal patients=%d" \
                % ((subj_labs==1).sum(), (subj_labs==0).sum())
    subj_train, subj_test, labs_train, labs_test = train_test_split(
        subj_list,
        subj_labs,
        test_size=test_size,
        stratify=subj_labs,
        random_state=random_seed)
    subj_train, subj_val, labs_train, labs_val = train_test_split(
        subj_train,
        labs_train,
        test_size=val_size,
        stratify=labs_train,
        random_state=random_seed)

    # Get image lists.
    # >>>> Debug <<<< #
    # subj_train = subj_train[:5]
    # subj_val = subj_val[:5]
    # subj_test = subj_test[:5]
    # >>>> Debug <<<< #
    print "Get flattened image lists"
    img_train, ilab_train = meta_man.get_flatten_img_list(subj_train)
    img_val, ilab_val = meta_man.get_flatten_img_list(subj_val)
    img_test, ilab_test = meta_man.get_flatten_img_list(subj_test)
    ilab_train = np.array(ilab_train)
    ilab_val = np.array(ilab_val)
    ilab_test = np.array(ilab_test)
    print "On train set, positive img=%d, negative img=%d" \
            % ((ilab_train==1).sum(), (ilab_train==0).sum())
    print "On val set, positive img=%d, negative img=%d" \
            % ((ilab_val==1).sum(), (ilab_val==0).sum())
    print "On test set, positive img=%d, negative img=%d" \
            % ((ilab_test==1).sum(), (ilab_test==0).sum())
    sys.stdout.flush()

    # Save the subj lists.
    print "Saving subject lists to external files.",
    sys.stdout.flush()
    pickle.dump((subj_train, subj_val, subj_test), open(out, 'w'))
    print "Done."

    # Load DL model, preprocess function.
    print "Load patch classifier:", dl_state
    sys.stdout.flush()
    dl_model, preprocess_input, top_layer_nb = get_dl_model(
        net,
        use_pretrained=True,
        resume_from=dl_state,
        top_layer_nb=top_layer_nb)
    if featurewise_center:
        preprocess_input = None
    if gpu_count > 1:
        print "Make the model parallel on %d GPUs" % (gpu_count)
        sys.stdout.flush()
        dl_model, org_model = make_parallel(dl_model, gpu_count)
        parallelized = True
    else:
        org_model = dl_model
        parallelized = False

    # Sweep the whole images and classify patches.
    print "Score image patches and write them to:", train_out
    sys.stdout.flush()
    nb_roi_train, nb_bkg_train = score_write_patches(
        img_train,
        ilab_train,
        img_height,
        img_scale,
        patch_size,
        stride,
        dl_model,
        batch_size,
        neg_out=os.path.join(train_out, neg_name),
        pos_out=os.path.join(train_out, pos_name),
        bkg_out=os.path.join(train_out, bkg_name),
        preprocess=preprocess_input,
        equalize_hist=equalize_hist,
        featurewise_center=featurewise_center,
        featurewise_mean=featurewise_mean,
        roi_cutoff=roi_cutoff,
        bkg_cutoff=bkg_cutoff,
        sample_bkg=sample_bkg,
        img_ext=out_img_ext,
        random_seed=random_seed,
        parallelized=parallelized)
    print "Wrote %d ROI and %d bkg patches" % (nb_roi_train, nb_bkg_train)
    ####
    print "Score image patches and write them to:", val_out
    sys.stdout.flush()
    nb_roi_val, nb_bkg_val = score_write_patches(
        img_val,
        ilab_val,
        img_height,
        img_scale,
        patch_size,
        stride,
        dl_model,
        batch_size,
        neg_out=os.path.join(val_out, neg_name),
        pos_out=os.path.join(val_out, pos_name),
        bkg_out=os.path.join(val_out, bkg_name),
        preprocess=preprocess_input,
        equalize_hist=equalize_hist,
        featurewise_center=featurewise_center,
        featurewise_mean=featurewise_mean,
        roi_cutoff=roi_cutoff,
        bkg_cutoff=bkg_cutoff,
        sample_bkg=sample_bkg,
        img_ext=out_img_ext,
        random_seed=random_seed,
        parallelized=parallelized)
    print "Wrote %d ROI and %d bkg patches" % (nb_roi_val, nb_bkg_val)
    ####
    print "Score image patches and write them to:", test_out
    sys.stdout.flush()
    nb_roi_test, nb_bkg_test = score_write_patches(
        img_test,
        ilab_test,
        img_height,
        img_scale,
        patch_size,
        stride,
        dl_model,
        batch_size,
        neg_out=os.path.join(test_out, neg_name),
        pos_out=os.path.join(test_out, pos_name),
        bkg_out=os.path.join(test_out, bkg_name),
        preprocess=preprocess_input,
        equalize_hist=equalize_hist,
        featurewise_center=featurewise_center,
        featurewise_mean=featurewise_mean,
        roi_cutoff=roi_cutoff,
        bkg_cutoff=bkg_cutoff,
        sample_bkg=sample_bkg,
        img_ext=out_img_ext,
        random_seed=random_seed,
        parallelized=parallelized)
    print "Wrote %d ROI and %d bkg patches" % (nb_roi_test, nb_bkg_test)
    sys.stdout.flush()

    # ==== Image generators ==== #
    if featurewise_center:
        train_imgen = DMImageDataGenerator(featurewise_center=True)
        val_imgen = DMImageDataGenerator(featurewise_center=True)
        test_imgen = DMImageDataGenerator(featurewise_center=True)
        train_imgen.mean = featurewise_mean
        val_imgen.mean = featurewise_mean
        test_imgen.mean = featurewise_mean
    else:
        train_imgen = DMImageDataGenerator()
        val_imgen = DMImageDataGenerator()
        test_imgen = DMImageDataGenerator()
    if augmentation:
        train_imgen.horizontal_flip = True
        train_imgen.vertical_flip = True
        train_imgen.rotation_range = 45.
        train_imgen.shear_range = np.pi / 8.

    # ==== Train & val set ==== #
    # Note: the images are histogram equalized before they were written to
    # external folders.
    train_bs = int(batch_size * train_bs_multiplier)
    if load_train_ram:
        raw_imgen = DMImageDataGenerator()
        print "Create generator for raw train set"
        raw_generator = raw_imgen.flow_from_directory(
            train_out,
            target_size=(patch_size, patch_size),
            target_scale=img_scale,
            equalize_hist=False,
            dup_3_channels=True,
            classes=[bkg_name, pos_name, neg_name],
            class_mode='categorical',
            batch_size=train_bs,
            shuffle=False)
        print "Loading raw train set into RAM.",
        sys.stdout.flush()
        raw_set = load_dat_ram(raw_generator, raw_generator.nb_sample)
        print "Done."
        sys.stdout.flush()
        print "Create generator for train set"
        train_generator = train_imgen.flow(raw_set[0],
                                           raw_set[1],
                                           batch_size=train_bs,
                                           auto_batch_balance=True,
                                           preprocess=preprocess_input,
                                           shuffle=True,
                                           seed=random_seed)
    else:
        print "Create generator for train set"
        train_generator = train_imgen.flow_from_directory(
            train_out,
            target_size=(patch_size, patch_size),
            target_scale=img_scale,
            equalize_hist=False,
            dup_3_channels=True,
            classes=[bkg_name, pos_name, neg_name],
            class_mode='categorical',
            auto_batch_balance=True,
            batch_size=train_bs,
            preprocess=preprocess_input,
            shuffle=True,
            seed=random_seed)

    print "Create generator for val set"
    sys.stdout.flush()
    validation_set = val_imgen.flow_from_directory(
        val_out,
        target_size=(patch_size, patch_size),
        target_scale=img_scale,
        equalize_hist=False,
        dup_3_channels=True,
        classes=[bkg_name, pos_name, neg_name],
        class_mode='categorical',
        batch_size=batch_size,
        preprocess=preprocess_input,
        shuffle=False)
    val_samples = validation_set.nb_sample
    if parallelized and val_samples % batch_size != 0:
        val_samples -= val_samples % batch_size
    print "Validation samples =", val_samples
    sys.stdout.flush()
    if load_val_ram:
        print "Loading validation set into RAM.",
        sys.stdout.flush()
        validation_set = load_dat_ram(validation_set, val_samples)
        print "Done."
        print "Loaded %d val samples" % (len(validation_set[0]))
        sys.stdout.flush()

    # ==== Model finetuning ==== #
    train_batches = int(train_generator.nb_sample / train_bs) + 1
    samples_per_epoch = train_bs * train_batches
    # import pdb; pdb.set_trace()
    dl_model, loss_hist, acc_hist = do_3stage_training(
        dl_model,
        org_model,
        train_generator,
        validation_set,
        val_samples,
        best_model,
        samples_per_epoch,
        top_layer_nb,
        net,
        nb_epoch=nb_epoch,
        top_layer_epochs=top_layer_epochs,
        all_layer_epochs=all_layer_epochs,
        use_pretrained=True,
        optim=optim,
        init_lr=init_lr,
        top_layer_multiplier=top_layer_multiplier,
        all_layer_multiplier=all_layer_multiplier,
        es_patience=es_patience,
        lr_patience=lr_patience,
        auto_batch_balance=True,
        nb_worker=nb_worker,
        weight_decay2=weight_decay2,
        bias_multiplier=bias_multiplier,
        hidden_dropout2=hidden_dropout2)

    # Training report.
    min_loss_locs, = np.where(loss_hist == min(loss_hist))
    best_val_loss = loss_hist[min_loss_locs[0]]
    best_val_accuracy = acc_hist[min_loss_locs[0]]
    print "\n==== Training summary ===="
    print "Minimum val loss achieved at epoch:", min_loss_locs[0] + 1
    print "Best val loss:", best_val_loss
    print "Best val accuracy:", best_val_accuracy

    # ==== Predict on test set ==== #
    print "\n==== Predicting on test set ===="
    print "Create generator for test set"
    test_generator = test_imgen.flow_from_directory(
        test_out,
        target_size=(patch_size, patch_size),
        target_scale=img_scale,
        equalize_hist=False,
        dup_3_channels=True,
        classes=[bkg_name, pos_name, neg_name],
        class_mode='categorical',
        batch_size=batch_size,
        preprocess=preprocess_input,
        shuffle=False)
    test_samples = test_generator.nb_sample
    if parallelized and test_samples % batch_size != 0:
        test_samples -= test_samples % batch_size
    print "Test samples =", test_samples
    print "Load saved best model:", best_model + '.',
    sys.stdout.flush()
    org_model.load_weights(best_model)
    print "Done."
    test_res = dl_model.evaluate_generator(
        test_generator,
        test_samples,
        nb_worker=nb_worker,
        pickle_safe=True if nb_worker > 1 else False)
    print "Evaluation result on test set:", test_res
def run(img_folder,
        img_extension='dcm',
        img_size=[288, 224],
        img_scale=4095,
        multi_view=False,
        do_featurewise_norm=True,
        featurewise_mean=398.5,
        featurewise_std=627.8,
        batch_size=16,
        samples_per_epoch=160,
        nb_epoch=20,
        balance_classes=.0,
        all_neg_skip=0.,
        pos_cls_weight=1.0,
        nb_init_filter=64,
        init_filter_size=7,
        init_conv_stride=2,
        pool_size=3,
        pool_stride=2,
        weight_decay=.0001,
        alpha=1.,
        l1_ratio=.5,
        inp_dropout=.0,
        hidden_dropout=.0,
        init_lr=.01,
        val_size=.2,
        lr_patience=5,
        es_patience=10,
        resume_from=None,
        net='resnet50',
        load_val_ram=False,
        exam_tsv='./metadata/exams_metadata.tsv',
        img_tsv='./metadata/images_crosswalk.tsv',
        best_model='./modelState/dm_resnet_best_model.h5',
        final_model="NOSAVE"):
    '''Run ResNet training on mammograms using an exam or image list
    Args:
        featurewise_mean, featurewise_std ([float]): they are estimated from 
                1152 x 896 images. Using different sized images give very close
                results. For png, mean=7772, std=12187.
    '''

    # Read some env variables.
    random_seed = int(os.getenv('RANDOM_SEED', 12345))
    nb_worker = int(os.getenv('NUM_CPU_CORES', 4))
    gpu_count = int(os.getenv('NUM_GPU_DEVICES', 1))

    # Setup training and validation data.
    # Load image or exam lists and split them into train and val sets.
    meta_man = DMMetaManager(exam_tsv=exam_tsv,
                             img_tsv=img_tsv,
                             img_folder=img_folder,
                             img_extension=img_extension)
    if multi_view:
        exam_list = meta_man.get_flatten_exam_list()
        exam_train, exam_val = train_test_split(
            exam_list,
            test_size=val_size,
            random_state=random_seed,
            stratify=meta_man.exam_labs(exam_list))
        val_size_ = len(exam_val) * 2  # L and R.
    else:
        img_list, lab_list = meta_man.get_flatten_img_list()
        img_train, img_val, lab_train, lab_val = train_test_split(
            img_list,
            lab_list,
            test_size=val_size,
            random_state=random_seed,
            stratify=lab_list)
        val_size_ = len(img_val)

    # Create image generator.
    img_gen = DMImageDataGenerator(horizontal_flip=True, vertical_flip=True)
    if do_featurewise_norm:
        img_gen.featurewise_center = True
        img_gen.featurewise_std_normalization = True
        img_gen.mean = featurewise_mean
        img_gen.std = featurewise_std
    else:
        img_gen.samplewise_center = True
        img_gen.samplewise_std_normalization = True

    if multi_view:
        train_generator = img_gen.flow_from_exam_list(
            exam_train,
            target_size=(img_size[0], img_size[1]),
            target_scale=img_scale,
            batch_size=batch_size,
            balance_classes=balance_classes,
            all_neg_skip=all_neg_skip,
            shuffle=True,
            seed=random_seed,
            class_mode='binary')
        if load_val_ram:
            val_generator = img_gen.flow_from_exam_list(
                exam_val,
                target_size=(img_size[0], img_size[1]),
                target_scale=img_scale,
                batch_size=val_size_,
                validation_mode=True,
                class_mode='binary')
        else:
            val_generator = img_gen.flow_from_exam_list(
                exam_val,
                target_size=(img_size[0], img_size[1]),
                target_scale=img_scale,
                batch_size=batch_size,
                validation_mode=True,
                class_mode='binary')
    else:
        train_generator = img_gen.flow_from_img_list(
            img_train,
            lab_train,
            target_size=(img_size[0], img_size[1]),
            target_scale=img_scale,
            batch_size=batch_size,
            balance_classes=balance_classes,
            all_neg_skip=all_neg_skip,
            shuffle=True,
            seed=random_seed,
            class_mode='binary')
        if load_val_ram:
            val_generator = img_gen.flow_from_img_list(
                img_val,
                lab_val,
                target_size=(img_size[0], img_size[1]),
                target_scale=img_scale,
                batch_size=val_size_,
                validation_mode=True,
                class_mode='binary')
        else:
            val_generator = img_gen.flow_from_img_list(
                img_val,
                lab_val,
                target_size=(img_size[0], img_size[1]),
                target_scale=img_scale,
                batch_size=batch_size,
                validation_mode=True,
                class_mode='binary')

    # Load validation set into RAM.
    if load_val_ram:
        validation_set = next(val_generator)
        if not multi_view and len(validation_set[0]) != val_size_:
            raise Exception
        elif len(validation_set[0][0]) != val_size_ \
                or len(validation_set[0][1]) != val_size_:
            raise Exception

    # Create model.
    if resume_from is not None:
        model = load_model(resume_from,
                           custom_objects={
                               'sensitivity': DMMetrics.sensitivity,
                               'specificity': DMMetrics.specificity
                           })
    else:
        if multi_view:
            builder = MultiViewResNetBuilder
        else:
            builder = ResNetBuilder
        if net == 'resnet18':
            model = builder.build_resnet_18(
                (1, img_size[0], img_size[1]), 1, nb_init_filter,
                init_filter_size, init_conv_stride, pool_size, pool_stride,
                weight_decay, alpha, l1_ratio, inp_dropout, hidden_dropout)
        elif net == 'resnet34':
            model = builder.build_resnet_34(
                (1, img_size[0], img_size[1]), 1, nb_init_filter,
                init_filter_size, init_conv_stride, pool_size, pool_stride,
                weight_decay, alpha, l1_ratio, inp_dropout, hidden_dropout)
        elif net == 'resnet50':
            model = builder.build_resnet_50(
                (1, img_size[0], img_size[1]), 1, nb_init_filter,
                init_filter_size, init_conv_stride, pool_size, pool_stride,
                weight_decay, alpha, l1_ratio, inp_dropout, hidden_dropout)
        elif net == 'dmresnet14':
            model = builder.build_dm_resnet_14(
                (1, img_size[0], img_size[1]), 1, nb_init_filter,
                init_filter_size, init_conv_stride, pool_size, pool_stride,
                weight_decay, alpha, l1_ratio, inp_dropout, hidden_dropout)
        elif net == 'dmresnet47rb5':
            model = builder.build_dm_resnet_47rb5(
                (1, img_size[0], img_size[1]), 1, nb_init_filter,
                init_filter_size, init_conv_stride, pool_size, pool_stride,
                weight_decay, alpha, l1_ratio, inp_dropout, hidden_dropout)
        elif net == 'dmresnet56rb6':
            model = builder.build_dm_resnet_56rb6(
                (1, img_size[0], img_size[1]), 1, nb_init_filter,
                init_filter_size, init_conv_stride, pool_size, pool_stride,
                weight_decay, alpha, l1_ratio, inp_dropout, hidden_dropout)
        elif net == 'dmresnet65rb7':
            model = builder.build_dm_resnet_65rb7(
                (1, img_size[0], img_size[1]), 1, nb_init_filter,
                init_filter_size, init_conv_stride, pool_size, pool_stride,
                weight_decay, alpha, l1_ratio, inp_dropout, hidden_dropout)
        elif net == 'resnet101':
            model = builder.build_resnet_101(
                (1, img_size[0], img_size[1]), 1, nb_init_filter,
                init_filter_size, init_conv_stride, pool_size, pool_stride,
                weight_decay, alpha, l1_ratio, inp_dropout, hidden_dropout)
        elif net == 'resnet152':
            model = builder.build_resnet_152(
                (1, img_size[0], img_size[1]), 1, nb_init_filter,
                init_filter_size, init_conv_stride, pool_size, pool_stride,
                weight_decay, alpha, l1_ratio, inp_dropout, hidden_dropout)

    if gpu_count > 1:
        model = make_parallel(model, gpu_count)

    # Model training.
    sgd = SGD(lr=init_lr, momentum=0.9, decay=0.0, nesterov=True)
    model.compile(optimizer=sgd,
                  loss='binary_crossentropy',
                  metrics=[DMMetrics.sensitivity, DMMetrics.specificity])
    reduce_lr = ReduceLROnPlateau(monitor='val_loss',
                                  factor=0.1,
                                  patience=lr_patience,
                                  verbose=1)
    early_stopping = EarlyStopping(monitor='val_loss',
                                   patience=es_patience,
                                   verbose=1)
    if load_val_ram:
        auc_checkpointer = DMAucModelCheckpoint(best_model,
                                                validation_set,
                                                batch_size=batch_size)
    else:
        auc_checkpointer = DMAucModelCheckpoint(best_model,
                                                val_generator,
                                                nb_test_samples=val_size_)
    # checkpointer = ModelCheckpoint(
    #     best_model, monitor='val_loss', verbose=1, save_best_only=True)
    hist = model.fit_generator(
        train_generator,
        samples_per_epoch=samples_per_epoch,
        nb_epoch=nb_epoch,
        class_weight={
            0: 1.0,
            1: pos_cls_weight
        },
        validation_data=validation_set if load_val_ram else val_generator,
        nb_val_samples=val_size_,
        callbacks=[reduce_lr, early_stopping, auc_checkpointer],
        nb_worker=nb_worker,
        pickle_safe=True,  # turn on pickle_safe to avoid a strange error.
        verbose=2)

    # Training report.
    min_loss_locs, = np.where(
        hist.history['val_loss'] == min(hist.history['val_loss']))
    best_val_loss = hist.history['val_loss'][min_loss_locs[0]]
    best_val_sensitivity = hist.history['val_sensitivity'][min_loss_locs[0]]
    best_val_specificity = hist.history['val_specificity'][min_loss_locs[0]]
    print "\n==== Training summary ===="
    print "Minimum val loss achieved at epoch:", min_loss_locs[0] + 1
    print "Best val loss:", best_val_loss
    print "Best val sensitivity:", best_val_sensitivity
    print "Best val specificity:", best_val_specificity

    if final_model != "NOSAVE":
        model.save(final_model)

    return hist
def run(img_folder, dl_state, img_extension='dcm', 
        img_height=1024, img_scale=4095, val_size=.2, neg_vs_pos_ratio=10., 
        do_featurewise_norm=True, featurewise_mean=873.6, featurewise_std=739.3,
        img_per_batch=2, roi_per_img=32, roi_size=(256, 256), 
        low_int_threshold=.05, blob_min_area=3, 
        blob_min_int=.5, blob_max_int=.85, blob_th_step=10,
        exam_tsv='./metadata/exams_metadata.tsv',
        img_tsv='./metadata/images_crosswalk.tsv',
        train_out='./modelState/meta_prob_train.pkl',
        test_out='./modelState/meta_prob_test.pkl'):
    '''Calculate bag of deep visual words count matrix for all breasts
    '''

    # Read some env variables.
    random_seed = int(os.getenv('RANDOM_SEED', 12345))
    rng = RandomState(random_seed)  # an rng used across board.
    gpu_count = int(os.getenv('NUM_GPU_DEVICES', 1))

    # Load and split image and label lists.
    meta_man = DMMetaManager(exam_tsv=exam_tsv, 
                             img_tsv=img_tsv, 
                             img_folder=img_folder, 
                             img_extension=img_extension)
    subj_list, subj_labs = meta_man.get_subj_labs()
    subj_train, subj_test, labs_train, labs_test = train_test_split(
        subj_list, subj_labs, test_size=val_size, stratify=subj_labs, 
        random_state=random_seed)
    if neg_vs_pos_ratio is not None:
        def subset_subj(subj, labs):
            subj = np.array(subj)
            labs = np.array(labs)
            pos_idx = np.where(labs==1)[0]
            neg_idx = np.where(labs==0)[0]
            nb_neg_desired = int(len(pos_idx)*neg_vs_pos_ratio)
            if nb_neg_desired >= len(neg_idx):
                return subj.tolist()
            else:
                neg_chosen = rng.choice(neg_idx, nb_neg_desired, replace=False)
                subset_idx = np.concatenate([pos_idx, neg_chosen])
                return subj[subset_idx].tolist()

        subj_train = subset_subj(subj_train, labs_train)
        subj_test = subset_subj(subj_test, labs_test)

    # Create image generator for ROIs for representation extraction.
    print "Create an image generator for ROIs"; sys.stdout.flush()
    if do_featurewise_norm:
        imgen = DMImageDataGenerator(
            featurewise_center=True, 
            featurewise_std_normalization=True)
        imgen.mean = featurewise_mean
        imgen.std = featurewise_std
    else:
        imgen = DMImageDataGenerator(
            samplewise_center=True, 
            samplewise_std_normalization=True)

    # Load DL model.
    print "Load DL classification model:", dl_state; sys.stdout.flush()
    dl_model = load_model(
        dl_state, 
        custom_objects={
            'sensitivity': dmm.sensitivity, 
            'specificity': dmm.specificity
        }
    )
    if gpu_count > 1:
        print "Make the model parallel on %d GPUs" % (gpu_count)
        sys.stdout.flush()
        dl_model = make_parallel(dl_model, gpu_count)

    # Read exam lists.
    exam_train = meta_man.get_flatten_exam_list(
        subj_train, flatten_img_list=True)
    exam_test = meta_man.get_flatten_exam_list(
        subj_test, flatten_img_list=True)
    exam_labs_train = np.array(meta_man.exam_labs(exam_train))
    exam_labs_test = np.array(meta_man.exam_labs(exam_test))
    nb_pos_exams_train = (exam_labs_train==1).sum()
    nb_neg_exams_train = (exam_labs_train==0).sum()
    nb_pos_exams_test = (exam_labs_test==1).sum()
    nb_neg_exams_test = (exam_labs_test==0).sum()
    print "Train set - Nb of pos exams: %d, Nb of neg exams: %d" % \
            (nb_pos_exams_train, nb_neg_exams_train)
    print "Test set - Nb of pos exams: %d, Nb of neg exams: %d" % \
            (nb_pos_exams_test, nb_neg_exams_test)

    # Make predictions for exam lists.
    print "Predicting for train exam list"; sys.stdout.flush()
    meta_prob_train = get_exam_pred(
        exam_train, roi_per_img, imgen, 
        target_height=img_height, target_scale=img_scale,
        img_per_batch=img_per_batch, roi_size=roi_size,
        low_int_threshold=low_int_threshold, blob_min_area=blob_min_area, 
        blob_min_int=blob_min_int, blob_max_int=blob_max_int, 
        blob_th_step=blob_th_step, seed=random_seed, 
        dl_model=dl_model)
    print "Length of train prediction list:", len(meta_prob_train)
    sys.stdout.flush()

    print "Predicting for test exam list"; sys.stdout.flush()
    meta_prob_test = get_exam_pred(
        exam_test, roi_per_img, imgen, 
        target_height=img_height, target_scale=img_scale,
        img_per_batch=img_per_batch, roi_size=roi_size,
        low_int_threshold=low_int_threshold, blob_min_area=blob_min_area, 
        blob_min_int=blob_min_int, blob_max_int=blob_max_int, 
        blob_th_step=blob_th_step, seed=random_seed, 
        dl_model=dl_model)
    print "Length of test prediction list:", len(meta_prob_test)
    sys.stdout.flush()

    pickle.dump(meta_prob_train, open(train_out, 'w'))
    pickle.dump(meta_prob_test, open(test_out, 'w'))
    print "Done."
Example #14
0
def run(img_folder,
        img_extension='png',
        img_size=[288, 224],
        multi_view=False,
        do_featurewise_norm=True,
        featurewise_mean=7772.,
        featurewise_std=12187.,
        batch_size=16,
        samples_per_epoch=160,
        nb_epoch=20,
        val_size=.2,
        balance_classes=0.,
        all_neg_skip=False,
        pos_cls_weight=1.0,
        alpha=1.,
        l1_ratio=.5,
        init_lr=.01,
        lr_patience=2,
        es_patience=4,
        exam_tsv='./metadata/exams_metadata.tsv',
        img_tsv='./metadata/images_crosswalk.tsv',
        dl_state='./modelState/resnet50_288_best_model.h5',
        best_model='./modelState/enet_288_best_model.h5',
        final_model="NOSAVE"):

    # Read some env variables.
    random_seed = int(os.getenv('RANDOM_SEED', 12345))
    nb_worker = int(os.getenv('NUM_CPU_CORES', 4))
    gpu_count = int(os.getenv('NUM_GPU_DEVICES', 1))

    # Setup training and validation data.
    meta_man = DMMetaManager(exam_tsv=exam_tsv,
                             img_tsv=img_tsv,
                             img_folder=img_folder,
                             img_extension=img_extension)

    if multi_view:
        exam_list = meta_man.get_flatten_exam_list()
        exam_train, exam_val = train_test_split(
            exam_list,
            test_size=val_size,
            random_state=random_seed,
            stratify=meta_man.exam_labs(exam_list))
        val_size_ = len(exam_val) * 2  # L and R.
    else:
        img_list, lab_list = meta_man.get_flatten_img_list()
        img_train, img_val, lab_train, lab_val = train_test_split(
            img_list,
            lab_list,
            test_size=val_size,
            random_state=random_seed,
            stratify=lab_list)
        val_size_ = len(img_val)

    img_gen = DMImageDataGenerator(horizontal_flip=True, vertical_flip=True)
    if do_featurewise_norm:
        img_gen.featurewise_center = True
        img_gen.featurewise_std_normalization = True
        img_gen.mean = featurewise_mean
        img_gen.std = featurewise_std
    else:
        img_gen.samplewise_center = True
        img_gen.samplewise_std_normalization = True

    if multi_view:
        train_generator = img_gen.flow_from_exam_list(
            exam_train,
            target_size=(img_size[0], img_size[1]),
            batch_size=batch_size,
            balance_classes=balance_classes,
            all_neg_skip=all_neg_skip,
            shuffle=True,
            seed=random_seed,
            class_mode='binary')
        val_generator = img_gen.flow_from_exam_list(exam_val,
                                                    target_size=(img_size[0],
                                                                 img_size[1]),
                                                    batch_size=batch_size,
                                                    validation_mode=True,
                                                    class_mode='binary')
    else:
        train_generator = img_gen.flow_from_img_list(
            img_train,
            lab_train,
            target_size=(img_size[0], img_size[1]),
            batch_size=batch_size,
            balance_classes=balance_classes,
            all_neg_skip=all_neg_skip,
            shuffle=True,
            seed=random_seed,
            class_mode='binary')
        val_generator = img_gen.flow_from_img_list(img_val,
                                                   lab_val,
                                                   target_size=(img_size[0],
                                                                img_size[1]),
                                                   batch_size=batch_size,
                                                   validation_mode=True,
                                                   class_mode='binary')

    # Deep learning model.
    dl_model = load_model(dl_state,
                          custom_objects={
                              'sensitivity': DMMetrics.sensitivity,
                              'specificity': DMMetrics.specificity
                          })
    # Dummy compilation to turn off the "uncompiled" error when model was run on multi-GPUs.
    # dl_model.compile(optimizer='sgd', loss='binary_crossentropy')
    reprlayer_model = Model(input=dl_model.input,
                            output=dl_model.get_layer(index=-2).output)
    if gpu_count > 1:
        reprlayer_model = make_parallel(reprlayer_model, gpu_count)

    # Setup test data in RAM.
    X_test, y_test = dlrepr_generator(reprlayer_model, val_generator,
                                      val_size_)
    # import pdb; pdb.set_trace()

    # Evaluat DL model on the test data.
    val_generator.reset()
    dl_test_pred = dl_model.predict_generator(val_generator,
                                              val_samples=val_size_,
                                              nb_worker=1,
                                              pickle_safe=False)
    # Set nb_worker to >1 can cause:
    # either inconsistent result when pickle_safe is False,
    #     or broadcasting error when pickle_safe is True.
    # This seems to be a Keras bug!!
    # Further note: the broadcasting error may only happen when val_size_
    # is not divisible by batch_size.
    try:
        dl_auc = roc_auc_score(y_test, dl_test_pred)
        dl_loss = log_loss(y_test, dl_test_pred)
    except ValueError:
        dl_auc = 0.
        dl_loss = np.inf
    print "\nAUROC by the DL model: %.4f, loss: %.4f" % (dl_auc, dl_loss)
    # import pdb; pdb.set_trace()

    # Elastic net training.
    target_classes = np.array([0, 1])
    sgd_clf = SGDClassifier(loss='log',
                            penalty='elasticnet',
                            alpha=alpha,
                            l1_ratio=l1_ratio,
                            verbose=0,
                            n_jobs=nb_worker,
                            learning_rate='constant',
                            eta0=init_lr,
                            random_state=random_seed,
                            class_weight={
                                0: 1.0,
                                1: pos_cls_weight
                            })
    curr_lr = init_lr
    best_epoch = 0
    best_auc = 0.
    min_loss = np.inf
    min_loss_epoch = 0
    for epoch in xrange(nb_epoch):
        samples_seen = 0
        X_list = []
        y_list = []
        epoch_start = time.time()
        while samples_seen < samples_per_epoch:
            X, y = next(train_generator)
            X_repr = reprlayer_model.predict_on_batch(X)
            sgd_clf.partial_fit(X_repr, y, classes=target_classes)
            samples_seen += len(y)
            X_list.append(X_repr)
            y_list.append(y)
        # The training X, y are expected to change for each epoch due to
        # image random sampling and class balancing.
        X_train_epo = np.concatenate(X_list)
        y_train_epo = np.concatenate(y_list)
        # End of epoch summary.
        pred_prob = sgd_clf.predict_proba(X_test)[:, 1]
        train_prob = sgd_clf.predict_proba(X_train_epo)[:, 1]
        try:
            auc = roc_auc_score(y_test, pred_prob)
            crossentropy_loss = log_loss(y_test, pred_prob)
        except ValueError:
            auc = 0.
            crossentropy_loss = np.inf
        try:
            train_loss = log_loss(y_train_epo, train_prob)
        except ValueError:
            train_loss = np.inf
        wei_sparseness = np.mean(sgd_clf.coef_ == 0)
        epoch_span = time.time() - epoch_start
        print ("%ds - Epoch=%d, auc=%.4f, train_loss=%.4f, test_loss=%.4f, "
               "weight sparsity=%.4f") % \
            (epoch_span, epoch + 1, auc, train_loss, crossentropy_loss,
             wei_sparseness)
        # Model checkpoint, reducing learning rate and early stopping.
        if auc > best_auc:
            best_epoch = epoch + 1
            best_auc = auc
            if best_model != "NOSAVE":
                with open(best_model, 'w') as best_state:
                    pickle.dump(sgd_clf, best_state)
        if crossentropy_loss < min_loss:
            min_loss = crossentropy_loss
            min_loss_epoch = epoch + 1
        else:
            if epoch + 1 - min_loss_epoch >= es_patience:
                print 'Early stopping criterion has reached. Stop training.'
                break
            if epoch + 1 - min_loss_epoch >= lr_patience:
                curr_lr *= .1
                sgd_clf.set_params(eta0=curr_lr)
                print "Reducing learning rate to: %s" % (curr_lr)
    # End of training summary
    print ">>> Found best AUROC: %.4f at epoch: %d, saved to: %s <<<" % \
        (best_auc, best_epoch, best_model)
    print ">>> Found best val loss: %.4f at epoch: %d. <<<" % \
        (min_loss, min_loss_epoch)
    #### Save elastic net model!! ####
    if final_model != "NOSAVE":
        with open(final_model, 'w') as final_state:
            pickle.dump(sgd_clf, final_state)
Example #15
0
def run(img_folder, dl_state, fprop_mode=False,
        img_size=(1152, 896), img_height=None, img_scale=None, 
        rescale_factor=None,
        equalize_hist=False, featurewise_center=False, featurewise_mean=71.8,
        net='vgg19', batch_size=128, patch_size=256, stride=8,
        avg_pool_size=(7, 7), hm_strides=(1, 1),
        pat_csv='./full_img/pat.csv', pat_list=None,
        out='./output/prob_heatmap.pkl'):
    '''Sweep mammograms with trained DL model to create prob heatmaps
    '''
    # Read some env variables.
    random_seed = int(os.getenv('RANDOM_SEED', 12345))
    rng = RandomState(random_seed)  # an rng used across board.
    gpu_count = int(os.getenv('NUM_GPU_DEVICES', 1))

    # Create image generator.
    imgen = DMImageDataGenerator(featurewise_center=featurewise_center)
    imgen.mean = featurewise_mean

    # Get image and label lists.
    df = pd.read_csv(pat_csv, header=0)
    df = df.set_index(['patient_id', 'side'])
    df.sort_index(inplace=True)
    if pat_list is not None:
        pat_ids = pd.read_csv(pat_list, header=0).values.ravel()
        pat_ids = pat_ids.tolist()
        print ("Read %d patient IDs" % (len(pat_ids)))
        df = df.loc[pat_ids]

    # Load DL model, preprocess.
    print ("Load patch classifier:", dl_state)
    sys.stdout.flush()
    dl_model, preprocess_input, _ = get_dl_model(net, resume_from=dl_state)
    if fprop_mode:
        dl_model = add_top_layers(dl_model, img_size, patch_net=net, 
                                  avg_pool_size=avg_pool_size, 
                                  return_heatmap=True, hm_strides=hm_strides)
    if gpu_count > 1:
        print ("Make the model parallel on %d GPUs" % (gpu_count))
        sys.stdout.flush()
        dl_model, _ = make_parallel(dl_model, gpu_count)
        parallelized = True
    else:
        parallelized = False
    if featurewise_center:
        preprocess_input = None

    # Sweep the whole images and classify patches.
    def const_filename(pat, side, view):
        basename = '_'.join([pat, side, view]) + '.png'
        return os.path.join(img_folder, basename)

    print ("Generate prob heatmaps")
    sys.stdout.flush()
    heatmaps = []
    cases_seen = 0
    nb_cases = len(df.index.unique())
    for i, (pat,side) in enumerate(df.index.unique()):
        ## DEBUG ##
        #if i >= 10:
        #    break
        ## DEBUG ##
        cancer = df.loc[pat].loc[side]['cancer']
        cc_fn = const_filename(pat, side, 'CC')
        if os.path.isfile(cc_fn):
            if fprop_mode:
                cc_x = read_img_for_pred(
                    cc_fn, equalize_hist=equalize_hist, data_format=data_format,
                    dup_3_channels=True, 
                    transformer=imgen.random_transform,
                    standardizer=imgen.standardize,
                    target_size=img_size, target_scale=img_scale,
                    rescale_factor=rescale_factor)
                cc_x = cc_x.reshape((1,) + cc_x.shape)
                cc_hm = dl_model.predict_on_batch(cc_x)[0]
                # import pdb; pdb.set_trace()
            else:
                cc_hm = get_prob_heatmap(
                    cc_fn, img_height, img_scale, patch_size, stride, 
                    dl_model, batch_size, featurewise_center=featurewise_center, 
                    featurewise_mean=featurewise_mean, preprocess=preprocess_input, 
                    parallelized=parallelized, equalize_hist=equalize_hist)
        else:
            cc_hm = None
        mlo_fn = const_filename(pat, side, 'MLO')
        if os.path.isfile(mlo_fn):
            if fprop_mode:
                mlo_x = read_img_for_pred(
                    mlo_fn, equalize_hist=equalize_hist, data_format=data_format,
                    dup_3_channels=True, 
                    transformer=imgen.random_transform,
                    standardizer=imgen.standardize,
                    target_size=img_size, target_scale=img_scale,
                    rescale_factor=rescale_factor)
                mlo_x = mlo_x.reshape((1,) + mlo_x.shape)
                mlo_hm = dl_model.predict_on_batch(mlo_x)[0]
            else:
                mlo_hm = get_prob_heatmap(
                    mlo_fn, img_height, img_scale, patch_size, stride, 
                    dl_model, batch_size, featurewise_center=featurewise_center, 
                    featurewise_mean=featurewise_mean, preprocess=preprocess_input, 
                    parallelized=parallelized, equalize_hist=equalize_hist)
        else:
            mlo_hm = None
        heatmaps.append({'patient_id':pat, 'side':side, 'cancer':cancer, 
                         'cc':cc_hm, 'mlo':mlo_hm})
        print ("scored %d/%d cases" % (i + 1, nb_cases))
        sys.stdout.flush()
    print ("Done.")

    # Save the result.
    print ("Saving result to external files.",)
    sys.stdout.flush()
    pickle.dump(heatmaps, open(out, 'w'))
    print ("Done.")
def run(train_dir, val_dir, test_dir, patch_model_state=None, resume_from=None,
        img_size=[1152, 896], img_scale=None, rescale_factor=None,
        featurewise_center=True, featurewise_mean=52.16, 
        equalize_hist=False, augmentation=True,
        class_list=['neg', 'pos'], patch_net='resnet50',
        block_type='resnet', top_depths=[512, 512], top_repetitions=[3, 3], 
        bottleneck_enlarge_factor=4, 
        add_heatmap=False, avg_pool_size=[7, 7], 
        add_conv=True, add_shortcut=False,
        hm_strides=(1,1), hm_pool_size=(5,5),
        fc_init_units=64, fc_layers=2,
        top_layer_nb=None,
        batch_size=64, train_bs_multiplier=.5, 
        nb_epoch=5, all_layer_epochs=20,
        load_val_ram=False, load_train_ram=False,
        weight_decay=.0001, hidden_dropout=.0, 
        weight_decay2=.0001, hidden_dropout2=.0, 
        optim='sgd', init_lr=.01, lr_patience=10, es_patience=25,
        auto_batch_balance=False, pos_cls_weight=1.0, neg_cls_weight=1.0,
        all_layer_multiplier=.1,
        best_model='./modelState/image_clf.h5',
        final_model="NOSAVE"):
    '''Train a deep learning model for image classifications
    '''

    # ======= Environmental variables ======== #
    random_seed = int(os.getenv('RANDOM_SEED', 12345))
    nb_worker = int(os.getenv('NUM_CPU_CORES', 4))
    gpu_count = int(os.getenv('NUM_GPU_DEVICES', 1))

    # ========= Image generator ============== #
    if featurewise_center:
        train_imgen = DMImageDataGenerator(featurewise_center=True)
        val_imgen = DMImageDataGenerator(featurewise_center=True)
        test_imgen = DMImageDataGenerator(featurewise_center=True)
        train_imgen.mean = featurewise_mean
        val_imgen.mean = featurewise_mean
        test_imgen.mean = featurewise_mean
    else:
        train_imgen = DMImageDataGenerator()
        val_imgen = DMImageDataGenerator()
        test_imgen = DMImageDataGenerator()

    # Add augmentation options.
    if augmentation:
        train_imgen.horizontal_flip = True 
        train_imgen.vertical_flip = True
        train_imgen.rotation_range = 25.  # in degree.
        train_imgen.shear_range = .2  # in radians.
        train_imgen.zoom_range = [.8, 1.2]  # in proportion.
        train_imgen.channel_shift_range = 20.  # in pixel intensity values.

    # ================= Model creation ============== #
    if resume_from is not None:
        image_model = load_model(resume_from, compile=False)
    else:
        patch_model = load_model(patch_model_state, compile=False)
        image_model, top_layer_nb = add_top_layers(
            patch_model, img_size, patch_net, block_type, 
            top_depths, top_repetitions, bottleneck_org,
            nb_class=len(class_list), shortcut_with_bn=True, 
            bottleneck_enlarge_factor=bottleneck_enlarge_factor,
            dropout=hidden_dropout, weight_decay=weight_decay,
            add_heatmap=add_heatmap, avg_pool_size=avg_pool_size,
            add_conv=add_conv, add_shortcut=add_shortcut,
            hm_strides=hm_strides, hm_pool_size=hm_pool_size, 
            fc_init_units=fc_init_units, fc_layers=fc_layers)
    if gpu_count > 1:
        image_model, org_model = make_parallel(image_model, gpu_count)
    else:
        org_model = image_model

    # ============ Train & validation set =============== #
    train_bs = int(batch_size*train_bs_multiplier)
    if patch_net != 'yaroslav':
        dup_3_channels = True
    else:
        dup_3_channels = False
    if load_train_ram:
        raw_imgen = DMImageDataGenerator()
        print "Create generator for raw train set"
        raw_generator = raw_imgen.flow_from_directory(
            train_dir, target_size=img_size, target_scale=img_scale, 
            rescale_factor=rescale_factor,
            equalize_hist=equalize_hist, dup_3_channels=dup_3_channels,
            classes=class_list, class_mode='categorical', 
            batch_size=train_bs, shuffle=False)
        print "Loading raw train set into RAM.",
        sys.stdout.flush()
        raw_set = load_dat_ram(raw_generator, raw_generator.nb_sample)
        print "Done."; sys.stdout.flush()
        print "Create generator for train set"
        train_generator = train_imgen.flow(
            raw_set[0], raw_set[1], batch_size=train_bs, 
            auto_batch_balance=auto_batch_balance, 
            shuffle=True, seed=random_seed)
    else:
        print "Create generator for train set"
        train_generator = train_imgen.flow_from_directory(
            train_dir, target_size=img_size, target_scale=img_scale,
            rescale_factor=rescale_factor,
            equalize_hist=equalize_hist, dup_3_channels=dup_3_channels,
            classes=class_list, class_mode='categorical', 
            auto_batch_balance=auto_batch_balance, batch_size=train_bs, 
            shuffle=True, seed=random_seed)

    print "Create generator for val set"
    validation_set = val_imgen.flow_from_directory(
        val_dir, target_size=img_size, target_scale=img_scale,
        rescale_factor=rescale_factor,
        equalize_hist=equalize_hist, dup_3_channels=dup_3_channels,
        classes=class_list, class_mode='categorical', 
        batch_size=batch_size, shuffle=False)
    sys.stdout.flush()
    if load_val_ram:
        print "Loading validation set into RAM.",
        sys.stdout.flush()
        validation_set = load_dat_ram(validation_set, validation_set.nb_sample)
        print "Done."; sys.stdout.flush()

    # ==================== Model training ==================== #
    # Do 2-stage training.
    train_batches = int(train_generator.nb_sample/train_bs) + 1
    if isinstance(validation_set, tuple):
        val_samples = len(validation_set[0])
    else:
        val_samples = validation_set.nb_sample
    validation_steps = int(val_samples/batch_size)
    #### DEBUG ####
    # train_batches = 1
    # val_samples = batch_size*5
    # validation_steps = 5
    #### DEBUG ####
    if load_val_ram:
        auc_checkpointer = DMAucModelCheckpoint(
            best_model, validation_set, batch_size=batch_size)
    else:
        auc_checkpointer = DMAucModelCheckpoint(
            best_model, validation_set, test_samples=val_samples)
    # import pdb; pdb.set_trace()
    image_model, loss_hist, acc_hist = do_2stage_training(
        image_model, org_model, train_generator, validation_set, validation_steps, 
        best_model, train_batches, top_layer_nb, nb_epoch=nb_epoch,
        all_layer_epochs=all_layer_epochs,
        optim=optim, init_lr=init_lr, 
        all_layer_multiplier=all_layer_multiplier,
        es_patience=es_patience, lr_patience=lr_patience, 
        auto_batch_balance=auto_batch_balance, 
        pos_cls_weight=pos_cls_weight, neg_cls_weight=neg_cls_weight,
        nb_worker=nb_worker, auc_checkpointer=auc_checkpointer,
        weight_decay=weight_decay, hidden_dropout=hidden_dropout,
        weight_decay2=weight_decay2, hidden_dropout2=hidden_dropout2,)

    # Training report.
    if len(loss_hist) > 0:
        min_loss_locs, = np.where(loss_hist == min(loss_hist))
        best_val_loss = loss_hist[min_loss_locs[0]]
        best_val_accuracy = acc_hist[min_loss_locs[0]]
        print "\n==== Training summary ===="
        print "Minimum val loss achieved at epoch:", min_loss_locs[0] + 1
        print "Best val loss:", best_val_loss
        print "Best val accuracy:", best_val_accuracy

    if final_model != "NOSAVE":
        image_model.save(final_model)

    # ==== Predict on test set ==== #
    print "\n==== Predicting on test set ===="
    test_generator = test_imgen.flow_from_directory(
        test_dir, target_size=img_size, target_scale=img_scale,
        rescale_factor=rescale_factor,
        equalize_hist=equalize_hist, dup_3_channels=dup_3_channels, 
        classes=class_list, class_mode='categorical', batch_size=batch_size, 
        shuffle=False)
    test_samples = test_generator.nb_sample
    #### DEBUG ####
    # test_samples = 5
    #### DEBUG ####
    print "Test samples =", test_samples
    print "Load saved best model:", best_model + '.',
    sys.stdout.flush()
    org_model.load_weights(best_model)
    print "Done."
    # test_steps = int(test_generator.nb_sample/batch_size)
    # test_res = image_model.evaluate_generator(
    #     test_generator, test_steps, nb_worker=nb_worker, 
    #     pickle_safe=True if nb_worker > 1 else False)
    test_auc = DMAucModelCheckpoint.calc_test_auc(
        test_generator, image_model, test_samples=test_samples)
    print "AUROC on test set:", test_auc
def run(train_dir,
        val_dir,
        test_dir,
        img_size=[256, 256],
        img_scale=None,
        rescale_factor=None,
        featurewise_center=True,
        featurewise_mean=59.6,
        equalize_hist=True,
        augmentation=False,
        class_list=['background', 'malignant', 'benign'],
        batch_size=64,
        train_bs_multiplier=.5,
        nb_epoch=5,
        top_layer_epochs=10,
        all_layer_epochs=20,
        load_val_ram=False,
        load_train_ram=False,
        net='resnet50',
        use_pretrained=True,
        nb_init_filter=32,
        init_filter_size=5,
        init_conv_stride=2,
        pool_size=2,
        pool_stride=2,
        weight_decay=.0001,
        weight_decay2=.0001,
        alpha=.0001,
        l1_ratio=.0,
        inp_dropout=.0,
        hidden_dropout=.0,
        hidden_dropout2=.0,
        optim='sgd',
        init_lr=.01,
        lr_patience=10,
        es_patience=25,
        resume_from=None,
        auto_batch_balance=False,
        pos_cls_weight=1.0,
        neg_cls_weight=1.0,
        top_layer_nb=None,
        top_layer_multiplier=.1,
        all_layer_multiplier=.01,
        best_model='./modelState/patch_clf.h5',
        final_model="NOSAVE"):
    '''Train a deep learning model for patch classifications
    '''
    best_model_dir = os.path.dirname(best_model)
    if not os.path.exists(best_model_dir):
        os.makedirs(best_model_dir)
    if final_model != "NOSAVE":
        final_model_dir = os.path.dirname(final_model)
        if not os.path.exists(final_model_dir):
            os.makedirs(final_model_dir)

    # ======= Environmental variables ======== #
    random_seed = int(os.getenv('RANDOM_SEED', 12345))
    nb_worker = int(os.getenv('NUM_CPU_CORES', 4))
    gpu_count = int(os.getenv('NUM_GPU_DEVICES', 1))

    # ========= Image generator ============== #
    if featurewise_center:
        print "Using feature-wise centering, mean:", featurewise_mean
        train_imgen = DMImageDataGenerator(featurewise_center=True)
        val_imgen = DMImageDataGenerator(featurewise_center=True)
        test_imgen = DMImageDataGenerator(featurewise_center=True)
        train_imgen.mean = featurewise_mean
        val_imgen.mean = featurewise_mean
        test_imgen.mean = featurewise_mean
    else:
        train_imgen = DMImageDataGenerator()
        val_imgen = DMImageDataGenerator()
        test_imgen = DMImageDataGenerator()

    # Add augmentation options.
    if augmentation:
        train_imgen.horizontal_flip = True
        train_imgen.vertical_flip = True
        train_imgen.rotation_range = 25.  # in degree.
        train_imgen.shear_range = .2  # in radians.
        train_imgen.zoom_range = [.8, 1.2]  # in proportion.
        train_imgen.channel_shift_range = 20.  # in pixel intensity values.

    # ================= Model creation ============== #
    model, preprocess_input, top_layer_nb = get_dl_model(
        net,
        nb_class=len(class_list),
        use_pretrained=use_pretrained,
        resume_from=resume_from,
        img_size=img_size,
        top_layer_nb=top_layer_nb,
        weight_decay=weight_decay,
        hidden_dropout=hidden_dropout,
        nb_init_filter=nb_init_filter,
        init_filter_size=init_filter_size,
        init_conv_stride=init_conv_stride,
        pool_size=pool_size,
        pool_stride=pool_stride,
        alpha=alpha,
        l1_ratio=l1_ratio,
        inp_dropout=inp_dropout)
    if featurewise_center:
        preprocess_input = None
    if gpu_count > 1:
        model, org_model = make_parallel(model, gpu_count)
    else:
        org_model = model

    # ============ Train & validation set =============== #
    train_bs = int(batch_size * train_bs_multiplier)
    if net != 'yaroslav':
        dup_3_channels = True
    else:
        dup_3_channels = False
    if load_train_ram:
        raw_imgen = DMImageDataGenerator()
        print "Create generator for raw train set"
        raw_generator = raw_imgen.flow_from_directory(
            train_dir,
            target_size=img_size,
            target_scale=img_scale,
            rescale_factor=rescale_factor,
            equalize_hist=equalize_hist,
            dup_3_channels=dup_3_channels,
            classes=class_list,
            class_mode='categorical',
            batch_size=train_bs,
            shuffle=False)
        print "Loading raw train set into RAM.",
        sys.stdout.flush()
        raw_set = load_dat_ram(raw_generator, raw_generator.nb_sample)
        print "Done."
        sys.stdout.flush()
        print "Create generator for train set"
        train_generator = train_imgen.flow(
            raw_set[0],
            raw_set[1],
            batch_size=train_bs,
            auto_batch_balance=auto_batch_balance,
            preprocess=preprocess_input,
            shuffle=True,
            seed=random_seed)
    else:
        print "Create generator for train set"
        train_generator = train_imgen.flow_from_directory(
            train_dir,
            target_size=img_size,
            target_scale=img_scale,
            rescale_factor=rescale_factor,
            equalize_hist=equalize_hist,
            dup_3_channels=dup_3_channels,
            classes=class_list,
            class_mode='categorical',
            auto_batch_balance=auto_batch_balance,
            batch_size=train_bs,
            preprocess=preprocess_input,
            shuffle=True,
            seed=random_seed)

    print "Create generator for val set"
    validation_set = val_imgen.flow_from_directory(
        val_dir,
        target_size=img_size,
        target_scale=img_scale,
        rescale_factor=rescale_factor,
        equalize_hist=equalize_hist,
        dup_3_channels=dup_3_channels,
        classes=class_list,
        class_mode='categorical',
        batch_size=batch_size,
        preprocess=preprocess_input,
        shuffle=False)
    sys.stdout.flush()
    if load_val_ram:
        print "Loading validation set into RAM.",
        sys.stdout.flush()
        validation_set = load_dat_ram(validation_set, validation_set.nb_sample)
        print "Done."
        sys.stdout.flush()

    # ==================== Model training ==================== #
    # Do 3-stage training.
    train_batches = int(train_generator.nb_sample / train_bs) + 1
    if isinstance(validation_set, tuple):
        val_samples = len(validation_set[0])
    else:
        val_samples = validation_set.nb_sample
    validation_steps = int(val_samples / batch_size)
    #### DEBUG ####
    # val_samples = 100
    #### DEBUG ####
    # import pdb; pdb.set_trace()
    model, loss_hist, acc_hist = do_3stage_training(
        model,
        org_model,
        train_generator,
        validation_set,
        validation_steps,
        best_model,
        train_batches,
        top_layer_nb,
        net,
        nb_epoch=nb_epoch,
        top_layer_epochs=top_layer_epochs,
        all_layer_epochs=all_layer_epochs,
        use_pretrained=use_pretrained,
        optim=optim,
        init_lr=init_lr,
        top_layer_multiplier=top_layer_multiplier,
        all_layer_multiplier=all_layer_multiplier,
        es_patience=es_patience,
        lr_patience=lr_patience,
        auto_batch_balance=auto_batch_balance,
        nb_class=len(class_list),
        pos_cls_weight=pos_cls_weight,
        neg_cls_weight=neg_cls_weight,
        nb_worker=nb_worker,
        weight_decay2=weight_decay2,
        hidden_dropout2=hidden_dropout2)

    # Training report.
    if len(loss_hist) > 0:
        min_loss_locs, = np.where(loss_hist == min(loss_hist))
        best_val_loss = loss_hist[min_loss_locs[0]]
        best_val_accuracy = acc_hist[min_loss_locs[0]]
        print "\n==== Training summary ===="
        print "Minimum val loss achieved at epoch:", min_loss_locs[0] + 1
        print "Best val loss:", best_val_loss
        print "Best val accuracy:", best_val_accuracy

    if final_model != "NOSAVE":
        model.save(final_model)

    # ==== Predict on test set ==== #
    print "\n==== Predicting on test set ===="
    test_generator = test_imgen.flow_from_directory(
        test_dir,
        target_size=img_size,
        target_scale=img_scale,
        rescale_factor=rescale_factor,
        equalize_hist=equalize_hist,
        dup_3_channels=dup_3_channels,
        classes=class_list,
        class_mode='categorical',
        batch_size=batch_size,
        preprocess=preprocess_input,
        shuffle=False)
    if test_generator.nb_sample:
        print "Test samples =", test_generator.nb_sample
        print "Load saved best model:", best_model + '.',
        sys.stdout.flush()
        org_model.load_weights(best_model)
        print "Done."
        test_steps = int(test_generator.nb_sample / batch_size)
        #### DEBUG ####
        # test_samples = 10
        #### DEBUG ####
        test_res = model.evaluate_generator(
            test_generator,
            test_steps,
            nb_worker=nb_worker,
            pickle_safe=True if nb_worker > 1 else False)
        print "Evaluation result on test set:", test_res
    else:
        print "Skip testing because no test sample is found."
Example #18
0
def run(train_dir,
        val_dir,
        test_dir,
        img_size=[256, 256],
        img_scale=255.,
        featurewise_center=True,
        featurewise_mean=59.6,
        equalize_hist=True,
        augmentation=False,
        class_list=['background', 'malignant', 'benign'],
        batch_size=64,
        train_bs_multiplier=.5,
        nb_epoch=5,
        top_layer_epochs=10,
        all_layer_epochs=20,
        load_val_ram=False,
        load_train_ram=False,
        net='resnet50',
        use_pretrained=True,
        nb_init_filter=32,
        init_filter_size=5,
        init_conv_stride=2,
        pool_size=2,
        pool_stride=2,
        weight_decay=.0001,
        weight_decay2=.0001,
        bias_multiplier=.1,
        alpha=.0001,
        l1_ratio=.0,
        inp_dropout=.0,
        hidden_dropout=.0,
        hidden_dropout2=.0,
        optim='sgd',
        init_lr=.01,
        lr_patience=10,
        es_patience=25,
        resume_from=None,
        auto_batch_balance=False,
        pos_cls_weight=1.0,
        neg_cls_weight=1.0,
        top_layer_nb=None,
        top_layer_multiplier=.1,
        all_layer_multiplier=.01,
        best_model='./modelState/patch_clf.h5',
        final_model="NOSAVE"):
    '''Train a deep learning model for patch classifications
    '''

    # ======= Environmental variables ======== #
    random_seed = int(os.getenv('RANDOM_SEED', 12345))
    nb_worker = int(os.getenv('NUM_CPU_CORES', 4))
    gpu_count = int(os.getenv('NUM_GPU_DEVICES', 1))

    # ========= Image generator ============== #
    # if use_pretrained:  # use pretrained model's preprocessing.
    #     train_imgen = DMImageDataGenerator()
    #     val_imgen = DMImageDataGenerator()
    if featurewise_center:
        # fitgen = DMImageDataGenerator()
        # # Calculate pixel-level mean and std.
        # print "Create generator for mean and std fitting"
        # fit_patch_generator = fitgen.flow_from_directory(
        #     train_dir, target_size=img_size, target_scale=img_scale,
        #     classes=class_list, class_mode=None, batch_size=batch_size,
        #     shuffle=True, seed=random_seed)
        # sys.stdout.flush()
        # fit_X_lst = []
        # patches_seen = 0
        # while patches_seen < fit_size:
        #     X = fit_patch_generator.next()
        #     fit_X_lst.append(X)
        #     patches_seen += len(X)
        # fit_X_arr = np.concatenate(fit_X_lst)
        train_imgen = DMImageDataGenerator(featurewise_center=True)
        # featurewise_std_normalization=True)
        val_imgen = DMImageDataGenerator(featurewise_center=True)
        test_imgen = DMImageDataGenerator(featurewise_center=True)
        # featurewise_std_normalization=True)
        # train_imgen.fit(fit_X_arr)
        # print "Found mean=%.2f, std=%.2f" % (train_imgen.mean, train_imgen.std)
        # sys.stdout.flush()
        train_imgen.mean = featurewise_mean
        val_imgen.mean = featurewise_mean
        test_imgen.mean = featurewise_mean
        # del fit_X_arr, fit_X_lst
    else:
        train_imgen = DMImageDataGenerator()
        val_imgen = DMImageDataGenerator()
        test_imgen = DMImageDataGenerator()
        # train_imgen = DMImageDataGenerator(
        #     samplewise_center=True,
        #     samplewise_std_normalization=True)
        # val_imgen = DMImageDataGenerator(
        #     samplewise_center=True,
        #     samplewise_std_normalization=True)

    # Add augmentation options.
    if augmentation:
        train_imgen.horizontal_flip = True
        train_imgen.vertical_flip = True
        train_imgen.rotation_range = 45.
        train_imgen.shear_range = np.pi / 8.

    # ================= Model creation ============== #
    model, preprocess_input, top_layer_nb = get_dl_model(
        net,
        nb_class=len(class_list),
        use_pretrained=use_pretrained,
        resume_from=resume_from,
        img_size=img_size,
        top_layer_nb=top_layer_nb,
        weight_decay=weight_decay,
        bias_multiplier=bias_multiplier,
        hidden_dropout=hidden_dropout,
        nb_init_filter=nb_init_filter,
        init_filter_size=init_filter_size,
        init_conv_stride=init_conv_stride,
        pool_size=pool_size,
        pool_stride=pool_stride,
        alpha=alpha,
        l1_ratio=l1_ratio,
        inp_dropout=inp_dropout)
    if featurewise_center:
        preprocess_input = None
    if gpu_count > 1:
        model, org_model = make_parallel(model, gpu_count)
    else:
        org_model = model

    # ============ Train & validation set =============== #
    train_bs = int(batch_size * train_bs_multiplier)
    if use_pretrained:
        dup_3_channels = True
    else:
        dup_3_channels = False
    if load_train_ram:
        raw_imgen = DMImageDataGenerator()
        print "Create generator for raw train set"
        raw_generator = raw_imgen.flow_from_directory(
            train_dir,
            target_size=img_size,
            target_scale=img_scale,
            equalize_hist=equalize_hist,
            dup_3_channels=dup_3_channels,
            classes=class_list,
            class_mode='categorical',
            batch_size=train_bs,
            shuffle=False)
        print "Loading raw train set into RAM.",
        sys.stdout.flush()
        raw_set = load_dat_ram(raw_generator, raw_generator.nb_sample)
        print "Done."
        sys.stdout.flush()
        print "Create generator for train set"
        train_generator = train_imgen.flow(
            raw_set[0],
            raw_set[1],
            batch_size=train_bs,
            auto_batch_balance=auto_batch_balance,
            preprocess=preprocess_input,
            shuffle=True,
            seed=random_seed)
    else:
        print "Create generator for train set"
        train_generator = train_imgen.flow_from_directory(
            train_dir,
            target_size=img_size,
            target_scale=img_scale,
            equalize_hist=equalize_hist,
            dup_3_channels=dup_3_channels,
            classes=class_list,
            class_mode='categorical',
            auto_batch_balance=auto_batch_balance,
            batch_size=train_bs,
            preprocess=preprocess_input,
            shuffle=True,
            seed=random_seed)
    # import pdb; pdb.set_trace()

    print "Create generator for val set"
    validation_set = val_imgen.flow_from_directory(
        val_dir,
        target_size=img_size,
        target_scale=img_scale,
        equalize_hist=equalize_hist,
        dup_3_channels=dup_3_channels,
        classes=class_list,
        class_mode='categorical',
        batch_size=batch_size,
        preprocess=preprocess_input,
        shuffle=False)
    sys.stdout.flush()
    if load_val_ram:
        print "Loading validation set into RAM.",
        sys.stdout.flush()
        validation_set = load_dat_ram(validation_set, validation_set.nb_sample)
        print "Done."
        sys.stdout.flush()

    # ==================== Model training ==================== #
    # Callbacks and class weight.
    early_stopping = EarlyStopping(monitor='val_loss',
                                   patience=es_patience,
                                   verbose=1)
    checkpointer = ModelCheckpoint(best_model,
                                   monitor='val_acc',
                                   verbose=1,
                                   save_best_only=True)
    stdout_flush = DMFlush()
    callbacks = [early_stopping, checkpointer, stdout_flush]
    if optim == 'sgd':
        reduce_lr = ReduceLROnPlateau(monitor='val_loss',
                                      factor=0.5,
                                      patience=lr_patience,
                                      verbose=1)
        callbacks.append(reduce_lr)
    if auto_batch_balance:
        class_weight = None
    elif len(class_list) == 2:
        class_weight = {0: 1.0, 1: pos_cls_weight}
    elif len(class_list) == 3:
        class_weight = {0: 1.0, 1: pos_cls_weight, 2: neg_cls_weight}
    else:
        class_weight = None
    # Do 3-stage training.
    train_batches = int(train_generator.nb_sample / train_bs) + 1
    samples_per_epoch = train_bs * train_batches
    #### DEBUG ####
    # samples_per_epoch = train_bs*10
    #### DEBUG ####
    if isinstance(validation_set, tuple):
        val_samples = len(validation_set[0])
    else:
        val_samples = validation_set.nb_sample
    #### DEBUG ####
    # val_samples = 100
    #### DEBUG ####
    model, loss_hist, acc_hist = do_3stage_training(
        model,
        org_model,
        train_generator,
        validation_set,
        val_samples,
        best_model,
        samples_per_epoch,
        top_layer_nb,
        net,
        nb_epoch=nb_epoch,
        top_layer_epochs=top_layer_epochs,
        all_layer_epochs=all_layer_epochs,
        use_pretrained=use_pretrained,
        optim=optim,
        init_lr=init_lr,
        top_layer_multiplier=top_layer_multiplier,
        all_layer_multiplier=all_layer_multiplier,
        es_patience=es_patience,
        lr_patience=lr_patience,
        auto_batch_balance=auto_batch_balance,
        pos_cls_weight=pos_cls_weight,
        neg_cls_weight=neg_cls_weight,
        nb_worker=nb_worker,
        weight_decay2=weight_decay2,
        bias_multiplier=bias_multiplier,
        hidden_dropout2=hidden_dropout2)

    # Training report.
    min_loss_locs, = np.where(loss_hist == min(loss_hist))
    best_val_loss = loss_hist[min_loss_locs[0]]
    best_val_accuracy = acc_hist[min_loss_locs[0]]
    print "\n==== Training summary ===="
    print "Minimum val loss achieved at epoch:", min_loss_locs[0] + 1
    print "Best val loss:", best_val_loss
    print "Best val accuracy:", best_val_accuracy

    if final_model != "NOSAVE":
        model.save(final_model)

    # ==== Predict on test set ==== #
    print "\n==== Predicting on test set ===="
    test_generator = test_imgen.flow_from_directory(
        test_dir,
        target_size=img_size,
        target_scale=img_scale,
        equalize_hist=equalize_hist,
        dup_3_channels=dup_3_channels,
        classes=class_list,
        class_mode='categorical',
        batch_size=batch_size,
        preprocess=preprocess_input,
        shuffle=False)
    print "Test samples =", test_generator.nb_sample
    print "Load saved best model:", best_model + '.',
    sys.stdout.flush()
    org_model.load_weights(best_model)
    print "Done."
    test_samples = test_generator.nb_sample
    #### DEBUG ####
    # test_samples = 10
    #### DEBUG ####
    test_res = model.evaluate_generator(
        test_generator,
        test_samples,
        nb_worker=nb_worker,
        pickle_safe=True if nb_worker > 1 else False)
    print "Evaluation result on test set:", test_res