Ejemplo n.º 1
0
def RetModel()->Tuple[Model,Model,Model,Model,Model,Model,Model]:
    '''
        To get all training model
        returns:
        load vgg,resnet and densenet model
        '''
    return [VGG16(input_shape=(256,256,7),classes=2),VGG19(input_shape=(256,256,7),classes=2),
            ResNet50(input_shape=(256, 256, 7), classes=2),InceptionResNetV2(input_shape=(256, 256, 7), classes=2),
            DenseNet121(input_shape=(256, 256, 7),classes=2),DenseNet169(input_shape=(256, 256, 7), classes=2),
            DenseNet201(input_shape=(256, 256, 7),classes=2)]
Ejemplo n.º 2
0
def RetModel():
    return [
        VGG16(input_shape=(256, 256, 7), classes=2),
        VGG19(input_shape=(256, 256, 7), classes=2),
        ResNet50(input_shape=(256, 256, 7), classes=2),
        InceptionResNetV2(input_shape=(256, 256, 7), classes=2),
        DenseNet121(input_shape=(256, 256, 7), classes=2),
        DenseNet169(input_shape=(256, 256, 7), classes=2),
        DenseNet201(input_shape=(256, 256, 7), classes=2)
    ]
def main():
    """Create the model and start the training."""
    args = get_arguments()


    # Create queue coordinator.
    coord = tf.train.Coordinator()

    train, valid, valid_studies = load_dataframes(DATA_DIR = args.data_dir)
    _, _, valid_studies_df = get_body_part_dataframes(train, valid, valid_studies, args.bpart)

    valid_studies_df_list = read_labeled_image_list(valid_studies_df)
    valid_studies_path = valid_studies_df_list[0]
    valid_studies_label = valid_studies_df_list[1]
    number_of_validation_studies = len(valid_studies_df_list[1])

    print("\nNumber of validation studies for %s dataset:"%args.bpart, number_of_validation_studies)

    image = tf.placeholder(tf.float32, [None, 320, 320, 3])

    # Create network  with weights initialized from densenet_169 pretrained on ImageNet
    net = DenseNet169(args.weights_path)

    # Predictions
    prob = net.build(inputs=image, is_training=False)
    prob = tf.reshape(prob, [-1])

    all_variables = tf.all_variables()
    #config = tf.ConfigProto(intra_op_parallelism_threads=8, inter_op_parallelism_threads=8)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    sess.run(init)

    # Load weights
    saver = tf.train.Saver(var_list=all_variables)
    #load(saver, sess, args.restore_from)

    # Start queue threads
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    probabilities = np.zeros(number_of_validation_studies)
    predictions =np.zeros(number_of_validation_studies, dtype=int)
    previous_study_type = ""

    for i in tqdm(range(number_of_validation_studies),  desc='Evaluation'):
        img_list = [f for f in os.listdir(valid_studies_path[i]) if not f.startswith(".")]
        num_img = len(img_list)
        pred_study = np.zeros(num_img)
        for j in range(num_img):
            img_path = valid_studies_path[i]+img_list[j] # eg. of path: 'MURA-v1.1/valid/XR_ELBOW/patient99999/study1_positive/image1.png'
            study_type = img_path.split("XR_")[1]  # Extract the study type in path between "XR_" and "/patient"
            study_type = study_type.split("/patient")[0]
            restore_from = args.restore_from+MODELS[study_type]
            if study_type != previous_study_type:
                load(saver, sess, restore_from)
            img_contents = tf.read_file(img_path)
            img = tf.image.decode_png(img_contents, channels=3)
            img = valid_transforms(img, study_type.lower()) # Normalize each model's inputs with the same statistics it has been trained on.
            img = tf.expand_dims(img, axis=0)
            img_arr = sess.run(img)
            feed_dict = {image: img_arr}
            pred_img = sess.run(prob, feed_dict=feed_dict)
            pred_study[j] = pred_img[0]
            previous_study_type = study_type
            #print('{:.4f}'.format(pred_study[j]))
        pred_study_mean = np.mean(pred_study)
        if pred_study_mean >0.5:
            predictions[i] = 1
        else:
            predictions[i] = 0
        probabilities[i] = pred_study_mean

    labels = tf.convert_to_tensor(valid_studies_label, dtype=tf.int32)
    predictions = tf.convert_to_tensor(predictions, dtype=tf.int32)
    probabilities = tf.convert_to_tensor(probabilities, dtype=tf.float32)

    # Define metrics
    confusion_matrix = tf.confusion_matrix(labels=labels, predictions=predictions)
    accuracy = tf.contrib.metrics.accuracy(labels=labels, predictions=predictions)
    auc, auc_update_op = tf.metrics.auc(labels=labels, predictions=probabilities)
    recall, recall_update_op = tf.metrics.recall(labels=labels, predictions=predictions)
    kappa, kappa_op = tf.contrib.metrics.cohen_kappa(labels=labels, predictions_idx=predictions, num_classes=2)

    #config = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
    config = tf.ConfigProto()
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())

    kappa_val, kappa_op_val, probs, testy, confusion_matrix_val, accuracy_val, auc_val, auc_op, recall_val, recall_op = sess.run([kappa, kappa_op, probabilities, labels, confusion_matrix, accuracy, auc, auc_update_op, recall, recall_update_op])

    print('\nConfusion matrix:\n', confusion_matrix_val)
    print('\nArea under the ROC curve:', auc_op)
    print("\nRecall:", recall_op)
    print("\nAccuracy:", accuracy_val)
    print("\nCohen's kappa:", kappa_op_val)


    # Plot the roc curve
    fpr, tpr, thresholds = roc_curve(testy, probs)
    plt.plot([0, 1], [0, 1], linestyle='--')
    plt.plot(fpr, tpr, marker='.')
    plt.show()

    coord.request_stop()
    coord.join(threads)
