コード例 #1
0
def run(img_folder,
        img_extension='dcm',
        img_size=[288, 224],
        img_scale=4095,
        multi_view=False,
        do_featurewise_norm=True,
        featurewise_mean=398.5,
        featurewise_std=627.8,
        batch_size=16,
        samples_per_epoch=160,
        nb_epoch=20,
        balance_classes=.0,
        all_neg_skip=0.,
        pos_cls_weight=1.0,
        nb_init_filter=64,
        init_filter_size=7,
        init_conv_stride=2,
        pool_size=3,
        pool_stride=2,
        weight_decay=.0001,
        alpha=1.,
        l1_ratio=.5,
        inp_dropout=.0,
        hidden_dropout=.0,
        init_lr=.01,
        val_size=.2,
        lr_patience=5,
        es_patience=10,
        resume_from=None,
        net='resnet50',
        load_val_ram=False,
        exam_tsv='./metadata/exams_metadata.tsv',
        img_tsv='./metadata/images_crosswalk.tsv',
        best_model='./modelState/dm_resnet_best_model.h5',
        final_model="NOSAVE"):
    '''Run ResNet training on mammograms using an exam or image list
    Args:
        featurewise_mean, featurewise_std ([float]): they are estimated from 
                1152 x 896 images. Using different sized images give very close
                results. For png, mean=7772, std=12187.
    '''

    # Read some env variables.
    random_seed = int(os.getenv('RANDOM_SEED', 12345))
    nb_worker = int(os.getenv('NUM_CPU_CORES', 4))
    gpu_count = int(os.getenv('NUM_GPU_DEVICES', 1))

    # Setup training and validation data.
    # Load image or exam lists and split them into train and val sets.
    meta_man = DMMetaManager(exam_tsv=exam_tsv,
                             img_tsv=img_tsv,
                             img_folder=img_folder,
                             img_extension=img_extension)
    if multi_view:
        exam_list = meta_man.get_flatten_exam_list()
        exam_train, exam_val = train_test_split(
            exam_list,
            test_size=val_size,
            random_state=random_seed,
            stratify=meta_man.exam_labs(exam_list))
        val_size_ = len(exam_val) * 2  # L and R.
    else:
        img_list, lab_list = meta_man.get_flatten_img_list()
        img_train, img_val, lab_train, lab_val = train_test_split(
            img_list,
            lab_list,
            test_size=val_size,
            random_state=random_seed,
            stratify=lab_list)
        val_size_ = len(img_val)

    # Create image generator.
    img_gen = DMImageDataGenerator(horizontal_flip=True, vertical_flip=True)
    if do_featurewise_norm:
        img_gen.featurewise_center = True
        img_gen.featurewise_std_normalization = True
        img_gen.mean = featurewise_mean
        img_gen.std = featurewise_std
    else:
        img_gen.samplewise_center = True
        img_gen.samplewise_std_normalization = True

    if multi_view:
        train_generator = img_gen.flow_from_exam_list(
            exam_train,
            target_size=(img_size[0], img_size[1]),
            target_scale=img_scale,
            batch_size=batch_size,
            balance_classes=balance_classes,
            all_neg_skip=all_neg_skip,
            shuffle=True,
            seed=random_seed,
            class_mode='binary')
        if load_val_ram:
            val_generator = img_gen.flow_from_exam_list(
                exam_val,
                target_size=(img_size[0], img_size[1]),
                target_scale=img_scale,
                batch_size=val_size_,
                validation_mode=True,
                class_mode='binary')
        else:
            val_generator = img_gen.flow_from_exam_list(
                exam_val,
                target_size=(img_size[0], img_size[1]),
                target_scale=img_scale,
                batch_size=batch_size,
                validation_mode=True,
                class_mode='binary')
    else:
        train_generator = img_gen.flow_from_img_list(
            img_train,
            lab_train,
            target_size=(img_size[0], img_size[1]),
            target_scale=img_scale,
            batch_size=batch_size,
            balance_classes=balance_classes,
            all_neg_skip=all_neg_skip,
            shuffle=True,
            seed=random_seed,
            class_mode='binary')
        if load_val_ram:
            val_generator = img_gen.flow_from_img_list(
                img_val,
                lab_val,
                target_size=(img_size[0], img_size[1]),
                target_scale=img_scale,
                batch_size=val_size_,
                validation_mode=True,
                class_mode='binary')
        else:
            val_generator = img_gen.flow_from_img_list(
                img_val,
                lab_val,
                target_size=(img_size[0], img_size[1]),
                target_scale=img_scale,
                batch_size=batch_size,
                validation_mode=True,
                class_mode='binary')

    # Load validation set into RAM.
    if load_val_ram:
        validation_set = next(val_generator)
        if not multi_view and len(validation_set[0]) != val_size_:
            raise Exception
        elif len(validation_set[0][0]) != val_size_ \
                or len(validation_set[0][1]) != val_size_:
            raise Exception

    # Create model.
    if resume_from is not None:
        model = load_model(resume_from,
                           custom_objects={
                               'sensitivity': DMMetrics.sensitivity,
                               'specificity': DMMetrics.specificity
                           })
    else:
        if multi_view:
            builder = MultiViewResNetBuilder
        else:
            builder = ResNetBuilder
        if net == 'resnet18':
            model = builder.build_resnet_18(
                (1, img_size[0], img_size[1]), 1, nb_init_filter,
                init_filter_size, init_conv_stride, pool_size, pool_stride,
                weight_decay, alpha, l1_ratio, inp_dropout, hidden_dropout)
        elif net == 'resnet34':
            model = builder.build_resnet_34(
                (1, img_size[0], img_size[1]), 1, nb_init_filter,
                init_filter_size, init_conv_stride, pool_size, pool_stride,
                weight_decay, alpha, l1_ratio, inp_dropout, hidden_dropout)
        elif net == 'resnet50':
            model = builder.build_resnet_50(
                (1, img_size[0], img_size[1]), 1, nb_init_filter,
                init_filter_size, init_conv_stride, pool_size, pool_stride,
                weight_decay, alpha, l1_ratio, inp_dropout, hidden_dropout)
        elif net == 'dmresnet14':
            model = builder.build_dm_resnet_14(
                (1, img_size[0], img_size[1]), 1, nb_init_filter,
                init_filter_size, init_conv_stride, pool_size, pool_stride,
                weight_decay, alpha, l1_ratio, inp_dropout, hidden_dropout)
        elif net == 'dmresnet47rb5':
            model = builder.build_dm_resnet_47rb5(
                (1, img_size[0], img_size[1]), 1, nb_init_filter,
                init_filter_size, init_conv_stride, pool_size, pool_stride,
                weight_decay, alpha, l1_ratio, inp_dropout, hidden_dropout)
        elif net == 'dmresnet56rb6':
            model = builder.build_dm_resnet_56rb6(
                (1, img_size[0], img_size[1]), 1, nb_init_filter,
                init_filter_size, init_conv_stride, pool_size, pool_stride,
                weight_decay, alpha, l1_ratio, inp_dropout, hidden_dropout)
        elif net == 'dmresnet65rb7':
            model = builder.build_dm_resnet_65rb7(
                (1, img_size[0], img_size[1]), 1, nb_init_filter,
                init_filter_size, init_conv_stride, pool_size, pool_stride,
                weight_decay, alpha, l1_ratio, inp_dropout, hidden_dropout)
        elif net == 'resnet101':
            model = builder.build_resnet_101(
                (1, img_size[0], img_size[1]), 1, nb_init_filter,
                init_filter_size, init_conv_stride, pool_size, pool_stride,
                weight_decay, alpha, l1_ratio, inp_dropout, hidden_dropout)
        elif net == 'resnet152':
            model = builder.build_resnet_152(
                (1, img_size[0], img_size[1]), 1, nb_init_filter,
                init_filter_size, init_conv_stride, pool_size, pool_stride,
                weight_decay, alpha, l1_ratio, inp_dropout, hidden_dropout)

    if gpu_count > 1:
        model = make_parallel(model, gpu_count)

    # Model training.
    sgd = SGD(lr=init_lr, momentum=0.9, decay=0.0, nesterov=True)
    model.compile(optimizer=sgd,
                  loss='binary_crossentropy',
                  metrics=[DMMetrics.sensitivity, DMMetrics.specificity])
    reduce_lr = ReduceLROnPlateau(monitor='val_loss',
                                  factor=0.1,
                                  patience=lr_patience,
                                  verbose=1)
    early_stopping = EarlyStopping(monitor='val_loss',
                                   patience=es_patience,
                                   verbose=1)
    if load_val_ram:
        auc_checkpointer = DMAucModelCheckpoint(best_model,
                                                validation_set,
                                                batch_size=batch_size)
    else:
        auc_checkpointer = DMAucModelCheckpoint(best_model,
                                                val_generator,
                                                nb_test_samples=val_size_)
    # checkpointer = ModelCheckpoint(
    #     best_model, monitor='val_loss', verbose=1, save_best_only=True)
    hist = model.fit_generator(
        train_generator,
        samples_per_epoch=samples_per_epoch,
        nb_epoch=nb_epoch,
        class_weight={
            0: 1.0,
            1: pos_cls_weight
        },
        validation_data=validation_set if load_val_ram else val_generator,
        nb_val_samples=val_size_,
        callbacks=[reduce_lr, early_stopping, auc_checkpointer],
        nb_worker=nb_worker,
        pickle_safe=True,  # turn on pickle_safe to avoid a strange error.
        verbose=2)

    # Training report.
    min_loss_locs, = np.where(
        hist.history['val_loss'] == min(hist.history['val_loss']))
    best_val_loss = hist.history['val_loss'][min_loss_locs[0]]
    best_val_sensitivity = hist.history['val_sensitivity'][min_loss_locs[0]]
    best_val_specificity = hist.history['val_specificity'][min_loss_locs[0]]
    print "\n==== Training summary ===="
    print "Minimum val loss achieved at epoch:", min_loss_locs[0] + 1
    print "Best val loss:", best_val_loss
    print "Best val sensitivity:", best_val_sensitivity
    print "Best val specificity:", best_val_specificity

    if final_model != "NOSAVE":
        model.save(final_model)

    return hist
