Ejemplo n.º 1
0
def run(img_folder, dl_state, fprop_mode=False,
        img_size=(1152, 896), img_height=None, img_scale=None, 
        rescale_factor=None,
        equalize_hist=False, featurewise_center=False, featurewise_mean=71.8,
        net='vgg19', batch_size=128, patch_size=256, stride=8,
        avg_pool_size=(7, 7), hm_strides=(1, 1),
        pat_csv='./full_img/pat.csv', pat_list=None,
        out='./output/prob_heatmap.pkl'):
    '''Sweep mammograms with trained DL model to create prob heatmaps
    '''
    # Read some env variables.
    random_seed = int(os.getenv('RANDOM_SEED', 12345))
    rng = RandomState(random_seed)  # an rng used across board.
    gpu_count = int(os.getenv('NUM_GPU_DEVICES', 1))

    # Create image generator.
    imgen = DMImageDataGenerator(featurewise_center=featurewise_center)
    imgen.mean = featurewise_mean

    # Get image and label lists.
    df = pd.read_csv(pat_csv, header=0)
    df = df.set_index(['patient_id', 'side'])
    df.sort_index(inplace=True)
    if pat_list is not None:
        pat_ids = pd.read_csv(pat_list, header=0).values.ravel()
        pat_ids = pat_ids.tolist()
        print ("Read %d patient IDs" % (len(pat_ids)))
        df = df.loc[pat_ids]

    # Load DL model, preprocess.
    print ("Load patch classifier:", dl_state)
    sys.stdout.flush()
    dl_model, preprocess_input, _ = get_dl_model(net, resume_from=dl_state)
    if fprop_mode:
        dl_model = add_top_layers(dl_model, img_size, patch_net=net, 
                                  avg_pool_size=avg_pool_size, 
                                  return_heatmap=True, hm_strides=hm_strides)
    if gpu_count > 1:
        print ("Make the model parallel on %d GPUs" % (gpu_count))
        sys.stdout.flush()
        dl_model, _ = make_parallel(dl_model, gpu_count)
        parallelized = True
    else:
        parallelized = False
    if featurewise_center:
        preprocess_input = None

    # Sweep the whole images and classify patches.
    def const_filename(pat, side, view):
        basename = '_'.join([pat, side, view]) + '.png'
        return os.path.join(img_folder, basename)

    print ("Generate prob heatmaps")
    sys.stdout.flush()
    heatmaps = []
    cases_seen = 0
    nb_cases = len(df.index.unique())
    for i, (pat,side) in enumerate(df.index.unique()):
        ## DEBUG ##
        #if i >= 10:
        #    break
        ## DEBUG ##
        cancer = df.loc[pat].loc[side]['cancer']
        cc_fn = const_filename(pat, side, 'CC')
        if os.path.isfile(cc_fn):
            if fprop_mode:
                cc_x = read_img_for_pred(
                    cc_fn, equalize_hist=equalize_hist, data_format=data_format,
                    dup_3_channels=True, 
                    transformer=imgen.random_transform,
                    standardizer=imgen.standardize,
                    target_size=img_size, target_scale=img_scale,
                    rescale_factor=rescale_factor)
                cc_x = cc_x.reshape((1,) + cc_x.shape)
                cc_hm = dl_model.predict_on_batch(cc_x)[0]
                # import pdb; pdb.set_trace()
            else:
                cc_hm = get_prob_heatmap(
                    cc_fn, img_height, img_scale, patch_size, stride, 
                    dl_model, batch_size, featurewise_center=featurewise_center, 
                    featurewise_mean=featurewise_mean, preprocess=preprocess_input, 
                    parallelized=parallelized, equalize_hist=equalize_hist)
        else:
            cc_hm = None
        mlo_fn = const_filename(pat, side, 'MLO')
        if os.path.isfile(mlo_fn):
            if fprop_mode:
                mlo_x = read_img_for_pred(
                    mlo_fn, equalize_hist=equalize_hist, data_format=data_format,
                    dup_3_channels=True, 
                    transformer=imgen.random_transform,
                    standardizer=imgen.standardize,
                    target_size=img_size, target_scale=img_scale,
                    rescale_factor=rescale_factor)
                mlo_x = mlo_x.reshape((1,) + mlo_x.shape)
                mlo_hm = dl_model.predict_on_batch(mlo_x)[0]
            else:
                mlo_hm = get_prob_heatmap(
                    mlo_fn, img_height, img_scale, patch_size, stride, 
                    dl_model, batch_size, featurewise_center=featurewise_center, 
                    featurewise_mean=featurewise_mean, preprocess=preprocess_input, 
                    parallelized=parallelized, equalize_hist=equalize_hist)
        else:
            mlo_hm = None
        heatmaps.append({'patient_id':pat, 'side':side, 'cancer':cancer, 
                         'cc':cc_hm, 'mlo':mlo_hm})
        print ("scored %d/%d cases" % (i + 1, nb_cases))
        sys.stdout.flush()
    print ("Done.")

    # Save the result.
    print ("Saving result to external files.",)
    sys.stdout.flush()
    pickle.dump(heatmaps, open(out, 'w'))
    print ("Done.")