Ejemplo n.º 4
0
def main():
    """Create the model and start the training."""
    args = get_arguments()

    # Create queue coordinator.
    coord = tf.train.Coordinator()

    train, valid, valid_studies = load_dataframes(DATA_DIR=args.data_dir)
    train_df, valid_df, valid_studies_df = get_body_part_dataframes(
        train, valid, valid_studies, args.bpart)

    train_df_list = read_labeled_image_list(
        train_df)  # Returns a tuple (train_path_list, train_label_list)
    valid_df_list = read_labeled_image_list(valid_df)
    number_of_training_images = len(train_df_list[1])  # Numer of labels
    number_of_validation_images = len(valid_df_list[1])
    NUM_STEPS = args.num_epochs * number_of_training_images // args.batch_size
    VALIDATION_STEPS = 5  #number_of_validation_images #// args.batch_size
    EVALUATE_EVERY = 10  #number_of_training_images // args.batch_size # Evaluate every epoch
    A_train = sum(train_df_list[1]
                  )  # Number of abnormals examples in the training dataset
    N_train = number_of_training_images - A_train  # Number of normal examples in the training dataset
    wT1 = N_train / (A_train + N_train)
    wT0 = A_train / (A_train + N_train)
    A_valid = sum(valid_df_list[1]
                  )  # Number of abnormal examples in the validation dataset
    N_valid = number_of_validation_images - A_valid  # Number of normal examples in the validation dataset
    df = pd.DataFrame([[A_train, A_valid, wT1], [N_train, N_valid, wT0],
                       [A_train + N_train, A_valid + N_valid, wT0 + wT1]],
                      index=["Abnormal", "Normal", "Total"],
                      columns=["Train", "Valid", "Loss weights"])
    print("\n%s dataset summary: \n " % args.bpart)
    print(df)
    print("\n")

    # Load reader.
    with tf.name_scope("Inputs"):
        reader = ImageReader(train_df, valid_df, args.bpart)
        image_batch, label_batch = reader.dequeue_train(args.batch_size)
        val_image_batch, val_label_batch = reader.dequeue_val(1)

    # Create network  with weights initialized from DenseNet169 pretrained on ImageNet
    net = DenseNet169(args.weights_path)

    # Define loss and accuracy
    loss = net.weighted_cross_entropy_loss(image_batch,
                                           label_batch,
                                           w0=wT0,
                                           w1=wT1,
                                           is_training=True,
                                           scope='train_loss')
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    acc = net.accuracy(image_batch,
                       label_batch,
                       is_training=True,
                       scope='train_accuracy')

    # Define summaries for TensorBoard visualization
    loss_summary = tf.summary.scalar('training loss', loss)
    val_image_summary = tf.summary.image('validation input', val_image_batch)

    # Optimization ops
    learning_rate = tf.constant(args.learning_rate)
    optimiser = tf.train.AdamOptimizer(learning_rate=learning_rate)
    trainable_variables = tf.trainable_variables()
    all_variables = tf.all_variables()
    with tf.control_dependencies(update_ops):
        optim = optimiser.minimize(loss, var_list=trainable_variables)

    # Track performance on the validation set during training
    val_loss = net.weighted_cross_entropy_loss(val_image_batch,
                                               val_label_batch,
                                               w0=wT0,
                                               w1=wT1,
                                               is_training=False,
                                               scope='Validation_loss')
    val_acc = net.accuracy(val_image_batch,
                           val_label_batch,
                           is_training=False,
                           scope='Validation_accuracy')

    #config = tf.ConfigProto()
    config = tf.ConfigProto(intra_op_parallelism_threads=1,
                            inter_op_parallelism_threads=1)
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    if os.path.exists(args.summaries_dir + "%s" % args.bpart):
        shutil.rmtree(args.summaries_dir + "%s" % args.bpart)

    train_writer = tf.summary.FileWriter(
        args.summaries_dir + "%s" % args.bpart + "/train", sess.graph)
    val_writer = tf.summary.FileWriter(args.summaries_dir + "%s" % args.bpart +
                                       "/val")

    init_g = tf.global_variables_initializer()
    init_l = tf.local_variables_initializer()
    sess.run([init_g, init_l])

    # Saver for storing the last 40 checkpoints of the model.
    saver = tf.train.Saver(var_list=all_variables, max_to_keep=40)
    if args.restore_from is not None:
        load(saver, sess, args.restore_from)

    # Start queue threads.
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    # Iterate over training steps.
    starting_time = time.asctime(time.localtime())

    for step in range(NUM_STEPS + 1):
        start_time = time.time()

        if step % EVALUATE_EVERY == 0:
            # Calculate the validation loss and accuracy over the whole validation set
            val_loss_list = []
            val_acc_list = []

            for i in tqdm(range(VALIDATION_STEPS), desc="Validation"):
                val_image_summary_value, val_loss_i, val_acc_i = sess.run(
                    [val_image_summary, val_loss, val_acc])
                val_loss_list.append(val_loss_i)
                val_acc_list.append(val_acc_i)
                val_writer.add_summary(val_image_summary_value, step)
            val_loss_mean = np.mean(
                val_loss_list)  # validation loss of the whole validation data
            val_acc_mean = np.mean(val_acc_list)

            # Reduce the learning rate if the valiatiation loss plateaus after one epoch
            if step > EVALUATE_EVERY:
                if val_loss_mean >= previous_val_loss:
                    learning_rate = tf.divide(learning_rate, 10.0)
                    print("Reducing the learning rate\n")
            previous_val_loss = val_loss_mean

            save(saver, sess, args.snapshot_dir + "%s" % args.bpart, step)

            summary = tf.Summary()
            summary.value.add(tag='validation loss',
                              simple_value=val_loss_mean)
            summary.value.add(tag='validation accuracy',
                              simple_value=val_acc_mean)
            val_writer.add_summary(summary, step)

            duration = time.time() - start_time
            print(
                "\nSTEP {:d}/{:d} VALIDATION LOSS = {:.4f}, \t ACC = {:.4f},  \t ({:.3f} sec/step)"
                .format(step, NUM_STEPS, val_loss_mean, val_acc_mean,
                        duration))
        else:
            loss_summary_value, loss_value, acc_value, lr, _ = sess.run(
                [loss_summary, loss, acc, learning_rate, optim])
            duration = time.time() - start_time
            train_writer.add_summary(loss_summary_value, step)
            print(
                "step {:d}/{:d} \t loss = {:.4f}, \t acc = {:.4f},\t lr = {:.1e},  \t ({:.3f} sec/step)"
                .format(step, NUM_STEPS, loss_value, acc_value,
                        Decimal(lr.item()), duration))

    end_time = time.asctime(time.localtime())

    coord.request_stop()
    coord.join(threads)