コード例 #2
0
from meta import DMMetaManager
meta_man = DMMetaManager(img_folder='preprocessedData/png_288x224/',
                         img_extension='png')
exam_list = meta_man.get_flatten_exam_list()
img_list = meta_man.get_flatten_img_list()
from dm_image import DMImageDataGenerator
img_gen = DMImageDataGenerator(featurewise_center=True,
                               featurewise_std_normalization=True)
img_gen.mean = 7772.
img_gen.std = 12187.
datgen_exam = img_gen.flow_from_exam_list(exam_list,
                                          target_size=(288, 224),
                                          batch_size=8,
                                          shuffle=False,
                                          seed=123)
datgen_image = img_gen.flow_from_img_list(img_list[0],
                                          img_list[1],
                                          target_size=(288, 224),
                                          batch_size=32,
                                          shuffle=False,
                                          seed=123)
import numpy as np
コード例 #3
0
def run(img_folder,
        img_extension='png',
        img_size=[288, 224],
        multi_view=False,
        do_featurewise_norm=True,
        featurewise_mean=7772.,
        featurewise_std=12187.,
        batch_size=16,
        samples_per_epoch=160,
        nb_epoch=20,
        val_size=.2,
        balance_classes=0.,
        all_neg_skip=False,
        pos_cls_weight=1.0,
        alpha=1.,
        l1_ratio=.5,
        init_lr=.01,
        lr_patience=2,
        es_patience=4,
        exam_tsv='./metadata/exams_metadata.tsv',
        img_tsv='./metadata/images_crosswalk.tsv',
        dl_state='./modelState/resnet50_288_best_model.h5',
        best_model='./modelState/enet_288_best_model.h5',
        final_model="NOSAVE"):

    # Read some env variables.
    random_seed = int(os.getenv('RANDOM_SEED', 12345))
    nb_worker = int(os.getenv('NUM_CPU_CORES', 4))
    gpu_count = int(os.getenv('NUM_GPU_DEVICES', 1))

    # Setup training and validation data.
    meta_man = DMMetaManager(exam_tsv=exam_tsv,
                             img_tsv=img_tsv,
                             img_folder=img_folder,
                             img_extension=img_extension)

    if multi_view:
        exam_list = meta_man.get_flatten_exam_list()
        exam_train, exam_val = train_test_split(
            exam_list,
            test_size=val_size,
            random_state=random_seed,
            stratify=meta_man.exam_labs(exam_list))
        val_size_ = len(exam_val) * 2  # L and R.
    else:
        img_list, lab_list = meta_man.get_flatten_img_list()
        img_train, img_val, lab_train, lab_val = train_test_split(
            img_list,
            lab_list,
            test_size=val_size,
            random_state=random_seed,
            stratify=lab_list)
        val_size_ = len(img_val)

    img_gen = DMImageDataGenerator(horizontal_flip=True, vertical_flip=True)
    if do_featurewise_norm:
        img_gen.featurewise_center = True
        img_gen.featurewise_std_normalization = True
        img_gen.mean = featurewise_mean
        img_gen.std = featurewise_std
    else:
        img_gen.samplewise_center = True
        img_gen.samplewise_std_normalization = True

    if multi_view:
        train_generator = img_gen.flow_from_exam_list(
            exam_train,
            target_size=(img_size[0], img_size[1]),
            batch_size=batch_size,
            balance_classes=balance_classes,
            all_neg_skip=all_neg_skip,
            shuffle=True,
            seed=random_seed,
            class_mode='binary')
        val_generator = img_gen.flow_from_exam_list(exam_val,
                                                    target_size=(img_size[0],
                                                                 img_size[1]),
                                                    batch_size=batch_size,
                                                    validation_mode=True,
                                                    class_mode='binary')
    else:
        train_generator = img_gen.flow_from_img_list(
            img_train,
            lab_train,
            target_size=(img_size[0], img_size[1]),
            batch_size=batch_size,
            balance_classes=balance_classes,
            all_neg_skip=all_neg_skip,
            shuffle=True,
            seed=random_seed,
            class_mode='binary')
        val_generator = img_gen.flow_from_img_list(img_val,
                                                   lab_val,
                                                   target_size=(img_size[0],
                                                                img_size[1]),
                                                   batch_size=batch_size,
                                                   validation_mode=True,
                                                   class_mode='binary')

    # Deep learning model.
    dl_model = load_model(dl_state,
                          custom_objects={
                              'sensitivity': DMMetrics.sensitivity,
                              'specificity': DMMetrics.specificity
                          })
    # Dummy compilation to turn off the "uncompiled" error when model was run on multi-GPUs.
    # dl_model.compile(optimizer='sgd', loss='binary_crossentropy')
    reprlayer_model = Model(input=dl_model.input,
                            output=dl_model.get_layer(index=-2).output)
    if gpu_count > 1:
        reprlayer_model = make_parallel(reprlayer_model, gpu_count)

    # Setup test data in RAM.
    X_test, y_test = dlrepr_generator(reprlayer_model, val_generator,
                                      val_size_)
    # import pdb; pdb.set_trace()

    # Evaluat DL model on the test data.
    val_generator.reset()
    dl_test_pred = dl_model.predict_generator(val_generator,
                                              val_samples=val_size_,
                                              nb_worker=1,
                                              pickle_safe=False)
    # Set nb_worker to >1 can cause:
    # either inconsistent result when pickle_safe is False,
    #     or broadcasting error when pickle_safe is True.
    # This seems to be a Keras bug!!
    # Further note: the broadcasting error may only happen when val_size_
    # is not divisible by batch_size.
    try:
        dl_auc = roc_auc_score(y_test, dl_test_pred)
        dl_loss = log_loss(y_test, dl_test_pred)
    except ValueError:
        dl_auc = 0.
        dl_loss = np.inf
    print "\nAUROC by the DL model: %.4f, loss: %.4f" % (dl_auc, dl_loss)
    # import pdb; pdb.set_trace()

    # Elastic net training.
    target_classes = np.array([0, 1])
    sgd_clf = SGDClassifier(loss='log',
                            penalty='elasticnet',
                            alpha=alpha,
                            l1_ratio=l1_ratio,
                            verbose=0,
                            n_jobs=nb_worker,
                            learning_rate='constant',
                            eta0=init_lr,
                            random_state=random_seed,
                            class_weight={
                                0: 1.0,
                                1: pos_cls_weight
                            })
    curr_lr = init_lr
    best_epoch = 0
    best_auc = 0.
    min_loss = np.inf
    min_loss_epoch = 0
    for epoch in xrange(nb_epoch):
        samples_seen = 0
        X_list = []
        y_list = []
        epoch_start = time.time()
        while samples_seen < samples_per_epoch:
            X, y = next(train_generator)
            X_repr = reprlayer_model.predict_on_batch(X)
            sgd_clf.partial_fit(X_repr, y, classes=target_classes)
            samples_seen += len(y)
            X_list.append(X_repr)
            y_list.append(y)
        # The training X, y are expected to change for each epoch due to
        # image random sampling and class balancing.
        X_train_epo = np.concatenate(X_list)
        y_train_epo = np.concatenate(y_list)
        # End of epoch summary.
        pred_prob = sgd_clf.predict_proba(X_test)[:, 1]
        train_prob = sgd_clf.predict_proba(X_train_epo)[:, 1]
        try:
            auc = roc_auc_score(y_test, pred_prob)
            crossentropy_loss = log_loss(y_test, pred_prob)
        except ValueError:
            auc = 0.
            crossentropy_loss = np.inf
        try:
            train_loss = log_loss(y_train_epo, train_prob)
        except ValueError:
            train_loss = np.inf
        wei_sparseness = np.mean(sgd_clf.coef_ == 0)
        epoch_span = time.time() - epoch_start
        print ("%ds - Epoch=%d, auc=%.4f, train_loss=%.4f, test_loss=%.4f, "
               "weight sparsity=%.4f") % \
            (epoch_span, epoch + 1, auc, train_loss, crossentropy_loss,
             wei_sparseness)
        # Model checkpoint, reducing learning rate and early stopping.
        if auc > best_auc:
            best_epoch = epoch + 1
            best_auc = auc
            if best_model != "NOSAVE":
                with open(best_model, 'w') as best_state:
                    pickle.dump(sgd_clf, best_state)
        if crossentropy_loss < min_loss:
            min_loss = crossentropy_loss
            min_loss_epoch = epoch + 1
        else:
            if epoch + 1 - min_loss_epoch >= es_patience:
                print 'Early stopping criterion has reached. Stop training.'
                break
            if epoch + 1 - min_loss_epoch >= lr_patience:
                curr_lr *= .1
                sgd_clf.set_params(eta0=curr_lr)
                print "Reducing learning rate to: %s" % (curr_lr)
    # End of training summary
    print ">>> Found best AUROC: %.4f at epoch: %d, saved to: %s <<<" % \
        (best_auc, best_epoch, best_model)
    print ">>> Found best val loss: %.4f at epoch: %d. <<<" % \
        (min_loss, min_loss_epoch)
    #### Save elastic net model!! ####
    if final_model != "NOSAVE":
        with open(final_model, 'w') as final_state:
            pickle.dump(sgd_clf, final_state)