Ejemplo n.º 2
0
def run(img_folder, dl_state, fprop_mode=False,
        img_size=(1152, 896), img_height=None, img_scale=None, 
        rescale_factor=None,
        equalize_hist=False, featurewise_center=False, featurewise_mean=71.8,
        net='vgg19', batch_size=128, patch_size=256, stride=8,
        avg_pool_size=(7, 7), hm_strides=(1, 1),
        pat_csv='./full_img/pat.csv', pat_list=None,
        out='./output/prob_heatmap.pkl'):
    '''Sweep mammograms with trained DL model to create prob heatmaps
    '''
    # Read some env variables.
    random_seed = int(os.getenv('RANDOM_SEED', 12345))
    rng = RandomState(random_seed)  # an rng used across board.
    gpu_count = int(os.getenv('NUM_GPU_DEVICES', 1))

    # Create image generator.
    imgen = DMImageDataGenerator(featurewise_center=featurewise_center)
    imgen.mean = featurewise_mean

    # Get image and label lists.
    df = pd.read_csv(pat_csv, header=0)
    df = df.set_index(['patient_id', 'side'])
    df.sort_index(inplace=True)
    if pat_list is not None:
        pat_ids = pd.read_csv(pat_list, header=0).values.ravel()
        pat_ids = pat_ids.tolist()
        print "Read %d patient IDs" % (len(pat_ids))
        df = df.loc[pat_ids]

    # Load DL model, preprocess.
    print "Load patch classifier:", dl_state; sys.stdout.flush()
    dl_model, preprocess_input, _ = get_dl_model(net, resume_from=dl_state)
    if fprop_mode:
        dl_model = add_top_layers(dl_model, img_size, patch_net=net, 
                                  avg_pool_size=avg_pool_size, 
                                  return_heatmap=True, hm_strides=hm_strides)
    if gpu_count > 1:
        print "Make the model parallel on %d GPUs" % (gpu_count)
        sys.stdout.flush()
        dl_model, _ = make_parallel(dl_model, gpu_count)
        parallelized = True
    else:
        parallelized = False
    if featurewise_center:
        preprocess_input = None

    # Sweep the whole images and classify patches.
    def const_filename(pat, side, view):
        basename = '_'.join([pat, side, view]) + '.png'
        return os.path.join(img_folder, basename)

    print "Generate prob heatmaps"; sys.stdout.flush()
    heatmaps = []
    cases_seen = 0
    nb_cases = len(df.index.unique())
    for i, (pat,side) in enumerate(df.index.unique()):
        ## DEBUG ##
        #if i >= 10:
        #    break
        ## DEBUG ##
        cancer = df.loc[pat].loc[side]['cancer']
        cc_fn = const_filename(pat, side, 'CC')
        if os.path.isfile(cc_fn):
            if fprop_mode:
                cc_x = read_img_for_pred(
                    cc_fn, equalize_hist=equalize_hist, data_format=data_format,
                    dup_3_channels=True, 
                    transformer=imgen.random_transform,
                    standardizer=imgen.standardize,
                    target_size=img_size, target_scale=img_scale,
                    rescale_factor=rescale_factor)
                cc_x = cc_x.reshape((1,) + cc_x.shape)
                cc_hm = dl_model.predict_on_batch(cc_x)[0]
                # import pdb; pdb.set_trace()
            else:
                cc_hm = get_prob_heatmap(
                    cc_fn, img_height, img_scale, patch_size, stride, 
                    dl_model, batch_size, featurewise_center=featurewise_center, 
                    featurewise_mean=featurewise_mean, preprocess=preprocess_input, 
                    parallelized=parallelized, equalize_hist=equalize_hist)
        else:
            cc_hm = None
        mlo_fn = const_filename(pat, side, 'MLO')
        if os.path.isfile(mlo_fn):
            if fprop_mode:
                mlo_x = read_img_for_pred(
                    mlo_fn, equalize_hist=equalize_hist, data_format=data_format,
                    dup_3_channels=True, 
                    transformer=imgen.random_transform,
                    standardizer=imgen.standardize,
                    target_size=img_size, target_scale=img_scale,
                    rescale_factor=rescale_factor)
                mlo_x = mlo_x.reshape((1,) + mlo_x.shape)
                mlo_hm = dl_model.predict_on_batch(mlo_x)[0]
            else:
                mlo_hm = get_prob_heatmap(
                    mlo_fn, img_height, img_scale, patch_size, stride, 
                    dl_model, batch_size, featurewise_center=featurewise_center, 
                    featurewise_mean=featurewise_mean, preprocess=preprocess_input, 
                    parallelized=parallelized, equalize_hist=equalize_hist)
        else:
            mlo_hm = None
        heatmaps.append({'patient_id':pat, 'side':side, 'cancer':cancer, 
                         'cc':cc_hm, 'mlo':mlo_hm})
        print "scored %d/%d cases" % (i + 1, nb_cases)
        sys.stdout.flush()
    print "Done."

    # Save the result.
    print "Saving result to external files.",
    sys.stdout.flush()
    pickle.dump(heatmaps, open(out, 'w'))
    print "Done."
