Ejemplo n.º 1
0
def run(img_folder, img_extension='dcm', 
        img_height=1024, img_scale=4095, 
        do_featurewise_norm=True, norm_fit_size=10,
        img_per_batch=2, roi_per_img=32, roi_size=(256, 256), 
        one_patch_mode=False,
        low_int_threshold=.05, blob_min_area=3, 
        blob_min_int=.5, blob_max_int=.85, blob_th_step=10,
        data_augmentation=False, roi_state=None, clf_bs=32, cutpoint=.5,
        amp_factor=1., return_sample_weight=True, auto_batch_balance=True,
        patches_per_epoch=12800, nb_epoch=20, 
        neg_vs_pos_ratio=None, all_neg_skip=0., 
        nb_init_filter=32, init_filter_size=5, init_conv_stride=2, 
        pool_size=2, pool_stride=2, 
        weight_decay=.0001, alpha=.0001, l1_ratio=.0, 
        inp_dropout=.0, hidden_dropout=.0, init_lr=.01,
        test_size=.2, val_size=.0, 
        lr_patience=3, es_patience=10, 
        resume_from=None, net='resnet50', load_val_ram=False, 
        load_train_ram=False, no_pos_skip=0., balance_classes=0.,
        pred_img_per_batch=1, pred_roi_per_img=32,
        exam_tsv='./metadata/exams_metadata.tsv',
        img_tsv='./metadata/images_crosswalk.tsv',
        best_model='./modelState/dm_candidROI_best_model.h5',
        final_model="NOSAVE",
        pred_trainval=False, pred_out="dl_pred_out.pkl"):
    '''Run ResNet training on candidate ROIs from mammograms
    Args:
        norm_fit_size ([int]): the number of patients used to calculate 
                feature-wise mean and std.
    '''

    # Read some env variables.
    random_seed = int(os.getenv('RANDOM_SEED', 12345))
    # Use of multiple CPU cores is not working!
    # When nb_worker>1 and pickle_safe=True, this error is encountered:
    # "failed to enqueue async memcpy from host to device: CUDA_ERROR_NOT_INITIALIZED"
    # To avoid the error, only this combination worked: 
    # nb_worker=1 and pickle_safe=False.
    nb_worker = int(os.getenv('NUM_CPU_CORES', 4))
    gpu_count = int(os.getenv('NUM_GPU_DEVICES', 1))
    
    # Setup training and validation data.
    # Load image or exam lists and split them into train and val sets.
    meta_man = DMMetaManager(exam_tsv=exam_tsv, 
                             img_tsv=img_tsv, 
                             img_folder=img_folder, 
                             img_extension=img_extension)
    # Split data based on subjects.
    subj_list, subj_labs = meta_man.get_subj_labs()
    subj_train, subj_test, slab_train, slab_test = train_test_split(
        subj_list, subj_labs, test_size=test_size, random_state=random_seed, 
        stratify=subj_labs)
    if val_size > 0:  # train/val split.
        subj_train, subj_val, slab_train, slab_val = train_test_split(
            subj_train, slab_train, test_size=val_size, 
            random_state=random_seed, stratify=slab_train)
    else:  # use test as val. make a copy of the test list.
        subj_val = list(subj_test)
        slab_val = list(slab_test)
    # import pdb; pdb.set_trace()
    # Subset subject lists to desired ratio.
    if neg_vs_pos_ratio is not None:
        subj_train, slab_train = DMMetaManager.subset_subj_list(
            subj_train, slab_train, neg_vs_pos_ratio, random_seed)
        subj_val, slab_val = DMMetaManager.subset_subj_list(
            subj_val, slab_val, neg_vs_pos_ratio, random_seed)
    print "After sampling, Nb of subjects for train=%d, val=%d, test=%d" \
            % (len(subj_train), len(subj_val), len(subj_test))
    # Get image and label lists.
    img_train, lab_train = meta_man.get_flatten_img_list(subj_train)
    img_val, lab_val = meta_man.get_flatten_img_list(subj_val)

    # Create image generators for train, fit and val.
    imgen_trainval = DMImageDataGenerator()
    if data_augmentation:
        imgen_trainval.horizontal_flip=True 
        imgen_trainval.vertical_flip=True
        imgen_trainval.rotation_range = 45.
        imgen_trainval.shear_range = np.pi/8.
        # imgen_trainval.width_shift_range = .05
        # imgen_trainval.height_shift_range = .05
        # imgen_trainval.zoom_range = [.95, 1.05]

    if do_featurewise_norm:
        imgen_trainval.featurewise_center = True
        imgen_trainval.featurewise_std_normalization = True
        # Fit feature-wise mean and std.
        img_fit,_ = meta_man.get_flatten_img_list(
            subj_train[:norm_fit_size])  # fit on a subset.
        print ">>> Fit image generator <<<"; sys.stdout.flush()
        fit_generator = imgen_trainval.flow_from_candid_roi(
            img_fit,
            target_height=img_height, target_scale=img_scale,
            class_mode=None, validation_mode=True, 
            img_per_batch=len(img_fit), roi_per_img=roi_per_img, 
            roi_size=roi_size,
            low_int_threshold=low_int_threshold, blob_min_area=blob_min_area, 
            blob_min_int=blob_min_int, blob_max_int=blob_max_int, 
            blob_th_step=blob_th_step,
            roi_clf=None, return_sample_weight=False, seed=random_seed)
        imgen_trainval.fit(fit_generator.next())
        print "Estimates from %d images: mean=%.1f, std=%.1f." % \
            (len(img_fit), imgen_trainval.mean, imgen_trainval.std)
        sys.stdout.flush()
    else:
        imgen_trainval.samplewise_center = True
        imgen_trainval.samplewise_std_normalization = True

    # Load ROI classifier.
    if roi_state is not None:
        roi_clf = load_model(
            roi_state, 
            custom_objects={
                'sensitivity': DMMetrics.sensitivity, 
                'specificity': DMMetrics.specificity
            }
        )
        graph = tf.get_default_graph()
    else:
        roi_clf = None
        graph = None

    # Set some DL training related parameters.
    if one_patch_mode:
        class_mode = 'binary'
        loss = 'binary_crossentropy'
        metrics = [DMMetrics.sensitivity, DMMetrics.specificity]
    else:
        class_mode = 'categorical'
        loss = 'categorical_crossentropy'
        metrics = ['accuracy', 'precision', 'recall']
    if load_train_ram:
        validation_mode = True
        return_raw_img = True
    else:
        validation_mode = False
        return_raw_img = False

    # Create train and val generators.
    print ">>> Train image generator <<<"; sys.stdout.flush()
    train_generator = imgen_trainval.flow_from_candid_roi(
        img_train, lab_train, 
        target_height=img_height, target_scale=img_scale,
        class_mode=class_mode, validation_mode=validation_mode, 
        img_per_batch=img_per_batch, roi_per_img=roi_per_img, 
        roi_size=roi_size, one_patch_mode=one_patch_mode,
        low_int_threshold=low_int_threshold, blob_min_area=blob_min_area, 
        blob_min_int=blob_min_int, blob_max_int=blob_max_int, 
        blob_th_step=blob_th_step,
        tf_graph=graph, roi_clf=roi_clf, clf_bs=clf_bs, cutpoint=cutpoint,
        amp_factor=amp_factor, return_sample_weight=return_sample_weight,
        auto_batch_balance=auto_batch_balance,
        all_neg_skip=all_neg_skip, shuffle=True, seed=random_seed,
        return_raw_img=return_raw_img)

    print ">>> Validation image generator <<<"; sys.stdout.flush()
    val_generator = imgen_trainval.flow_from_candid_roi(
        img_val, lab_val, 
        target_height=img_height, target_scale=img_scale,
        class_mode=class_mode, validation_mode=True, 
        img_per_batch=img_per_batch, roi_per_img=roi_per_img, 
        roi_size=roi_size, one_patch_mode=one_patch_mode,
        low_int_threshold=low_int_threshold, blob_min_area=blob_min_area, 
        blob_min_int=blob_min_int, blob_max_int=blob_max_int, 
        blob_th_step=blob_th_step,
        tf_graph=graph, roi_clf=roi_clf, clf_bs=clf_bs, cutpoint=cutpoint,
        amp_factor=amp_factor, return_sample_weight=False, 
        auto_batch_balance=False,
        seed=random_seed)

    # Load train and validation set into RAM.
    if one_patch_mode:
        nb_train_samples = len(img_train)
        nb_val_samples = len(img_val)
    else:
        nb_train_samples = len(img_train)*roi_per_img
        nb_val_samples = len(img_val)*roi_per_img
    if load_val_ram:
        print "Loading validation data into RAM.",
        sys.stdout.flush()
        validation_set = load_dat_ram(val_generator, nb_val_samples)
        print "Done."; sys.stdout.flush()
        sparse_y = to_sparse(validation_set[1])
        for uy in np.unique(sparse_y):
            print "Nb of samples for class:%d = %d" % \
                    (uy, (sparse_y==uy).sum())
        sys.stdout.flush()
    if load_train_ram:
        print "Loading train data into RAM.",
        sys.stdout.flush()
        train_set = load_dat_ram(train_generator, nb_train_samples)
        print "Done."; sys.stdout.flush()
        sparse_y = to_sparse(train_set[1])
        for uy in np.unique(sparse_y):
            print "Nb of samples for class:%d = %d" % \
                    (uy, (sparse_y==uy).sum())
        sys.stdout.flush()
        train_generator = imgen_trainval.flow(
            train_set[0], train_set[1], batch_size=clf_bs, 
            auto_batch_balance=auto_batch_balance, no_pos_skip=no_pos_skip,
            balance_classes=balance_classes, shuffle=True, seed=random_seed)

    # Load or create model.
    if resume_from is not None:
        model = load_model(
            resume_from,
            custom_objects={
                'sensitivity': DMMetrics.sensitivity, 
                'specificity': DMMetrics.specificity
            }
        )
    else:
        builder = ResNetBuilder
        if net == 'resnet18':
            model = builder.build_resnet_18(
                (1, roi_size[0], roi_size[1]), 3, nb_init_filter, init_filter_size, 
                init_conv_stride, pool_size, pool_stride, weight_decay, alpha, l1_ratio, 
                inp_dropout, hidden_dropout)
        elif net == 'resnet34':
            model = builder.build_resnet_34(
                (1, roi_size[0], roi_size[1]), 3, nb_init_filter, init_filter_size, 
                init_conv_stride, pool_size, pool_stride, weight_decay, alpha, l1_ratio, 
                inp_dropout, hidden_dropout)
        elif net == 'resnet50':
            model = builder.build_resnet_50(
                (1, roi_size[0], roi_size[1]), 3, nb_init_filter, init_filter_size, 
                init_conv_stride, pool_size, pool_stride, weight_decay, alpha, l1_ratio, 
                inp_dropout, hidden_dropout)
        elif net == 'resnet101':
            model = builder.build_resnet_101(
                (1, roi_size[0], roi_size[1]), 3, nb_init_filter, init_filter_size, 
                init_conv_stride, pool_size, pool_stride, weight_decay, alpha, l1_ratio, 
                inp_dropout, hidden_dropout)
        elif net == 'resnet152':
            model = builder.build_resnet_152(
                (1, roi_size[0], roi_size[1]), 3, nb_init_filter, init_filter_size, 
                init_conv_stride, pool_size, pool_stride, weight_decay, alpha, l1_ratio, 
                inp_dropout, hidden_dropout)
    
    if gpu_count > 1:
        model = make_parallel(model, gpu_count)

    # Model training.
    sgd = SGD(lr=init_lr, momentum=0.9, decay=0.0, nesterov=True)
    model.compile(optimizer=sgd, loss=loss, metrics=metrics)
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, 
                                  patience=lr_patience, verbose=1)
    early_stopping = EarlyStopping(monitor='val_loss', patience=es_patience, 
                                   verbose=1)
    if load_val_ram:
        auc_checkpointer = DMAucModelCheckpoint(
            best_model, validation_set, batch_size=clf_bs)
    else:
        auc_checkpointer = DMAucModelCheckpoint(
            best_model, val_generator, nb_test_samples=nb_val_samples)
    hist = model.fit_generator(
        train_generator, 
        samples_per_epoch=patches_per_epoch, 
        nb_epoch=nb_epoch,
        validation_data=validation_set if load_val_ram else val_generator, 
        nb_val_samples=nb_val_samples, 
        callbacks=[reduce_lr, early_stopping, auc_checkpointer],
        # nb_worker=1, pickle_safe=False,
        nb_worker=nb_worker if load_train_ram else 1,
        pickle_safe=True if load_train_ram else False,
        verbose=2)

    if final_model != "NOSAVE":
        print "Saving final model to:", final_model; sys.stdout.flush()
        model.save(final_model)
    
    # Training report.
    min_loss_locs, = np.where(hist.history['val_loss'] == min(hist.history['val_loss']))
    best_val_loss = hist.history['val_loss'][min_loss_locs[0]]
    if one_patch_mode:
        best_val_sensitivity = hist.history['val_sensitivity'][min_loss_locs[0]]
        best_val_specificity = hist.history['val_specificity'][min_loss_locs[0]]
    else:
        best_val_precision = hist.history['val_precision'][min_loss_locs[0]]
        best_val_recall = hist.history['val_recall'][min_loss_locs[0]]
        best_val_accuracy = hist.history['val_acc'][min_loss_locs[0]]
    print "\n==== Training summary ===="
    print "Minimum val loss achieved at epoch:", min_loss_locs[0] + 1
    print "Best val loss:", best_val_loss
    if one_patch_mode:
        print "Best val sensitivity:", best_val_sensitivity
        print "Best val specificity:", best_val_specificity
    else:
        print "Best val precision:", best_val_precision
        print "Best val recall:", best_val_recall
        print "Best val accuracy:", best_val_accuracy

    # Make predictions on train, val, test exam lists.
    if best_model != 'NOSAVE':
        print "\n==== Making predictions ===="
        print "Load best model for prediction:", best_model
        sys.stdout.flush()
        pred_model = load_model(best_model)
        if gpu_count > 1:
            pred_model = make_parallel(pred_model, gpu_count)
        
        if pred_trainval:
            print "Load exam lists for train, val sets"; sys.stdout.flush()
            exam_train = meta_man.get_flatten_exam_list(
                subj_train, flatten_img_list=True)
            print "Train exam list length=", len(exam_train); sys.stdout.flush()
            exam_val = meta_man.get_flatten_exam_list(
                subj_val, flatten_img_list=True)
            print "Val exam list length=", len(exam_val); sys.stdout.flush()
        print "Load exam list for test set"; sys.stdout.flush()
        exam_test = meta_man.get_flatten_exam_list(
            subj_test, flatten_img_list=True)
        print "Test exam list length=", len(exam_test); sys.stdout.flush()
        
        if do_featurewise_norm:
            imgen_pred = DMImageDataGenerator()
            imgen_pred.featurewise_center = True
            imgen_pred.featurewise_std_normalization = True
            imgen_pred.mean = imgen_trainval.mean
            imgen_pred.std = imgen_trainval.std
        else:
            imgen_pred.samplewise_center = True
            imgen_pred.samplewise_std_normalization = True
        
        if pred_trainval:
            print "Make predictions on train exam list"; sys.stdout.flush()
            meta_prob_train = get_exam_pred(
                exam_train, pred_roi_per_img, imgen_pred, 
                target_height=img_height, target_scale=img_scale,
                img_per_batch=pred_img_per_batch, roi_size=roi_size,
                low_int_threshold=low_int_threshold, blob_min_area=blob_min_area, 
                blob_min_int=blob_min_int, blob_max_int=blob_max_int, 
                blob_th_step=blob_th_step, seed=random_seed, 
                dl_model=pred_model)
            print "Train prediction list length=", len(meta_prob_train)
            
            print "Make predictions on val exam list"; sys.stdout.flush()
            meta_prob_val = get_exam_pred(
                exam_val, pred_roi_per_img, imgen_pred, 
                target_height=img_height, target_scale=img_scale,
                img_per_batch=pred_img_per_batch, roi_size=roi_size,
                low_int_threshold=low_int_threshold, blob_min_area=blob_min_area, 
                blob_min_int=blob_min_int, blob_max_int=blob_max_int, 
                blob_th_step=blob_th_step, seed=random_seed, 
                dl_model=pred_model)
            print "Val prediction list length=", len(meta_prob_val)
        
        print "Make predictions on test exam list"; sys.stdout.flush()
        meta_prob_test = get_exam_pred(
            exam_test, pred_roi_per_img, imgen_pred, 
            target_height=img_height, target_scale=img_scale,
            img_per_batch=pred_img_per_batch, roi_size=roi_size,
            low_int_threshold=low_int_threshold, blob_min_area=blob_min_area, 
            blob_min_int=blob_min_int, blob_max_int=blob_max_int, 
            blob_th_step=blob_th_step, seed=random_seed, 
            dl_model=pred_model)
        print "Test prediction list length=", len(meta_prob_test)
        
        if pred_trainval:
            pickle.dump((meta_prob_train, meta_prob_val, meta_prob_test), 
                        open(pred_out, 'w'))
        else:
            pickle.dump(meta_prob_test, open(pred_out, 'w'))

    return hist
