示例#1
0
def main(args):
    transform_fn = data_utils.make_transform_fn(128,
                                                128,
                                                args.input_dim,
                                                1.0,
                                                normalize=True)

    snapshot = os.path.join('../experiment/save',
                            '{}.h5'.format(args.timestamp))
    test_list = os.path.join('../experiment/test_lists',
                             '{}.txt'.format(args.timestamp))

    model = MilkEager(encoder_args=encoder_args,
                      mil_type=args.mil,
                      deep_classifier=True)
    x_dummy = tf.zeros(
        shape=[1, args.batch_size, args.input_dim, args.input_dim, 3],
        dtype=tf.float32)
    retvals = model(x_dummy, verbose=True)
    model.load_weights(snapshot, by_name=True)
    model.summary()

    test_list = read_test_list(test_list)

    yhats, ytrues = [], []
    features_case, features_classifier = [], []
    for test_case in test_list:
        case_name = os.path.basename(test_case).replace('.npy', '')
        print(test_case, case_name)
        # case_path = os.path.join('../dataset/tiles_reduced', '{}.npy'.format(case_name))
        case_x = np.load(test_case)
        case_x = np.stack([transform_fn(x) for x in case_x], 0)
        print(case_x.shape)

        if args.sample:
            case_x = case_x[
                np.random.choice(range(case_x.shape[0]), args.sample), ...]
            print(case_x.shape)

        features = model.encode_bag(case_x,
                                    batch_size=args.batch_size,
                                    training=True,
                                    return_z=True)
        print('features:', features.shape)
        # features_att, attention = model.mil_attention(features, return_att=True, training=False)
        # print('features:', features_att.shape, 'attention:', attention.shape)

        features_att = tf.reduce_mean(features, axis=0, keep_dims=True)
        yhat_instances = model.apply_classifier(features, training=False)
        print('yhat instances:', yhat_instances.shape)
        yhat = model.apply_classifier(features_att, training=False)
        print('yhat:', yhat.shape)

        yhat_1 = yhat_instances[:, 1].numpy()
        # yhat, attention, features, feat_case, feat_class = retvals
        # attention = np.squeeze(attention.numpy(), axis=0)
        high_att_idx, high_att_imgs, low_att_idx, low_att_imgs = get_attention_extremes(
            yhat_1, case_x, n=5)

        print('Case {}: predicted={} ({})'.format(test_case,
                                                  np.argmax(yhat, axis=-1),
                                                  yhat.numpy()))

        features = features.numpy()
        savepath = '{}_{}_{:3.2f}_yhat.png'.format(args.savebase, case_name,
                                                   yhat[0, 1])
        print('Saving figure {}'.format(savepath))
        z = draw_projection(features, features_att, yhat_1, savepath=savepath)

        savepath = '{}_{}_{:3.2f}_yhat_img.png'.format(args.savebase,
                                                       case_name, yhat[0, 1])
        print('Saving figure {}'.format(savepath))
        draw_projection_with_images(z,
                                    yhat_1,
                                    high_att_idx,
                                    high_att_imgs,
                                    low_att_idx,
                                    low_att_imgs,
                                    savepath=savepath)
示例#2
0
def main(args):
    transform_fn = data_utils.make_transform_fn(128,
                                                128,
                                                args.input_dim,
                                                1.0,
                                                normalize=True)

    snapshot = os.path.join(args.snapshot_dir, '{}.h5'.format(args.timestamp))
    test_list = os.path.join(args.test_list_dir,
                             '{}.txt'.format(args.timestamp))

    encoder_args = get_encoder_args(args.encoder)
    model = MilkEager(encoder_args=encoder_args,
                      deep_classifier=True,
                      batch_size=args.batch_size,
                      temperature=args.temperature,
                      cls_normalize=args.cls_normalize)

    x_dummy = tf.zeros(
        shape=[1, args.batch_size, args.input_dim, args.input_dim, 3],
        dtype=tf.float32)
    retvals = model(x_dummy, verbose=True)
    model.load_weights(snapshot, by_name=True)
    model.summary()

    test_list = read_test_list(test_list)
    savebase = os.path.join(args.odir, args.timestamp)
    if os.path.exists(savebase):
        shutil.rmtree(savebase)
    os.makedirs(savebase)

    yhats, ytrues = [], []
    features_case, features_classifier = [], []
    for test_case in test_list:
        case_name = os.path.basename(test_case).replace('.npy', '')
        print(test_case, case_name)
        # case_path = os.path.join('../dataset/tiles_reduced', '{}.npy'.format(case_name))
        case_x = np.load(test_case)
        case_x = np.stack([transform_fn(x) for x in case_x], 0)
        ytrue = case_dict[case_name]
        print(case_x.shape, ytrue)
        ytrues.append(ytrue)

        if args.sample:
            case_x = case_x[
                np.random.choice(range(case_x.shape[0]), args.sample), ...]
            print(case_x.shape)

        features = model.encode_bag(case_x, training=True, return_z=True)
        print('features:', features.shape)
        features_att, attention = model.mil_attention(features,
                                                      return_att=True,
                                                      training=False)
        print('features:', features_att.shape, 'attention:', attention.shape)

        features_avg = np.mean(features, axis=0, keepdims=True)

        yhat_instances = model.apply_classifier(features, training=False)
        print('yhat instances:', yhat_instances.shape)
        yhat = model.apply_classifier(features_att, training=False)
        print('yhat:', yhat.shape)
        yhats.append(yhat)

        # yhat, attention, features, feat_case, feat_class = retvals
        # attention = np.squeeze(attention.numpy(), axis=0)
        high_att_idx, high_att_imgs, low_att_idx, low_att_imgs = get_attention_extremes(
            attention, case_x, n=5)

        print('Case {}: predicted={} ({})'.format(test_case,
                                                  np.argmax(yhat, axis=-1),
                                                  yhat.numpy()))

        attention = np.squeeze(attention.numpy())
        features = features.numpy()
        savepath = os.path.join(savebase,
                                '{}_{:3.2f}.png'.format(case_name, yhat[0, 1]))
        print('Saving figure {}'.format(savepath))
        z = draw_projection(features,
                            features_avg,
                            features_att,
                            attention,
                            savepath=savepath)

        # savepath = os.path.join(savebase, '{}_{:3.2f}_ys.png'.format(case_name, yhat[0,1]))
        # print('Saving figure {}'.format(savepath))
        # draw_projection_with_images(z, yhat_instances[:,1].numpy(),
        #   high_att_idx, high_att_imgs,
        #   low_att_idx, low_att_imgs,
        #   savepath=savepath)

        savepath = os.path.join(
            savebase, '{}_{:3.2f}_imgs.png'.format(case_name, yhat[0, 1]))
        print('Saving figure {}'.format(savepath))
        draw_projection_with_images(z,
                                    attention,
                                    high_att_idx,
                                    high_att_imgs,
                                    low_att_idx,
                                    low_att_imgs,
                                    savepath=savepath)

        savepath = os.path.join(savebase, '{}_atns.npy'.format(case_name))
        np.save(savepath, attention)
        savepath = os.path.join(savebase, '{}_feat.npy'.format(case_name))
        np.save(savepath, features)

    yhats = np.concatenate(yhats, axis=0)
    yhats = np.argmax(yhats, axis=1)
    ytrues = np.array(ytrues)
    acc = (yhats == ytrues).mean()
    print(acc)
    cm = confusion_matrix(y_true=ytrues, y_pred=yhats)
    print(cm)