Ejemplo n.º 3
0
def run(train_dir,
        val_dir,
        test_dir,
        patch_model_state=None,
        resume_from=None,
        img_size=[1152, 896],
        img_scale=None,
        rescale_factor=None,
        featurewise_center=True,
        featurewise_mean=52.16,
        equalize_hist=False,
        augmentation=True,
        class_list=['neg', 'pos'],
        patch_net='resnet50',
        block_type='resnet',
        top_depths=[512, 512],
        top_repetitions=[3, 3],
        bottleneck_enlarge_factor=4,
        add_heatmap=False,
        avg_pool_size=[7, 7],
        add_conv=True,
        add_shortcut=False,
        hm_strides=(1, 1),
        hm_pool_size=(5, 5),
        fc_init_units=64,
        fc_layers=2,
        top_layer_nb=None,
        batch_size=64,
        train_bs_multiplier=.5,
        nb_epoch=5,
        all_layer_epochs=20,
        load_val_ram=False,
        load_train_ram=False,
        weight_decay=.0001,
        hidden_dropout=.0,
        weight_decay2=.0001,
        hidden_dropout2=.0,
        optim='sgd',
        init_lr=.01,
        lr_patience=10,
        es_patience=25,
        auto_batch_balance=False,
        pos_cls_weight=1.0,
        neg_cls_weight=1.0,
        all_layer_multiplier=.1,
        best_model='./modelState/image_clf.h5',
        final_model="NOSAVE"):
    '''Train a deep learning model for image classifications
    '''

    # ======= Environmental variables ======== #
    random_seed = int(os.getenv('RANDOM_SEED', 12345))
    nb_worker = int(os.getenv('NUM_CPU_CORES', 4))
    gpu_count = int(os.getenv('NUM_GPU_DEVICES', 1))

    # ========= Image generator ============== #
    if featurewise_center:
        train_imgen = DMImageDataGenerator(featurewise_center=True)
        val_imgen = DMImageDataGenerator(featurewise_center=True)
        test_imgen = DMImageDataGenerator(featurewise_center=True)
        train_imgen.mean = featurewise_mean
        val_imgen.mean = featurewise_mean
        test_imgen.mean = featurewise_mean
    else:
        train_imgen = DMImageDataGenerator()
        val_imgen = DMImageDataGenerator()
        test_imgen = DMImageDataGenerator()

    # Add augmentation options.
    if augmentation:
        train_imgen.horizontal_flip = True
        train_imgen.vertical_flip = True
        train_imgen.rotation_range = 25.  # in degree.
        train_imgen.shear_range = .2  # in radians.
        train_imgen.zoom_range = [.8, 1.2]  # in proportion.
        train_imgen.channel_shift_range = 20.  # in pixel intensity values.

    # ================= Model creation ============== #
    if resume_from is not None:
        image_model = load_model(resume_from, compile=False)
    else:
        patch_model = load_model(patch_model_state, compile=False)
        image_model, top_layer_nb = add_top_layers(
            patch_model,
            img_size,
            patch_net,
            block_type,
            top_depths,
            top_repetitions,
            bottleneck_org,
            nb_class=len(class_list),
            shortcut_with_bn=True,
            bottleneck_enlarge_factor=bottleneck_enlarge_factor,
            dropout=hidden_dropout,
            weight_decay=weight_decay,
            add_heatmap=add_heatmap,
            avg_pool_size=avg_pool_size,
            add_conv=add_conv,
            add_shortcut=add_shortcut,
            hm_strides=hm_strides,
            hm_pool_size=hm_pool_size,
            fc_init_units=fc_init_units,
            fc_layers=fc_layers)
    if gpu_count > 1:
        image_model, org_model = make_parallel(image_model, gpu_count)
    else:
        org_model = image_model

    # ============ Train & validation set =============== #
    train_bs = int(batch_size * train_bs_multiplier)
    dup_3_channels = True
    if load_train_ram:
        raw_imgen = DMImageDataGenerator()
        print "Create generator for raw train set"
        raw_generator = raw_imgen.flow_from_directory(
            train_dir,
            target_size=img_size,
            target_scale=img_scale,
            rescale_factor=rescale_factor,
            equalize_hist=equalize_hist,
            dup_3_channels=dup_3_channels,
            classes=class_list,
            class_mode='categorical',
            batch_size=train_bs,
            shuffle=False)
        print "Loading raw train set into RAM.",
        sys.stdout.flush()
        raw_set = load_dat_ram(raw_generator, raw_generator.nb_sample)
        print "Done."
        sys.stdout.flush()
        print "Create generator for train set"
        train_generator = train_imgen.flow(
            raw_set[0],
            raw_set[1],
            batch_size=train_bs,
            auto_batch_balance=auto_batch_balance,
            shuffle=True,
            seed=random_seed)
    else:
        print "Create generator for train set"
        train_generator = train_imgen.flow_from_directory(
            train_dir,
            target_size=img_size,
            target_scale=img_scale,
            rescale_factor=rescale_factor,
            equalize_hist=equalize_hist,
            dup_3_channels=dup_3_channels,
            classes=class_list,
            class_mode='categorical',
            auto_batch_balance=auto_batch_balance,
            batch_size=train_bs,
            shuffle=True,
            seed=random_seed)

    print "Create generator for val set"
    validation_set = val_imgen.flow_from_directory(
        val_dir,
        target_size=img_size,
        target_scale=img_scale,
        rescale_factor=rescale_factor,
        equalize_hist=equalize_hist,
        dup_3_channels=dup_3_channels,
        classes=class_list,
        class_mode='categorical',
        batch_size=batch_size,
        shuffle=False)
    sys.stdout.flush()
    if load_val_ram:
        print "Loading validation set into RAM.",
        sys.stdout.flush()
        validation_set = load_dat_ram(validation_set, validation_set.nb_sample)
        print "Done."
        sys.stdout.flush()

    # ==================== Model training ==================== #
    # Do 2-stage training.
    train_batches = int(train_generator.nb_sample / train_bs) + 1
    if isinstance(validation_set, tuple):
        val_samples = len(validation_set[0])
    else:
        val_samples = validation_set.nb_sample
    validation_steps = int(val_samples / batch_size)
    #### DEBUG ####
    # train_batches = 1
    # val_samples = batch_size*5
    # validation_steps = 5
    #### DEBUG ####
    if load_val_ram:
        auc_checkpointer = DMAucModelCheckpoint(best_model,
                                                validation_set,
                                                batch_size=batch_size)
    else:
        auc_checkpointer = DMAucModelCheckpoint(best_model,
                                                validation_set,
                                                test_samples=val_samples)
    # import pdb; pdb.set_trace()
    image_model, loss_hist, acc_hist = do_2stage_training(
        image_model,
        org_model,
        train_generator,
        validation_set,
        validation_steps,
        best_model,
        train_batches,
        top_layer_nb,
        nb_epoch=nb_epoch,
        all_layer_epochs=all_layer_epochs,
        optim=optim,
        init_lr=init_lr,
        all_layer_multiplier=all_layer_multiplier,
        es_patience=es_patience,
        lr_patience=lr_patience,
        auto_batch_balance=auto_batch_balance,
        pos_cls_weight=pos_cls_weight,
        neg_cls_weight=neg_cls_weight,
        nb_worker=nb_worker,
        auc_checkpointer=auc_checkpointer,
        weight_decay=weight_decay,
        hidden_dropout=hidden_dropout,
        weight_decay2=weight_decay2,
        hidden_dropout2=hidden_dropout2,
    )

    # Training report.
    if len(loss_hist) > 0:
        min_loss_locs, = np.where(loss_hist == min(loss_hist))
        best_val_loss = loss_hist[min_loss_locs[0]]
        best_val_accuracy = acc_hist[min_loss_locs[0]]
        print "\n==== Training summary ===="
        print "Minimum val loss achieved at epoch:", min_loss_locs[0] + 1
        print "Best val loss:", best_val_loss
        print "Best val accuracy:", best_val_accuracy

    if final_model != "NOSAVE":
        image_model.save(final_model)

    # ==== Predict on test set ==== #
    print "\n==== Predicting on test set ===="
    test_generator = test_imgen.flow_from_directory(
        test_dir,
        target_size=img_size,
        target_scale=img_scale,
        rescale_factor=rescale_factor,
        equalize_hist=equalize_hist,
        dup_3_channels=dup_3_channels,
        classes=class_list,
        class_mode='categorical',
        batch_size=batch_size,
        shuffle=False)
    test_samples = test_generator.nb_sample
    #### DEBUG ####
    # test_samples = 5
    #### DEBUG ####
    print "Test samples =", test_samples
    print "Load saved best model:", best_model + '.',
    sys.stdout.flush()
    org_model.load_weights(best_model)
    print "Done."
    # test_steps = int(test_generator.nb_sample/batch_size)
    # test_res = image_model.evaluate_generator(
    #     test_generator, test_steps, nb_worker=nb_worker,
    #     pickle_safe=True if nb_worker > 1 else False)
    test_auc = DMAucModelCheckpoint.calc_test_auc(test_generator,
                                                  image_model,
                                                  test_samples=test_samples)
    print "AUROC on test set:", test_auc