Ejemplo n.º 2
0
def run(img_folder,
        dl_state,
        img_extension='dcm',
        img_height=1024,
        img_scale=4095,
        val_size=.2,
        neg_vs_pos_ratio=10.,
        do_featurewise_norm=True,
        featurewise_mean=873.6,
        featurewise_std=739.3,
        img_per_batch=2,
        roi_per_img=32,
        roi_size=(256, 256),
        low_int_threshold=.05,
        blob_min_area=3,
        blob_min_int=.5,
        blob_max_int=.85,
        blob_th_step=10,
        layer_name=['flatten_1', 'dense_1'],
        layer_index=None,
        roi_state=None,
        roi_clf_bs=32,
        pc_components=.95,
        pc_whiten=True,
        nb_words=[512],
        km_max_iter=100,
        km_bs=1000,
        km_patience=20,
        km_init=10,
        exam_tsv='./metadata/exams_metadata.tsv',
        img_tsv='./metadata/images_crosswalk.tsv',
        pca_km_states='./modelState/dlrepr_pca_km_models.pkl',
        bow_train_out='./modelState/bow_dat_train.pkl',
        bow_test_out='./modelState/bow_dat_test.pkl'):
    '''Calculate bag of deep visual words count matrix for all breasts
    '''

    # Read some env variables.
    random_seed = int(os.getenv('RANDOM_SEED', 12345))
    rng = RandomState(random_seed)  # an rng used across board.

    # Load and split image and label lists.
    meta_man = DMMetaManager(exam_tsv=exam_tsv,
                             img_tsv=img_tsv,
                             img_folder=img_folder,
                             img_extension=img_extension)
    subj_list, subj_labs = meta_man.get_subj_labs()
    subj_train, subj_test, labs_train, labs_test = train_test_split(
        subj_list,
        subj_labs,
        test_size=val_size,
        stratify=subj_labs,
        random_state=random_seed)
    if neg_vs_pos_ratio is not None:

        def subset_subj(subj, labs):
            subj = np.array(subj)
            labs = np.array(labs)
            pos_idx = np.where(labs == 1)[0]
            neg_idx = np.where(labs == 0)[0]
            nb_neg_desired = int(len(pos_idx) * neg_vs_pos_ratio)
            if nb_neg_desired >= len(neg_idx):
                return subj.tolist()
            else:
                neg_chosen = rng.choice(neg_idx, nb_neg_desired, replace=False)
                subset_idx = np.concatenate([pos_idx, neg_chosen])
                return subj[subset_idx].tolist()

        subj_train = subset_subj(subj_train, labs_train)
        subj_test = subset_subj(subj_test, labs_test)

    img_list, lab_list = meta_man.get_flatten_img_list(subj_train)
    lab_list = np.array(lab_list)
    print "Train set - Nb of positive images: %d, Nb of negative images: %d" \
            % ( (lab_list==1).sum(), (lab_list==0).sum())
    sys.stdout.flush()

    # Create image generator for ROIs for representation extraction.
    print "Create an image generator for ROIs"
    sys.stdout.flush()
    if do_featurewise_norm:
        imgen = DMImageDataGenerator(featurewise_center=True,
                                     featurewise_std_normalization=True)
        imgen.mean = featurewise_mean
        imgen.std = featurewise_std
    else:
        imgen = DMImageDataGenerator(samplewise_center=True,
                                     samplewise_std_normalization=True)

    # Load ROI classifier.
    if roi_state is not None:
        print "Load ROI classifier"
        sys.stdout.flush()
        roi_clf = load_model(roi_state,
                             custom_objects={
                                 'sensitivity': dmm.sensitivity,
                                 'specificity': dmm.specificity
                             })
        graph = tf.get_default_graph()
    else:
        roi_clf = None
        graph = None

    # Create ROI generators for pos and neg images separately.
    print "Create ROI generators for pos and neg images"
    sys.stdout.flush()
    roi_generator = imgen.flow_from_candid_roi(
        img_list,
        target_height=img_height,
        target_scale=img_scale,
        class_mode=None,
        validation_mode=True,
        img_per_batch=img_per_batch,
        roi_per_img=roi_per_img,
        roi_size=roi_size,
        low_int_threshold=low_int_threshold,
        blob_min_area=blob_min_area,
        blob_min_int=blob_min_int,
        blob_max_int=blob_max_int,
        blob_th_step=blob_th_step,
        tf_graph=graph,
        roi_clf=roi_clf,
        clf_bs=roi_clf_bs,
        return_sample_weight=False,
        seed=random_seed)

    # Generate image patches and extract their DL representations.
    print "Load DL representation model"
    sys.stdout.flush()
    dlrepr_model = DLRepr(dl_state,
                          custom_objects={
                              'sensitivity': dmm.sensitivity,
                              'specificity': dmm.specificity
                          },
                          layer_name=layer_name,
                          layer_index=layer_index)
    last_output_size = dlrepr_model.get_output_shape()[-1][-1]
    if last_output_size != 3 and last_output_size != 1:
        raise Exception("The last output must be prob outputs (size=3 or 1)")

    nb_tot_samples = len(img_list) * roi_per_img
    print "Extract ROIs from pos and neg images"
    sys.stdout.flush()
    pred = dlrepr_model.predict_generator(roi_generator,
                                          val_samples=nb_tot_samples)
    for i, d in enumerate(pred):
        print "Shape of representation/output data %d:" % (i), d.shape
    sys.stdout.flush()

    # Flatten feature maps, e.g. an 8x8 feature map will become a 64-d vector.
    pred = [d.reshape((-1, d.shape[-1])) for d in pred]
    for i, d in enumerate(pred):
        print "Shape of flattened data %d:" % (i), d.shape
    sys.stdout.flush()

    # Split representations and prob outputs.
    dl_repr = pred[0]
    prob_out = pred[1]
    if prob_out.shape[1] == 3:
        prob_out = prob_out[:, 1]  # pos class.
    prob_out = prob_out.reshape((len(img_list), -1))
    print "Reshape prob output to:", prob_out.shape
    sys.stdout.flush()

    # Use PCA to reduce dimension of the representation data.
    if pc_components is not None:
        print "Start PCA dimension reduction on DL representation"
        sys.stdout.flush()
        pca = PCA(n_components=pc_components, whiten=pc_whiten)
        pca.fit(dl_repr)
        print "Nb of PCA components:", pca.n_components_
        print "Total explained variance ratio: %.4f" % \
                (pca.explained_variance_ratio_.sum())
        dl_repr_pca = pca.transform(dl_repr)
        print "Shape of transformed representation data:", dl_repr_pca.shape
        sys.stdout.flush()
    else:
        pca = None

    # Use K-means to create a codebook for deep visual words.
    print "Start K-means training on DL representation"
    sys.stdout.flush()
    clf_list = []
    clust_list = []
    # Shuffling indices for mini-batches learning.
    perm_idx = rng.permutation(len(dl_repr))
    for n in nb_words:
        print "Train K-means with %d cluster centers" % (n)
        sys.stdout.flush()
        clf = MiniBatchKMeans(n_clusters=n,
                              init='k-means++',
                              max_iter=km_max_iter,
                              batch_size=km_bs,
                              compute_labels=True,
                              random_state=random_seed,
                              tol=0.0,
                              max_no_improvement=km_patience,
                              init_size=None,
                              n_init=km_init,
                              reassignment_ratio=0.01,
                              verbose=0)
        clf.fit(dl_repr[perm_idx])
        clf_list.append(clf)
        clust = np.zeros_like(clf.labels_)
        clust[perm_idx] = clf.labels_
        clust = clust.reshape((len(img_list), -1))
        clust_list.append(clust)

    if pca is not None:
        print "Start K-means training on transformed representation"
        sys.stdout.flush()
        clf_list_pca = []
        clust_list_pca = []
        # Shuffling indices for mini-batches learning.
        perm_idx = rng.permutation(len(dl_repr_pca))
        for n in nb_words:
            print "Train K-means with %d cluster centers" % (n)
            sys.stdout.flush()
            clf = MiniBatchKMeans(n_clusters=n,
                                  init='k-means++',
                                  max_iter=km_max_iter,
                                  batch_size=km_bs,
                                  compute_labels=True,
                                  random_state=random_seed,
                                  tol=0.0,
                                  max_no_improvement=km_patience,
                                  init_size=None,
                                  n_init=km_init,
                                  reassignment_ratio=0.01,
                                  verbose=0)
            clf.fit(dl_repr_pca[perm_idx])
            clf_list_pca.append(clf)
            clust = np.zeros_like(clf.labels_)
            clust[perm_idx] = clf.labels_
            clust = clust.reshape((len(img_list), -1))
            clust_list_pca.append(clust)

    # Read exam lists.
    exam_train = meta_man.get_flatten_exam_list(subj_train,
                                                flatten_img_list=True)
    exam_test = meta_man.get_flatten_exam_list(subj_test,
                                               flatten_img_list=True)
    exam_labs_train = np.array(meta_man.exam_labs(exam_train))
    exam_labs_test = np.array(meta_man.exam_labs(exam_test))
    nb_pos_exams_train = (exam_labs_train == 1).sum()
    nb_neg_exams_train = (exam_labs_train == 0).sum()
    nb_pos_exams_test = (exam_labs_test == 1).sum()
    nb_neg_exams_test = (exam_labs_test == 0).sum()
    print "Train set - Nb of pos exams: %d, Nb of neg exams: %d" % \
            (nb_pos_exams_train, nb_neg_exams_train)
    print "Test set - Nb of pos exams: %d, Nb of neg exams: %d" % \
            (nb_pos_exams_test, nb_neg_exams_test)

    # Do BoW counts for each breast.
    print "BoW counting for train exam list"
    sys.stdout.flush()
    bow_dat_train = get_exam_bow_dat(exam_train,
                                     nb_words,
                                     roi_per_img,
                                     img_list=img_list,
                                     prob_out=prob_out,
                                     clust_list=clust_list)
    for i, d in enumerate(bow_dat_train[1]):
        print "Shape of train BoW matrix %d:" % (i), d.shape
    sys.stdout.flush()

    print "BoW counting for test exam list"
    sys.stdout.flush()
    bow_dat_test = get_exam_bow_dat(exam_test,
                                    nb_words,
                                    roi_per_img,
                                    imgen=imgen,
                                    clf_list=clf_list,
                                    transformer=None,
                                    target_height=img_height,
                                    target_scale=img_scale,
                                    img_per_batch=img_per_batch,
                                    roi_size=roi_size,
                                    low_int_threshold=low_int_threshold,
                                    blob_min_area=blob_min_area,
                                    blob_min_int=blob_min_int,
                                    blob_max_int=blob_max_int,
                                    blob_th_step=blob_th_step,
                                    seed=random_seed,
                                    dlrepr_model=dlrepr_model)
    for i, d in enumerate(bow_dat_test[1]):
        print "Shape of test BoW matrix %d:" % (i), d.shape
    sys.stdout.flush()

    if pca is not None:
        print "== Do same BoW counting on PCA transformed data =="
        print "BoW counting for train exam list"
        sys.stdout.flush()
        bow_dat_train_pca = get_exam_bow_dat(exam_train,
                                             nb_words,
                                             roi_per_img,
                                             img_list=img_list,
                                             prob_out=prob_out,
                                             clust_list=clust_list_pca)
        for i, d in enumerate(bow_dat_train_pca[1]):
            print "Shape of train BoW matrix %d:" % (i), d.shape
        sys.stdout.flush()

        print "BoW counting for test exam list"
        sys.stdout.flush()
        bow_dat_test_pca = get_exam_bow_dat(
            exam_test,
            nb_words,
            roi_per_img,
            imgen=imgen,
            clf_list=clf_list_pca,
            transformer=pca,
            target_height=img_height,
            target_scale=img_scale,
            img_per_batch=img_per_batch,
            roi_size=roi_size,
            low_int_threshold=low_int_threshold,
            blob_min_area=blob_min_area,
            blob_min_int=blob_min_int,
            blob_max_int=blob_max_int,
            blob_th_step=blob_th_step,
            seed=random_seed,
            dlrepr_model=dlrepr_model)
        for i, d in enumerate(bow_dat_test_pca[1]):
            print "Shape of test BoW matrix %d:" % (i), d.shape
        sys.stdout.flush()

    # Save K-means model and BoW count data.
    if pca is None:
        pickle.dump(clf_list, open(pca_km_states, 'w'))
        pickle.dump(bow_dat_train, open(bow_train_out, 'w'))
        pickle.dump(bow_dat_test, open(bow_test_out, 'w'))
    else:
        pickle.dump((pca, clf_list), open(pca_km_states, 'w'))
        pickle.dump((bow_dat_train, bow_dat_train_pca),
                    open(bow_train_out, 'w'))
        pickle.dump((bow_dat_test, bow_dat_test_pca), open(bow_test_out, 'w'))

    print "Done."