Ejemplo n.º 1
0
def keras_model_fn(hyperparameters):
    '''
    hyperparameters: The hyperparameters passed to SageMaker TrainingJob that runs your TensorFlow training
        script. You can use this to pass hyperparameters to your training script.
    '''
    # Logic to do the following:
    # 1. Instantiate the Keras model
    # 2. Compile the keras model

    config = hyperparameters
    model = get_model_and_load_weights(config)
    compile_model(model, config)

    return model
Ejemplo n.º 2
0
def manip(args, test_list, u_model):
    if args.test_weights_path == '':
        weights_path = os.path.join(args.check_dir, args.output_name + '_model_' + args.time + '.hdf5')
    else:
        weights_path = os.path.join(args.data_root_dir, args.test_weights_path)

    output_dir = os.path.join(args.data_root_dir, 'results', args.net)
    manip_out_dir = os.path.join(output_dir, 'manip_output')
    try:
        safe_mkdir(manip_out_dir)
    except:
        pass

    # Compile the loaded model
    manip_model = compile_model(args=args, uncomp_model=u_model)

    try:
        manip_model.load_weights(weights_path)
    except:
        raise NotImplementedError('Unable to find weights path.')


    # Manipulating capsule vectors
    print('Testing... This will take some time...')

    for i, img in enumerate(tqdm(test_list)):
        sitk_img = sitk.ReadImage(os.path.join(args.data_root_dir, 'imgs', img[0]))
        img_data = sitk.GetArrayFromImage(sitk_img)
        num_slices = img_data.shape[0]
        sitk_mask = sitk.ReadImage(os.path.join(args.data_root_dir, 'masks', img[0]))
        gt_data = sitk.GetArrayFromImage(sitk_mask)

        x, y = img_data[num_slices//2, :, :], gt_data[num_slices//2, :, :]
        x, y = np.expand_dims(np.expand_dims(x, -1), 0), np.expand_dims(np.expand_dims(y, -1), 0)

        noise = np.zeros([1, 512, 512, 1, 16])
        x_recons = []
        for dim in trange(16):
            for r in [-0.25, -0.125, 0, 0.125, 0.25]:
                tmp = np.copy(noise)
                tmp[:, :, :, :, dim] = r
                x_recon = manip_model.predict([x, y, tmp])
                x_recons.append(x_recon)

        x_recons = np.concatenate(x_recons)

        out_img = combine_images(x_recons, height=16)
        out_image = out_img * 4096
        out_image[out_image > 574] = 574
        out_image = out_image / 574 * 255

        Image.fromarray(out_image.astype(np.uint8)).save(os.path.join(manip_out_dir, img[0][:-4] + '_manip_output.png'))

    print('Done.')
Ejemplo n.º 3
0
def test(args, u_model, val_samples, val_shape, test_samples, test_shape):
    out_dir = os.path.join(args.root_dir, 'predictions', args.exp_name,
                           args.net)
    try:
        safe_mkdir(out_dir)
    except:
        pass

    # Compile the loaded model
    model = compile_model(args=args, uncomp_model=u_model)

    # Load testing weights
    if args.test_weights_path != '':
        try:
            model.load_weights(args.test_weights_path)
            output_filename = os.path.join(
                out_dir,
                os.path.basename(args.test_weights_path)[:-5] + '.csv')
        except Exception as e:
            print(e)
            raise NotImplementedError('Failed to load weights file in test.py')
    else:
        try:
            model.load_weights(
                os.path.join(
                    args.check_dir,
                    args.output_name + '_model_' + args.time + '.hdf5'))
            output_filename = os.path.join(
                out_dir, args.output_name + '_model_' + args.time + '.csv')
        except Exception as e:
            print(e)
            raise NotImplementedError('Failed to load weights from training.')

    test_datagen = ImageDataGenerator(samplewise_center=True,
                                      samplewise_std_normalization=True,
                                      rescale=1. / 255)

    if args.thresh_level == 0.:
        thresh = 0.5
    else:
        thresh = args.thresh_level

    # TESTING SECTION
    y_true_test = []

    def test_data_gen(gen):
        while True:
            batch = gen.next()
            y_true_test.append(batch[1][0])
            x_batch = np.copy(batch[0])
            for i, x in enumerate(batch[0]):
                x2 = np.copy(x)
                x2 = x2 + abs(np.min(x2))
                x2 /= (np.max(x2) + 1e-7)
                x2 = (x2 - 0.5) * 2.
                x_batch[i, ...] = x2
            yield x_batch

    test_flow_gen = test_datagen.flow_from_directory(os.path.join(
        args.img_dir, 'test'),
                                                     target_size=test_shape,
                                                     class_mode='binary',
                                                     batch_size=1,
                                                     seed=12,
                                                     shuffle=False)

    filenames = np.asarray(test_flow_gen.filenames)
    test_flow_gen.reset()
    test_gen = test_data_gen(test_flow_gen)
    results = model.predict_generator(test_gen,
                                      max_queue_size=1,
                                      workers=1,
                                      use_multiprocessing=False,
                                      steps=test_samples,
                                      verbose=args.verbose)
    if args.net.find('caps') != -1:
        test_scores = results[0]
        reconstructions = results[1]
    else:
        test_scores = results

    polyp_ids = []
    for f in tqdm(filenames, desc='Loading filenames'):
        temp = os.path.basename(f).split('_')
        try:
            polyp_ids.append('m_{}_{}'.format(
                os.path.dirname(f)[2:], temp[1][:6]))
        except:
            polyp_ids.append('m_{}_{}'.format(
                os.path.dirname(f)[2:], temp[0][:6]))

    unique_polyp_results_ALL = []
    unique_polyp_results_NBI = []
    unique_polyp_results_NBIF = []
    unique_polyp_results_NBIN = []
    unique_polyp_results_WL = []
    unique_polyp_results_WLF = []
    unique_polyp_results_WLN = []
    unique_polyp_results_NEAR = []
    unique_polyp_results_FAR = []
    unique_polyp_labels = []
    unique_polyp_names = []
    counts = Counter(polyp_ids)
    for s, num in tqdm(counts.items(), desc='Computing Scores'):
        current_polyp_results_ALL = []
        current_polyp_results_NBI = []
        current_polyp_results_NBIF = []
        current_polyp_results_NBIN = []
        current_polyp_results_WL = []
        current_polyp_results_WLF = []
        current_polyp_results_WLN = []
        current_polyp_results_NEAR = []
        current_polyp_results_FAR = []
        current_polyp_name = s
        for _ in range(1, num + 1):  # loop over all images of same polyp
            pos = polyp_ids.index(s)
            current_image_score = test_scores[pos][0]
            current_polyp_results_ALL.append(current_image_score)
            current_filename = os.path.basename(filenames[pos])
            split_name = current_filename.split('-')
            if len(split_name) < 4:
                print('Encountered improperly named image. Please fix: {}.'.
                      format(current_filename))
                continue
            if split_name[3] == 'NBI':
                current_polyp_results_NBI.append(current_image_score)
            elif split_name[3] == 'NBIF':
                current_polyp_results_NBIF.append(current_image_score)
                current_polyp_results_NBI.append(current_image_score)
                current_polyp_results_FAR.append(current_image_score)
            elif split_name[3] == 'NBIN':
                current_polyp_results_NBIN.append(current_image_score)
                current_polyp_results_NBI.append(current_image_score)
                current_polyp_results_NEAR.append(current_image_score)
            elif split_name[3] == 'WL':
                current_polyp_results_WL.append(current_image_score)
            elif split_name[3] == 'WLF':
                current_polyp_results_WLF.append(current_image_score)
                current_polyp_results_WL.append(current_image_score)
                current_polyp_results_FAR.append(current_image_score)
            elif split_name[3] == 'WLN':
                current_polyp_results_WLN.append(current_image_score)
                current_polyp_results_WL.append(current_image_score)
                current_polyp_results_NEAR.append(current_image_score)
            else:
                Warning('Encountered unexpected imaging type: {}.'.format(
                    split_name[3]))
            polyp_ids[pos] = s + '_c'  # mark the image as seen

        unique_polyp_names.append(current_polyp_name)
        unique_polyp_results_ALL.append(
            np.mean(np.asarray(current_polyp_results_ALL)))

        if current_polyp_results_NBI:
            unique_polyp_results_NBI.append(
                np.mean(np.asarray(current_polyp_results_NBI)))
        else:
            unique_polyp_results_NBI.append(np.nan)

        if current_polyp_results_NBIF:
            unique_polyp_results_NBIF.append(
                np.mean(np.asarray(current_polyp_results_NBIF)))
        else:
            unique_polyp_results_NBIF.append(np.nan)

        if current_polyp_results_NBIN:
            unique_polyp_results_NBIN.append(
                np.mean(np.asarray(current_polyp_results_NBIN)))
        else:
            unique_polyp_results_NBIN.append(np.nan)

        if current_polyp_results_WL:
            unique_polyp_results_WL.append(
                np.mean(np.asarray(current_polyp_results_WL)))
        else:
            unique_polyp_results_WL.append(np.nan)

        if current_polyp_results_WLF:
            unique_polyp_results_WLF.append(
                np.mean(np.asarray(current_polyp_results_WLF)))
        else:
            unique_polyp_results_WLF.append(np.nan)

        if current_polyp_results_WLN:
            unique_polyp_results_WLN.append(
                np.mean(np.asarray(current_polyp_results_WLN)))
        else:
            unique_polyp_results_WLN.append(np.nan)

        if current_polyp_results_NEAR:
            unique_polyp_results_NEAR.append(
                np.mean(np.asarray(current_polyp_results_NEAR)))
        else:
            unique_polyp_results_NEAR.append(np.nan)

        if current_polyp_results_FAR:
            unique_polyp_results_FAR.append(
                np.mean(np.asarray(current_polyp_results_FAR)))
        else:
            unique_polyp_results_FAR.append(np.nan)

    unique_polyp_labels = np.asarray(unique_polyp_labels)
    warnings.filterwarnings("ignore")
    predictions_IMAGES = np.where(test_scores > thresh, 1., 0.)
    predictions_ALL = np.where(unique_polyp_results_ALL > thresh, 1., 0.)
    try:
        predictions_ALL[np.argwhere(
            np.isnan(unique_polyp_results_ALL))] = np.nan
    except:
        predictions_ALL = np.asarray(unique_polyp_results_ALL)
    predictions_NBI = np.where(unique_polyp_results_NBI > thresh, 1., 0.)
    try:
        predictions_NBI[np.argwhere(
            np.isnan(unique_polyp_results_NBI))] = np.nan
    except:
        predictions_NBI = np.asarray(unique_polyp_results_NBI)
    predictions_NBIF = np.where(unique_polyp_results_NBIF > thresh, 1., 0.)
    try:
        predictions_NBIF[np.argwhere(
            np.isnan(unique_polyp_results_NBIF))] = np.nan
    except:
        predictions_NBIF = np.asarray(unique_polyp_results_NBIF)
    predictions_NBIN = np.where(unique_polyp_results_NBIN > thresh, 1., 0.)
    try:
        predictions_NBIN[np.argwhere(
            np.isnan(unique_polyp_results_NBIN))] = np.nan
    except:
        predictions_NBIN = np.asarray(unique_polyp_results_NBIN)
    predictions_WL = np.where(unique_polyp_results_WL > thresh, 1., 0.)
    try:
        predictions_WL[np.argwhere(np.isnan(unique_polyp_results_WL))] = np.nan
    except:
        predictions_WL = np.asarray(unique_polyp_results_WL)
    predictions_WLF = np.where(unique_polyp_results_WLF > thresh, 1., 0.)
    try:
        predictions_WLF[np.argwhere(
            np.isnan(unique_polyp_results_WLF))] = np.nan
    except:
        predictions_WLF = np.asarray(unique_polyp_results_WLF)
    predictions_WLN = np.where(unique_polyp_results_WLN > thresh, 1., 0.)
    try:
        predictions_WLN[np.argwhere(
            np.isnan(unique_polyp_results_WLN))] = np.nan
    except:
        predictions_WLN = np.asarray(unique_polyp_results_WLN)
    predictions_NEAR = np.where(unique_polyp_results_NEAR > thresh, 1., 0.)
    try:
        predictions_NEAR[np.argwhere(
            np.isnan(unique_polyp_results_NEAR))] = np.nan
    except:
        predictions_NEAR = np.asarray(unique_polyp_results_NEAR)
    predictions_FAR = np.where(unique_polyp_results_FAR > thresh, 1., 0.)
    try:
        predictions_FAR[np.argwhere(
            np.isnan(unique_polyp_results_FAR))] = np.nan
    except:
        predictions_FAR = np.asarray(unique_polyp_results_FAR)
    warnings.resetwarnings()

    np.savetxt(output_filename,
               np.stack([
                   predictions_IMAGES,
                   np.squeeze(predictions_ALL[np.argwhere(
                       np.isfinite(unique_polyp_results_ALL))]),
                   np.squeeze(predictions_NBI[np.argwhere(
                       np.isfinite(unique_polyp_results_NBI))]),
                   np.squeeze(predictions_NBIF[np.argwhere(
                       np.isfinite(unique_polyp_results_NBIF))]),
                   np.squeeze(predictions_NBIN[np.argwhere(
                       np.isfinite(unique_polyp_results_NBIN))]),
                   np.squeeze(predictions_WL[np.argwhere(
                       np.isfinite(unique_polyp_results_WL))]),
                   np.squeeze(predictions_WLF[np.argwhere(
                       np.isfinite(unique_polyp_results_WLF))]),
                   np.squeeze(predictions_WLN[np.argwhere(
                       np.isfinite(unique_polyp_results_WLN))]),
                   np.squeeze(predictions_NEAR[np.argwhere(
                       np.isfinite(unique_polyp_results_NEAR))]),
                   np.squeeze(predictions_FAR[np.argwhere(
                       np.isfinite(unique_polyp_results_FAR))])
               ],
                        axis=0),
               delimiter=',')
Ejemplo n.º 4
0
def manip(args, u_model, test_samples):
    out_dir = join(args.data_root_dir, 'results', args.exp_name, args.net)
    try:
        makedirs(out_dir)
    except:
        pass
    out_img_dir = join(out_dir, 'manip_output')
    try:
        makedirs(out_img_dir)
    except:
        pass

    # Compile the loaded model
    model = compile_model(args=args, uncomp_model=u_model)

    # Load testing weights
    if args.test_weights_path != '':
        try:
            model.load_weights(args.test_weights_path)
            out_name = basename(args.test_weights_path)[:-5]
        except Exception as e:
            print(e)
            raise Exception('Failed to load weights from training.')
    else:
        try:
            model.load_weights(
                join(args.check_dir,
                     args.output_name + '_model_' + args.time + '.hdf5'))
            out_name = args.output_name + '_model_' + args.time
        except Exception as e:
            print(e)
            raise Exception('Failed to load weights from training.')

    x_test = normalize_img(
        np.expand_dims(test_samples[0], axis=-1).astype(np.float32))
    if args.num_classes == 1:
        y_test = np.expand_dims(test_samples[2][:, 25],
                                axis=-1)  # 25 should be avg mal score
        y_test[y_test < 3.] = 0.
        y_test[y_test >= 3.] = 1.
    else:
        y_test = to_categorical(np.rint(test_samples[2][:, -2]) - 1)

    print('Creating manipulated outputs.')
    for mal_val in trange(y_test.shape[1]):
        index = np.argmax(y_test, 1) == mal_val
        number = np.random.randint(low=0, high=sum(index) - 1)
        x, y = x_test[index][number], y_test[index][number]
        x, y = np.expand_dims(x, 0), np.expand_dims(y, 0)
        if args.net.find('xcaps') != -1:
            noise = np.zeros([1, 6, 16])
        elif args.net == 'capsnet':
            noise = np.zeros([1, y_test.shape[1], 16])
        else:
            raise NotImplementedError(
                'Specified Network does not have proper implementation in manip.py'
            )
        x_recons = []
        for attr in range(noise.shape[0]):
            for dim in range(16):
                for r in [
                        -0.5, -0.25, -0.2, -0.15, -0.1, -0.05, 0, 0.05, 0.1,
                        0.15, 0.2, 0.25, 0.5
                ]:
                    tmp = np.copy(noise)
                    tmp[attr, :, dim] = r
                    if args.net.find('xcaps') != -1:
                        x_recon = model.predict([x, tmp])
                    elif args.net == 'capsnet':
                        x_recon = model.predict([x, y, tmp])
                    else:
                        raise NotImplementedError(
                            'Specified Network does not have proper implementation in manip.py'
                        )
                    x_recons.append(x_recon[-1])

            x_recons = np.concatenate(x_recons)

            img = combine_images(x_recons, height=16)
            pil_img = Image.fromarray(255 * img).convert('L')
            pil_img.save(
                join(out_img_dir,
                     out_name + '_{}_{}.png'.format(attr, mal_val + 1)))
Ejemplo n.º 5
0
def test(args, u_model, test_samples):
    out_dir = os.path.join(args.data_root_dir, 'results', args.exp_name,
                           args.net)
    safe_mkdir(out_dir)
    out_img_dir = os.path.join(out_dir, 'recons')
    safe_mkdir(out_img_dir)

    # Compile the loaded model
    model = compile_model(args=args, uncomp_model=u_model)

    # Load testing weights
    if args.test_weights_path != '':
        output_filename = os.path.join(
            out_dir, 'results_' +
            os.path.basename(args.test_weights_path)[:-5] + '.csv')
        try:
            model.load_weights(args.test_weights_path)
        except Exception as e:
            print(e)
            raise Exception('Failed to load weights from training.')
    else:
        output_filename = os.path.join(
            out_dir,
            'results_' + args.output_name + '_model_' + args.time + '.csv')
        try:
            model.load_weights(
                os.path.join(
                    args.check_dir,
                    args.output_name + '_model_' + args.time + '.hdf5'))
        except Exception as e:
            print(e)
            raise Exception('Failed to load weights from training.')

    test_datagen = ImageDataGenerator(samplewise_center=False,
                                      samplewise_std_normalization=False,
                                      rescale=None)

    # TESTING SECTION
    def data_gen(gen):
        while True:
            x, y = gen.next()
            yield x, y

    x_test = normalize_img(
        np.expand_dims(test_samples[0], axis=-1).astype(np.float32))

    if args.num_classes == 1:
        y_test = np.rint(test_samples[2][:, -2])
    else:
        y_test = get_pseudo_label([1., 2., 3., 4., 5.], test_samples[2][:, -2],
                                  test_samples[2][:, -1])

    test_gen = data_gen(
        test_datagen.flow(x=x_test,
                          y=y_test,
                          batch_size=1,
                          shuffle=False,
                          seed=12))

    results = model.predict_generator(test_gen,
                                      max_queue_size=1,
                                      workers=1,
                                      use_multiprocessing=False,
                                      steps=len(x_test),
                                      verbose=1)

    if args.net.find('caps') != -1:
        y_pred = results[0]
        x_recon = results[-1]

        img = combine_images(np.concatenate([x_test[:250:5], x_recon[:250:5]]))
        pil_img = Image.fromarray(255 * img).convert('L')
        if args.test_weights_path != '':
            img_filename = os.path.join(
                out_img_dir,
                os.path.basename(args.test_weights_path)[:-5] +
                '_real_and_recon.png')
        else:
            img_filename = os.path.join(
                out_img_dir, args.output_name + '_model_' + args.time +
                '_real_and_recon.png')
        pil_img.save(os.path.join(out_img_dir, img_filename))
    else:
        y_pred = results

    if args.num_classes == 1:
        gt = y_test
        pred = np.squeeze(np.rint(y_pred * 4 + 1))
    else:
        gt = np.argmax(y_test, axis=1) + 1
        pred = np.argmax(y_pred, axis=1) + 1

    if args.num_classes == 1:
        cmat = confusion_matrix(gt, pred, labels=[1, 2, 3, 4, 5])
        test_acc_cat, test_acc_all = compute_within_one_acc(cmat)
        test_acc_cat_weighted, test_acc_all_weighted = np.zeros_like(
            test_acc_cat), np.zeros_like(test_acc_all)
    else:
        cmat = confusion_matrix(gt, pred, labels=[1, 2, 3, 4, 5])
        test_acc_cat, test_acc_all = compute_within_one_acc(cmat)
        cmat_weighted = confusion_matrix(
            gt,
            pred,
            labels=[1, 2, 3, 4, 5],
            sample_weight=1. / np.var([1., 2., 3., 4., 5.] * y_pred, axis=1))
        test_acc_cat_weighted, test_acc_all_weighted = compute_within_one_acc(
            cmat_weighted)

    with open(output_filename, 'w', newline='') as f:
        fw = csv.writer(f, delimiter=',')
        fw.writerow(
            ['Malignancy Accuracy', 'Malignancy Accuracy Confidence Weighted'])
        fw.writerow([
            '{:05f}'.format(test_acc_all),
            '{:05f}'.format(test_acc_all_weighted)
        ])
        fw.writerow(['Malignancy Accuracy by Score:'])
        fw.writerow(['{:05f}'.format(num) for num in test_acc_cat])
        fw.writerow(['Malignancy Accuracy by Score Confidence Weighted:'])
        fw.writerow(['{:05f}'.format(num) for num in test_acc_cat_weighted])
        fw.writerow(['Malignancy Confusion Matrix:'])
        for row in cmat:
            fw.writerow(['{:05f}'.format(num) for num in row])
        if args.net.find('dcaps') != -1 or args.net == 'xcapsnet':
            if args.net.find('simple') != -1 or args.net == 'xcapsnet':
                attr_pred = np.rint(
                    np.swapaxes(np.asarray(results[1:-1]), 0, -1) * 4 + 1)
                y_attr = np.rint(
                    np.concatenate(
                        (np.expand_dims(test_samples[2][:, -18], axis=-1),
                         np.expand_dims(test_samples[2][:, -12], axis=-1),
                         np.expand_dims(test_samples[2][:, -10], axis=-1),
                         np.expand_dims(test_samples[2][:, -8], axis=-1),
                         np.expand_dims(test_samples[2][:, -6], axis=-1),
                         np.expand_dims(test_samples[2][:, -4], axis=-1)),
                        axis=1)).astype(np.int64)
                for i in range(y_attr.shape[1]):
                    attr_cmat = confusion_matrix(y_attr[:, i],
                                                 attr_pred[:, i],
                                                 labels=[1, 2, 3, 4, 5])
                    class_acc, total_acc = compute_within_one_acc(attr_cmat)
                    fw.writerow(['Attribute {} Accuracy'.format(i)])
                    fw.writerow(['{:05f}'.format(total_acc)])
                    fw.writerow(['Attribute {} Accuracy by Score:'.format(i)])
                    fw.writerow(['{:05f}'.format(num) for num in class_acc])
                    fw.writerow(['Attribute {} Confusion Matrix:'.format(i)])
                    for row in attr_cmat:
                        fw.writerow(['{:05f}'.format(num) for num in row])
            else:
                attr_pred = np.argmax(
                    np.rollaxis(np.asarray(results[1:-1]), 0, -1), axis=-1) + 1
                y_attr = np.concatenate(
                    (np.expand_dims(get_pseudo_label([1., 2., 3., 4., 5.],
                                                     test_samples[2][:, -18],
                                                     test_samples[2][:, -17]),
                                    axis=1),
                     np.expand_dims(get_pseudo_label([1., 2., 3., 4., 5.],
                                                     test_samples[2][:, -12],
                                                     test_samples[2][:, -11]),
                                    axis=1),
                     np.expand_dims(get_pseudo_label([1., 2., 3., 4., 5.],
                                                     test_samples[2][:, -10],
                                                     test_samples[2][:, -9]),
                                    axis=1),
                     np.expand_dims(get_pseudo_label([1., 2., 3., 4., 5.],
                                                     test_samples[2][:, -8],
                                                     test_samples[2][:, -7]),
                                    axis=1),
                     np.expand_dims(get_pseudo_label([1., 2., 3., 4., 5.],
                                                     test_samples[2][:, -6],
                                                     test_samples[2][:, -5]),
                                    axis=1),
                     np.expand_dims(get_pseudo_label([1., 2., 3., 4., 5.],
                                                     test_samples[2][:, -4],
                                                     test_samples[2][:, -3]),
                                    axis=1)),
                    axis=1)
                gt_attr = np.argmax(y_attr, axis=2) + 1

                for i in range(gt_attr.shape[1]):
                    attr_cmat = confusion_matrix(gt_attr[:, i],
                                                 attr_pred[:, i],
                                                 labels=[1, 2, 3, 4, 5])
                    attr_cmat_weighted = confusion_matrix(
                        gt_attr[:, i],
                        attr_pred[:, i],
                        labels=[1, 2, 3, 4, 5],
                        sample_weight=1. /
                        np.var([1., 2., 3., 4., 5.] * np.rollaxis(
                            np.asarray(results[1:-1]), 0, -1)[:, i],
                               axis=1))
                    class_acc, total_acc = compute_within_one_acc(attr_cmat)
                    class_acc_weighted, total_acc_weighted = compute_within_one_acc(
                        attr_cmat_weighted)
                    fw.writerow([
                        'Attribute {} Accuracy'.format(i),
                        'Attribute {} Accuracy Confidence Weighted'.format(i)
                    ])
                    fw.writerow([
                        '{:05f}'.format(total_acc),
                        '{:05f}'.format(total_acc_weighted)
                    ])
                    fw.writerow(['Attribute {} Accuracy by Score:'.format(i)])
                    fw.writerow(['{:05f}'.format(num) for num in class_acc])
                    fw.writerow([
                        'Attribute {} Accuracy by Score Confidence Weighted:'.
                        format(i)
                    ])
                    fw.writerow(
                        ['{:05f}'.format(num) for num in class_acc_weighted])
                    fw.writerow(['Attribute {} Confusion Matrix:'.format(i)])
                    for row in attr_cmat:
                        fw.writerow(['{:05f}'.format(num) for num in row])
Ejemplo n.º 6
0
def test(args, u_model, val_samples, val_shape, test_samples, test_shape):
    out_dir = os.path.join(args.root_dir, 'results', args.exp_name, args.net)
    try:
        safe_mkdir(out_dir)
    except:
        pass

    # Compile the loaded model
    model = compile_model(args=args, uncomp_model=u_model)

    # Load testing weights
    if args.test_weights_path != '':
        try:
            model.load_weights(args.test_weights_path)
            output_filename = os.path.join(
                out_dir,
                os.path.basename(args.test_weights_path)[:-5] + '.csv')
        except Exception as e:
            print(e)
            raise NotImplementedError('Failed to load weights file in test.py')
    else:
        try:
            model.load_weights(
                os.path.join(
                    args.check_dir,
                    args.output_name + '_model_' + args.time + '.hdf5'))
            output_filename = os.path.join(
                out_dir, args.output_name + '_model_' + args.time + '.csv')
        except Exception as e:
            print(e)
            raise NotImplementedError('Failed to load weights from training.')

    test_datagen = ImageDataGenerator(samplewise_center=True,
                                      samplewise_std_normalization=True,
                                      rescale=1. / 255)

    # VALIDATION SECTION
    if args.thresh_level == 0.:
        # We use this section to choose a threshold which maximizes the harmonic mean between sensitivity and specificity.
        y_true_val = []

        def val_data_gen(gen):
            while True:
                batch = gen.next()
                y_true_val.append(batch[1][0])
                x_batch = np.copy(batch[0])
                for i, x in enumerate(batch[0]):
                    x2 = np.copy(x)
                    x2 = x2 + abs(np.min(x2))
                    x2 /= (np.max(x2) + 1e-7)
                    x2 = (x2 - 0.5) * 2.
                    x_batch[i, ...] = x2
                yield x_batch

        val_flow_gen = test_datagen.flow_from_directory(os.path.join(
            args.img_dir, 'val'),
                                                        target_size=val_shape,
                                                        class_mode='binary',
                                                        batch_size=1,
                                                        seed=12,
                                                        shuffle=False)

        val_flow_gen.reset()
        val_gen = val_data_gen(val_flow_gen)
        val_results = model.predict_generator(val_gen,
                                              max_queue_size=1,
                                              workers=1,
                                              use_multiprocessing=False,
                                              steps=val_samples,
                                              verbose=args.verbose)
        if args.net.find('caps') != -1:
            val_scores = val_results[0]
            val_reconstructions = val_results[1]
        else:
            val_scores = val_results
        val_y_true = np.asarray(
            y_true_val[:-(len(y_true_val) - len(val_flow_gen.filenames))])
        thresh, [val_acc, val_sen,
                 val_spec] = find_thresh_level(val_scores, val_y_true,
                                               'pseudof1')
    else:
        thresh = args.thresh_level

    # TESTING SECTION
    y_true_test = []

    def test_data_gen(gen):
        while True:
            batch = gen.next()
            y_true_test.append(batch[1][0])
            x_batch = np.copy(batch[0])
            for i, x in enumerate(batch[0]):
                x2 = np.copy(x)
                x2 = x2 + abs(np.min(x2))
                x2 /= (np.max(x2) + 1e-7)
                x2 = (x2 - 0.5) * 2.
                x_batch[i, ...] = x2
            yield x_batch

    test_flow_gen = test_datagen.flow_from_directory(os.path.join(
        args.img_dir, 'test'),
                                                     target_size=test_shape,
                                                     class_mode='binary',
                                                     batch_size=1,
                                                     seed=12,
                                                     shuffle=False)

    filenames = np.asarray(test_flow_gen.filenames)
    test_flow_gen.reset()
    test_gen = test_data_gen(test_flow_gen)
    results = model.predict_generator(test_gen,
                                      max_queue_size=1,
                                      workers=1,
                                      use_multiprocessing=False,
                                      steps=test_samples,
                                      verbose=args.verbose)
    if args.net.find('caps') != -1:
        test_scores = results[0]
        reconstructions = results[1]
    else:
        test_scores = results
    test_y_true = np.asarray(
        y_true_test[:-(len(y_true_test) - len(test_flow_gen.filenames))])

    y_true_check = []
    polyp_ids = []
    for f in tqdm(filenames, desc='Loading filenames'):
        y_true_check.append(f[0])
        temp = os.path.basename(f).split('_')
        try:
            polyp_ids.append('m_{}_{}'.format(
                os.path.dirname(f)[2:], temp[1][:6]))
        except:
            polyp_ids.append('m_{}_{}'.format(
                os.path.dirname(f)[2:], temp[0][:6]))

    y_true_check = np.asarray(y_true_check, dtype=np.float32)
    assert np.array_equal(test_y_true, y_true_check), 'Error: Order of images and labels not preserved! ' \
                                                      'Cannot match images to labels.'

    unique_polyp_results_ALL = []
    unique_polyp_results_NBI = []
    unique_polyp_results_NBIF = []
    unique_polyp_results_NBIN = []
    unique_polyp_results_WL = []
    unique_polyp_results_WLF = []
    unique_polyp_results_WLN = []
    unique_polyp_results_NEAR = []
    unique_polyp_results_FAR = []
    unique_polyp_labels = []
    unique_polyp_names = []
    counts = Counter(polyp_ids)
    for s, num in tqdm(counts.items(), desc='Computing Scores'):
        current_polyp_results_ALL = []
        current_polyp_results_NBI = []
        current_polyp_results_NBIF = []
        current_polyp_results_NBIN = []
        current_polyp_results_WL = []
        current_polyp_results_WLF = []
        current_polyp_results_WLN = []
        current_polyp_results_NEAR = []
        current_polyp_results_FAR = []
        current_polyp_name = s
        current_polyp_label = test_y_true[polyp_ids.index(s)]
        for _ in range(1, num + 1):  # loop over all images of same polyp
            pos = polyp_ids.index(s)
            current_image_score = test_scores[pos][0]
            current_polyp_results_ALL.append(current_image_score)
            current_filename = os.path.basename(filenames[pos])
            split_name = current_filename.split('-')
            if len(split_name) < 4:
                print('Encountered improperly named image. Please fix: {}.'.
                      format(current_filename))
                continue
            if split_name[3] == 'NBI':
                current_polyp_results_NBI.append(current_image_score)
            elif split_name[3] == 'NBIF':
                current_polyp_results_NBIF.append(current_image_score)
                current_polyp_results_NBI.append(current_image_score)
                current_polyp_results_FAR.append(current_image_score)
            elif split_name[3] == 'NBIN':
                current_polyp_results_NBIN.append(current_image_score)
                current_polyp_results_NBI.append(current_image_score)
                current_polyp_results_NEAR.append(current_image_score)
            elif split_name[3] == 'WL':
                current_polyp_results_WL.append(current_image_score)
            elif split_name[3] == 'WLF':
                current_polyp_results_WLF.append(current_image_score)
                current_polyp_results_WL.append(current_image_score)
                current_polyp_results_FAR.append(current_image_score)
            elif split_name[3] == 'WLN':
                current_polyp_results_WLN.append(current_image_score)
                current_polyp_results_WL.append(current_image_score)
                current_polyp_results_NEAR.append(current_image_score)
            else:
                Warning('Encountered unexpected imaging type: {}.'.format(
                    split_name[3]))
            polyp_ids[pos] = s + '_c'  # mark the image as seen

        unique_polyp_names.append(current_polyp_name)
        unique_polyp_labels.append(current_polyp_label)
        unique_polyp_results_ALL.append(
            np.mean(np.asarray(current_polyp_results_ALL)))

        if current_polyp_results_NBI:
            unique_polyp_results_NBI.append(
                np.mean(np.asarray(current_polyp_results_NBI)))
        else:
            unique_polyp_results_NBI.append(np.nan)

        if current_polyp_results_NBIF:
            unique_polyp_results_NBIF.append(
                np.mean(np.asarray(current_polyp_results_NBIF)))
        else:
            unique_polyp_results_NBIF.append(np.nan)

        if current_polyp_results_NBIN:
            unique_polyp_results_NBIN.append(
                np.mean(np.asarray(current_polyp_results_NBIN)))
        else:
            unique_polyp_results_NBIN.append(np.nan)

        if current_polyp_results_WL:
            unique_polyp_results_WL.append(
                np.mean(np.asarray(current_polyp_results_WL)))
        else:
            unique_polyp_results_WL.append(np.nan)

        if current_polyp_results_WLF:
            unique_polyp_results_WLF.append(
                np.mean(np.asarray(current_polyp_results_WLF)))
        else:
            unique_polyp_results_WLF.append(np.nan)

        if current_polyp_results_WLN:
            unique_polyp_results_WLN.append(
                np.mean(np.asarray(current_polyp_results_WLN)))
        else:
            unique_polyp_results_WLN.append(np.nan)

        if current_polyp_results_NEAR:
            unique_polyp_results_NEAR.append(
                np.mean(np.asarray(current_polyp_results_NEAR)))
        else:
            unique_polyp_results_NEAR.append(np.nan)

        if current_polyp_results_FAR:
            unique_polyp_results_FAR.append(
                np.mean(np.asarray(current_polyp_results_FAR)))
        else:
            unique_polyp_results_FAR.append(np.nan)

    unique_polyp_labels = np.asarray(unique_polyp_labels)
    warnings.filterwarnings("ignore")
    predictions_IMAGES = np.where(test_scores > thresh, 1., 0.)
    predictions_ALL = np.where(unique_polyp_results_ALL > thresh, 1., 0.)
    try:
        predictions_ALL[np.argwhere(
            np.isnan(unique_polyp_results_ALL))] = np.nan
    except:
        predictions_ALL = np.asarray(unique_polyp_results_ALL)
    predictions_NBI = np.where(unique_polyp_results_NBI > thresh, 1., 0.)
    try:
        predictions_NBI[np.argwhere(
            np.isnan(unique_polyp_results_NBI))] = np.nan
    except:
        predictions_NBI = np.asarray(unique_polyp_results_NBI)
    predictions_NBIF = np.where(unique_polyp_results_NBIF > thresh, 1., 0.)
    try:
        predictions_NBIF[np.argwhere(
            np.isnan(unique_polyp_results_NBIF))] = np.nan
    except:
        predictions_NBIF = np.asarray(unique_polyp_results_NBIF)
    predictions_NBIN = np.where(unique_polyp_results_NBIN > thresh, 1., 0.)
    try:
        predictions_NBIN[np.argwhere(
            np.isnan(unique_polyp_results_NBIN))] = np.nan
    except:
        predictions_NBIN = np.asarray(unique_polyp_results_NBIN)
    predictions_WL = np.where(unique_polyp_results_WL > thresh, 1., 0.)
    try:
        predictions_WL[np.argwhere(np.isnan(unique_polyp_results_WL))] = np.nan
    except:
        predictions_WL = np.asarray(unique_polyp_results_WL)
    predictions_WLF = np.where(unique_polyp_results_WLF > thresh, 1., 0.)
    try:
        predictions_WLF[np.argwhere(
            np.isnan(unique_polyp_results_WLF))] = np.nan
    except:
        predictions_WLF = np.asarray(unique_polyp_results_WLF)
    predictions_WLN = np.where(unique_polyp_results_WLN > thresh, 1., 0.)
    try:
        predictions_WLN[np.argwhere(
            np.isnan(unique_polyp_results_WLN))] = np.nan
    except:
        predictions_WLN = np.asarray(unique_polyp_results_WLN)
    predictions_NEAR = np.where(unique_polyp_results_NEAR > thresh, 1., 0.)
    try:
        predictions_NEAR[np.argwhere(
            np.isnan(unique_polyp_results_NEAR))] = np.nan
    except:
        predictions_NEAR = np.asarray(unique_polyp_results_NEAR)
    predictions_FAR = np.where(unique_polyp_results_FAR > thresh, 1., 0.)
    try:
        predictions_FAR[np.argwhere(
            np.isnan(unique_polyp_results_FAR))] = np.nan
    except:
        predictions_FAR = np.asarray(unique_polyp_results_FAR)
    warnings.resetwarnings()

    scores_IMAGEWISE = compute_scores(y_true=test_y_true,
                                      y_pred=predictions_IMAGES)
    scores_ALL = compute_scores(
        y_true=np.squeeze(unique_polyp_labels[np.argwhere(
            np.isfinite(unique_polyp_results_ALL))]),
        y_pred=np.squeeze(predictions_ALL[np.argwhere(
            np.isfinite(unique_polyp_results_ALL))]))
    scores_NBI = compute_scores(
        y_true=np.squeeze(unique_polyp_labels[np.argwhere(
            np.isfinite(unique_polyp_results_NBI))]),
        y_pred=np.squeeze(predictions_NBI[np.argwhere(
            np.isfinite(unique_polyp_results_NBI))]))
    scores_NBIF = compute_scores(
        y_true=np.squeeze(unique_polyp_labels[np.argwhere(
            np.isfinite(unique_polyp_results_NBIF))]),
        y_pred=np.squeeze(predictions_NBIF[np.argwhere(
            np.isfinite(unique_polyp_results_NBIF))]))
    scores_NBIN = compute_scores(
        y_true=np.squeeze(unique_polyp_labels[np.argwhere(
            np.isfinite(unique_polyp_results_NBIN))]),
        y_pred=np.squeeze(predictions_NBIN[np.argwhere(
            np.isfinite(unique_polyp_results_NBIN))]))
    scores_WL = compute_scores(
        y_true=np.squeeze(unique_polyp_labels[np.argwhere(
            np.isfinite(unique_polyp_results_WL))]),
        y_pred=np.squeeze(predictions_WL[np.argwhere(
            np.isfinite(unique_polyp_results_WL))]))
    scores_WLF = compute_scores(
        y_true=np.squeeze(unique_polyp_labels[np.argwhere(
            np.isfinite(unique_polyp_results_WLF))]),
        y_pred=np.squeeze(predictions_WLF[np.argwhere(
            np.isfinite(unique_polyp_results_WLF))]))
    scores_WLN = compute_scores(
        y_true=np.squeeze(unique_polyp_labels[np.argwhere(
            np.isfinite(unique_polyp_results_WLN))]),
        y_pred=np.squeeze(predictions_WLN[np.argwhere(
            np.isfinite(unique_polyp_results_WLN))]))
    scores_NEAR = compute_scores(
        y_true=np.squeeze(unique_polyp_labels[np.argwhere(
            np.isfinite(unique_polyp_results_NEAR))]),
        y_pred=np.squeeze(predictions_NEAR[np.argwhere(
            np.isfinite(unique_polyp_results_NEAR))]))
    scores_FAR = compute_scores(
        y_true=np.squeeze(unique_polyp_labels[np.argwhere(
            np.isfinite(unique_polyp_results_FAR))]),
        y_pred=np.squeeze(predictions_FAR[np.argwhere(
            np.isfinite(unique_polyp_results_FAR))]))

    np.savetxt(
        output_filename,
        np.stack([
            scores_IMAGEWISE, scores_ALL, scores_NBI, scores_NBIF, scores_NBIN,
            scores_WL, scores_WLF, scores_WLN, scores_NEAR, scores_FAR
        ],
                 axis=0),
        delimiter=',')
    print(
        '- Testing Complete! Results on ALL Polyps -\nAccuracy: {}\nSensitivity: {}\nSpecificity: {}'
        .format(scores_ALL[0], scores_ALL[1], scores_ALL[2]))
Ejemplo n.º 7
0
def train(args, u_model, train_samples, val_samples):
    # Compile the loaded model
    model = compile_model(args=args, uncomp_model=u_model)

    # Load pre-trained weights
    if args.finetune_weights_path != '':
        try:
            model.load_weights(args.finetune_weights_path)
        except Exception as e:
            print(e)
            print(
                '!!! Failed to load custom weights file. Training without pre-trained weights. !!!'
            )

    # Set the callbacks
    callbacks = get_callbacks(args)

    if args.aug_data:
        train_datagen = ImageDataGenerator(
            samplewise_center=False,
            samplewise_std_normalization=False,
            rotation_range=45,
            width_shift_range=0.1,
            height_shift_range=0.1,
            shear_range=0.1,
            zoom_range=0.1,
            fill_mode='nearest',
            horizontal_flip=True,
            vertical_flip=True,
            rescale=None,
            preprocessing_function=custom_train_data_augmentation)

        val_datagen = ImageDataGenerator(samplewise_center=False,
                                         samplewise_std_normalization=False,
                                         rescale=None)
    else:
        train_datagen = ImageDataGenerator(samplewise_center=False,
                                           samplewise_std_normalization=False,
                                           rotation_range=0,
                                           width_shift_range=0.,
                                           height_shift_range=0.,
                                           shear_range=0.,
                                           zoom_range=0.,
                                           fill_mode='nearest',
                                           horizontal_flip=False,
                                           vertical_flip=False,
                                           rescale=None)

        val_datagen = ImageDataGenerator(samplewise_center=False,
                                         samplewise_std_normalization=False,
                                         rescale=None)

    if debug:
        save_dir = args.img_aug_dir
    else:
        save_dir = None

    def xcaps_data_gen(gen):
        while True:
            x, y = gen.next()
            if args.num_classes == 1:
                mal = np.array([y[i][0][6, 0] for i in range(y.shape[0])])
            else:
                mal = np.array([y[i][0][6, 1:] for i in range(y.shape[0])])
            yield x, [
                mal,
                np.array([y[i][0][0, 0] for i in range(y.shape[0])]),
                np.array([y[i][0][1, 0] for i in range(y.shape[0])]),
                np.array([y[i][0][2, 0] for i in range(y.shape[0])]),
                np.array([y[i][0][3, 0] for i in range(y.shape[0])]),
                np.array([y[i][0][4, 0] for i in range(y.shape[0])]),
                np.array([y[i][0][5, 0] for i in range(y.shape[0])]),
                x * np.expand_dims(
                    np.array([y[i][1] for i in range(y.shape[0])]), axis=-1)
            ]

    def capsnet_data_gen(gen):
        while True:
            x, y = gen.next()
            if args.num_classes == 1:
                y = np.array([y[i][0][6, 0] for i in range(y.shape[0])])
            else:
                y = np.array([y[i][0][6, 1:] for i in range(y.shape[0])])
            yield [x, y], [y, x]

    # Prepare images and labels for training
    train_imgs = normalize_img(
        np.expand_dims(train_samples[0], axis=-1).astype(np.float32))
    val_imgs = normalize_img(
        np.expand_dims(val_samples[0], axis=-1).astype(np.float32))

    train_labels = []
    val_labels = []
    n_attr = 9  # 8 attr + mal score
    skip_attr_list = [1, 2]
    for i in range(n_attr):
        skip = False
        if skip_attr_list:
            for j in skip_attr_list:
                if i == j:  #indexing from negative side
                    skip_attr_list.remove(j)
                    skip = True
        if args.num_classes == 1 and i == n_attr - 1:
            tlab = np.repeat(np.expand_dims(train_samples[2][:,
                                                             2 * i + n_attr],
                                            axis=-1),
                             6,
                             axis=1)
            tlab[tlab < 3.] = 0.
            tlab[tlab >= 3.] = 1.
            train_labels.append(tlab)
            vlab = np.repeat(np.expand_dims(val_samples[2][:, 2 * i + n_attr],
                                            axis=-1),
                             6,
                             axis=1)
            vlab[vlab < 3.] = 0.
            vlab[vlab >= 3.] = 1.
            val_labels.append(vlab)
            skip = True
        if not skip:
            train_labels.append(
                np.hstack(
                    (np.expand_dims(
                        (train_samples[2][:, 2 * i + n_attr] - 1) / 4.,
                        axis=-1),
                     get_pseudo_label([1., 2., 3., 4., 5.],
                                      train_samples[2][:, 2 * i + n_attr],
                                      train_samples[2][:,
                                                       2 * i + 1 + n_attr]))))
            val_labels.append(
                np.hstack(
                    (np.expand_dims(
                        (val_samples[2][:, 2 * i + n_attr] - 1) / 4., axis=-1),
                     get_pseudo_label([1., 2., 3., 4., 5.],
                                      val_samples[2][:, 2 * i + n_attr],
                                      val_samples[2][:, 2 * i + 1 + n_attr]))))

    train_labels = np.rollaxis(np.asarray(train_labels), 0, 2)
    val_labels = np.rollaxis(np.asarray(val_labels), 0, 2)

    new_labels = np.empty((len(train_labels), 2), dtype=np.object)
    for i in range(len(train_labels)):
        new_labels[i, 0] = train_labels[i]
        if args.masked_recon:
            new_labels[i, 1] = train_samples[1][i]
        else:
            new_labels[i, 1] = np.ones_like(train_samples[1][i])
    train_labels = new_labels

    new_labels = np.empty((len(val_labels), 2), dtype=np.object)
    for i in range(len(val_labels)):
        new_labels[i, 0] = val_labels[i]
        if args.masked_recon:
            new_labels[i, 1] = val_samples[1][i]
        else:
            new_labels[i, 1] = np.ones_like(val_samples[1][i])
    val_labels = new_labels

    train_flow_gen = train_datagen.flow(x=train_imgs,
                                        y=train_labels,
                                        batch_size=args.batch_size,
                                        shuffle=True,
                                        seed=12,
                                        save_to_dir=save_dir)

    val_flow_gen = val_datagen.flow(x=val_imgs,
                                    y=val_labels,
                                    batch_size=args.batch_size,
                                    shuffle=True,
                                    seed=12,
                                    save_to_dir=save_dir)

    if args.net.find('xcaps') != -1:
        train_gen = xcaps_data_gen(train_flow_gen)
        val_gen = xcaps_data_gen(val_flow_gen)
    elif args.net.find('capsnet') != -1:
        train_gen = capsnet_data_gen(train_flow_gen)
        val_gen = capsnet_data_gen(val_flow_gen)
    else:
        raise NotImplementedError(
            'Data generator not found for specified network. Please check train.py file.'
        )

    # Settings
    train_steps = len(train_samples[0]) // args.batch_size
    val_steps = len(val_samples[0]) // args.batch_size
    workers = 4
    multiproc = True

    # Run training
    history = model.fit_generator(train_gen,
                                  max_queue_size=40,
                                  workers=workers,
                                  use_multiprocessing=multiproc,
                                  steps_per_epoch=train_steps,
                                  validation_data=val_gen,
                                  validation_steps=val_steps,
                                  epochs=args.epochs,
                                  class_weight=None,
                                  callbacks=callbacks,
                                  verbose=args.verbose,
                                  shuffle=True)

    # Plot the training data collected
    plot_training(history, args)