Ejemplo n.º 4
0
def run(train_dir, val_dir, test_dir, patch_model_state=None, resume_from=None,
        img_size=[1152, 896], img_scale=None, rescale_factor=None,
        featurewise_center=True, featurewise_mean=52.16, 
        equalize_hist=False, augmentation=True,
        class_list=['neg', 'pos'], patch_net='resnet50',
        block_type='resnet', top_depths=[512, 512], top_repetitions=[3, 3], 
        bottleneck_enlarge_factor=4, 
        add_heatmap=False, avg_pool_size=[7, 7], 
        add_conv=True, add_shortcut=False,
        hm_strides=(1,1), hm_pool_size=(5,5),
        fc_init_units=64, fc_layers=2,
        top_layer_nb=None,
        batch_size=64, train_bs_multiplier=.5, 
        nb_epoch=5, all_layer_epochs=20,
        load_val_ram=False, load_train_ram=False,
        weight_decay=.0001, hidden_dropout=.0, 
        weight_decay2=.0001, hidden_dropout2=.0, 
        optim='sgd', init_lr=.01, lr_patience=10, es_patience=25,
        auto_batch_balance=False, pos_cls_weight=1.0, neg_cls_weight=1.0,
        all_layer_multiplier=.1,
        best_model='./modelState/image_clf.h5',
        final_model="NOSAVE"):
    '''Train a deep learning model for image classifications
    '''

    # ======= Environmental variables ======== #
    random_seed = int(os.getenv('RANDOM_SEED', 12345))
    nb_worker = int(os.getenv('NUM_CPU_CORES', 4))
    gpu_count = int(os.getenv('NUM_GPU_DEVICES', 1))

    # ========= Image generator ============== #
    if featurewise_center:
        train_imgen = DMImageDataGenerator(featurewise_center=True)
        val_imgen = DMImageDataGenerator(featurewise_center=True)
        test_imgen = DMImageDataGenerator(featurewise_center=True)
        train_imgen.mean = featurewise_mean
        val_imgen.mean = featurewise_mean
        test_imgen.mean = featurewise_mean
    else:
        train_imgen = DMImageDataGenerator()
        val_imgen = DMImageDataGenerator()
        test_imgen = DMImageDataGenerator()

    # Add augmentation options.
    if augmentation:
        train_imgen.horizontal_flip = True 
        train_imgen.vertical_flip = True
        train_imgen.rotation_range = 25.  # in degree.
        train_imgen.shear_range = .2  # in radians.
        train_imgen.zoom_range = [.8, 1.2]  # in proportion.
        train_imgen.channel_shift_range = 20.  # in pixel intensity values.

    # ================= Model creation ============== #
    if resume_from is not None:
        image_model = load_model(resume_from, compile=False)
    else:
        patch_model = load_model(patch_model_state, compile=False)
        image_model, top_layer_nb = add_top_layers(
            patch_model, img_size, patch_net, block_type, 
            top_depths, top_repetitions, bottleneck_org,
            nb_class=len(class_list), shortcut_with_bn=True, 
            bottleneck_enlarge_factor=bottleneck_enlarge_factor,
            dropout=hidden_dropout, weight_decay=weight_decay,
            add_heatmap=add_heatmap, avg_pool_size=avg_pool_size,
            add_conv=add_conv, add_shortcut=add_shortcut,
            hm_strides=hm_strides, hm_pool_size=hm_pool_size, 
            fc_init_units=fc_init_units, fc_layers=fc_layers)
    if gpu_count > 1:
        image_model, org_model = make_parallel(image_model, gpu_count)
    else:
        org_model = image_model

    # ============ Train & validation set =============== #
    train_bs = int(batch_size*train_bs_multiplier)
    if patch_net != 'yaroslav':
        dup_3_channels = True
    else:
        dup_3_channels = False
    if load_train_ram:
        raw_imgen = DMImageDataGenerator()
        print "Create generator for raw train set"
        raw_generator = raw_imgen.flow_from_directory(
            train_dir, target_size=img_size, target_scale=img_scale, 
            rescale_factor=rescale_factor,
            equalize_hist=equalize_hist, dup_3_channels=dup_3_channels,
            classes=class_list, class_mode='categorical', 
            batch_size=train_bs, shuffle=False)
        print "Loading raw train set into RAM.",
        sys.stdout.flush()
        raw_set = load_dat_ram(raw_generator, raw_generator.nb_sample)
        print "Done."; sys.stdout.flush()
        print "Create generator for train set"
        train_generator = train_imgen.flow(
            raw_set[0], raw_set[1], batch_size=train_bs, 
            auto_batch_balance=auto_batch_balance, 
            shuffle=True, seed=random_seed)
    else:
        print "Create generator for train set"
        train_generator = train_imgen.flow_from_directory(
            train_dir, target_size=img_size, target_scale=img_scale,
            rescale_factor=rescale_factor,
            equalize_hist=equalize_hist, dup_3_channels=dup_3_channels,
            classes=class_list, class_mode='categorical', 
            auto_batch_balance=auto_batch_balance, batch_size=train_bs, 
            shuffle=True, seed=random_seed)

    print "Create generator for val set"
    validation_set = val_imgen.flow_from_directory(
        val_dir, target_size=img_size, target_scale=img_scale,
        rescale_factor=rescale_factor,
        equalize_hist=equalize_hist, dup_3_channels=dup_3_channels,
        classes=class_list, class_mode='categorical', 
        batch_size=batch_size, shuffle=False)
    sys.stdout.flush()
    if load_val_ram:
        print "Loading validation set into RAM.",
        sys.stdout.flush()
        validation_set = load_dat_ram(validation_set, validation_set.nb_sample)
        print "Done."; sys.stdout.flush()

    # ==================== Model training ==================== #
    # Do 2-stage training.
    train_batches = int(train_generator.nb_sample/train_bs) + 1
    if isinstance(validation_set, tuple):
        val_samples = len(validation_set[0])
    else:
        val_samples = validation_set.nb_sample
    validation_steps = int(val_samples/batch_size)
    #### DEBUG ####
    # train_batches = 1
    # val_samples = batch_size*5
    # validation_steps = 5
    #### DEBUG ####
    if load_val_ram:
        auc_checkpointer = DMAucModelCheckpoint(
            best_model, validation_set, batch_size=batch_size)
    else:
        auc_checkpointer = DMAucModelCheckpoint(
            best_model, validation_set, test_samples=val_samples)
    # import pdb; pdb.set_trace()
    image_model, loss_hist, acc_hist = do_2stage_training(
        image_model, org_model, train_generator, validation_set, validation_steps, 
        best_model, train_batches, top_layer_nb, nb_epoch=nb_epoch,
        all_layer_epochs=all_layer_epochs,
        optim=optim, init_lr=init_lr, 
        all_layer_multiplier=all_layer_multiplier,
        es_patience=es_patience, lr_patience=lr_patience, 
        auto_batch_balance=auto_batch_balance, 
        pos_cls_weight=pos_cls_weight, neg_cls_weight=neg_cls_weight,
        nb_worker=nb_worker, auc_checkpointer=auc_checkpointer,
        weight_decay=weight_decay, hidden_dropout=hidden_dropout,
        weight_decay2=weight_decay2, hidden_dropout2=hidden_dropout2,)

    # Training report.
    if len(loss_hist) > 0:
        min_loss_locs, = np.where(loss_hist == min(loss_hist))
        best_val_loss = loss_hist[min_loss_locs[0]]
        best_val_accuracy = acc_hist[min_loss_locs[0]]
        print "\n==== Training summary ===="
        print "Minimum val loss achieved at epoch:", min_loss_locs[0] + 1
        print "Best val loss:", best_val_loss
        print "Best val accuracy:", best_val_accuracy

    if final_model != "NOSAVE":
        image_model.save(final_model)

    # ==== Predict on test set ==== #
    print "\n==== Predicting on test set ===="
    test_generator = test_imgen.flow_from_directory(
        test_dir, target_size=img_size, target_scale=img_scale,
        rescale_factor=rescale_factor,
        equalize_hist=equalize_hist, dup_3_channels=dup_3_channels, 
        classes=class_list, class_mode='categorical', batch_size=batch_size, 
        shuffle=False)
    test_samples = test_generator.nb_sample
    #### DEBUG ####
    # test_samples = 5
    #### DEBUG ####
    print "Test samples =", test_samples
    print "Load saved best model:", best_model + '.',
    sys.stdout.flush()
    org_model.load_weights(best_model)
    print "Done."
    # test_steps = int(test_generator.nb_sample/batch_size)
    # test_res = image_model.evaluate_generator(
    #     test_generator, test_steps, nb_worker=nb_worker, 
    #     pickle_safe=True if nb_worker > 1 else False)
    test_auc = DMAucModelCheckpoint.calc_test_auc(
        test_generator, image_model, test_samples=test_samples)
    print "AUROC on test set:", test_auc