Example #1
0
def for_pixel_input():
    # load test data
    print('Load test data..')
    test_imgs = load_raw_images_data(data_path,
                                     rescale_ratio=0.25,
                                     preserve_range_after_rescale=True)
    img_height, img_width = test_imgs[0].shape
    test_dataset, channel_len = load_patch_dataset_from_imgs(test_imgs,
                                                             patch_size=3)

    model = models.conv_hybrid(channel_len, 3)
    model.load_state_dict(
        torch.load(f'{model_path}/conv_hybrid_on_{data_class}.pth',
                   map_location='cpu'))
    model.eval()

    for name, param in model.named_parameters():
        print(name, param.shape)

    # Visualize feature maps
    activation = {}
    model.residual[0].register_forward_hook(get_activation(
        'conv1', activation))
    output = models.get_model_output(test_dataset, model)
    act = activation['conv1']
    filter_num = act.size(1)
    reconstruct_feature_map(act, filter_num, img_height - 2, img_width - 2)
    # Visualize conv filter
    kernels = model.residual[0].weight.detach()
    # normalize filter values to 0-1 so we can visualize them
    kernels = kernels - kernels.min()
    kernels = kernels / kernels.max()
Example #2
0
def for_img_input():
    # load test data
    print('Load test data..')
    imgs_norm = load_raw_images_data(data_path, rescale_ratio=0.25)
    sample_img = get_sample_image(data_path, rescale_ratio=0.25)
    img_height, img_width = sample_img.shape

    model = models.conv2d_net(23, img_width, img_height, 3)
    model.load_state_dict(
        torch.load(
            f'{model_path}/conv{conv_nd}d_on_{data_class}_{model_data_id}.pth',
            map_location='cpu'))
    model.eval()

    for name, param in model.named_parameters():
        print(name, param.shape)

    # Visualize feature maps
    activation = {}
    model.conv[0].register_forward_hook(get_activation('conv1', activation))
    output = model(torch.FloatTensor(imgs_norm))
    act = activation['conv1'].squeeze()
    view_feature_map(act)
    # Visualize conv filter
    kernels = model.conv[0].weight.detach()
    # normalize filter values to 0-1 so we can visualize them
    kernels = kernels - kernels.min()
    kernels = kernels / kernels.max()
    view_filter(kernels)
def enhance_roi(data_id):
    data_class = 'allClass'
    data_type = 'cropped_roi'

    # file paths
    data_path = f'networks/data/sgp/{data_id}/cropped_roi/*'
    model_path = 'networks/model'
    img_save_path = 'networks/reconstructed_roi/conv_resid'
    # mkdir if not exists
    Path(f'{img_save_path}').mkdir(parents=True, exist_ok=True)

    # load test data
    print('Load test data..')
    test_imgs = load_raw_images_data(data_path,
                                     rescale_ratio=0.25,
                                     preserve_range_after_rescale=True)
    sample_img = test_imgs[0]
    test_dataset, channel_len = load_patch_dataset_from_imgs(test_imgs,
                                                             patch_size=3)

    # load model
    print('Load model..')
    model = models.conv_resid(channel_len, 3)
    model.load_state_dict(
        torch.load(f'{model_path}/conv_resid_on_{data_class}.pth',
                   map_location='cpu'))
    model.eval()

    print('Model predict..')
    predictions = models.predict_class(test_dataset, model)

    predictions = pad_prediction(predictions, sample_img,
                                 test_dataset.patch_size)

    print('Reconstruct..')
    sample_img = reconstruct_image(sample_img,
                                   predictions,
                                   enhance_intensity=20,
                                   count_note=True)
    imsave(f'{img_save_path}/{data_id}_conv_resid.png', sample_img)
Example #4
0
data_path = f'networks/data/sgp/{data_id}/cropped_roi/*'
model_path = 'networks/model'
img_save_path = 'networks/reconstructed_roi'
log_path = f'networks/training_log/conv{conv_nd}ds/sgd'
# mkdir if not exists
Path(f'{model_path}').mkdir(parents=True, exist_ok=True)
Path(f'{log_path}').mkdir(parents=True, exist_ok=True)
Path(f'{img_save_path}').mkdir(parents=True, exist_ok=True)

# load training data
print('Load training data..')

# load images
print('-load images..')

imgs_norm = load_raw_images_data(data_path, rescale_ratio=0.25)
test_dataset, sample_img = load_images_data(data_path, rescale_ratio=0.25)
img_height, img_width = sample_img.shape

channel_len = 23
pxl_num = img_width * img_height

# load estimator
print('-load estimator..')
autoencoder = models.sdae(dimensions=[channel_len, 10, 10, 20, 3])
ae_model = models.sdae_lr(autoencoder)
ae_model.load_state_dict(
    torch.load(f'{model_path}/ae_on_{data_class}.pth', map_location='cpu'))
ae_model.eval()

# predict labels
Example #5
0
channel_train, y_true, channel_len = load_raw_labeled_data()

# fit lda
print('Model training..')
classifier = lda()
classifier.fit(channel_train, y_true)
precision_clf = classifier.score(channel_train, y_true)
prediction = classifier.predict(channel_train)
balanced_acc = balanced_accuracy_score(y_true, prediction)
kappa = cohen_kappa_score(y_true, prediction)
# plot learning curve
plot_learning_curve(classifier, 'learning curve of LDA', channel_train, y_true)

print('train-accuracy: ', precision_clf)
print('balanced-accuracy: ', balanced_acc)
print('kappa: ', kappa)

# prepare test data
print('Prepare test data..')
imgs = load_raw_images_data(test_data_path,
                            rescale_ratio=0.25,
                            preserve_range_after_rescale=True)
channel_test, _ = flatten_images(imgs)

print('Model predict..')
predictions = classifier.predict(channel_test)

print('Reconstruct..')
sample_img = imgs[0]
sample_img = reconstruct_image(sample_img, predictions)
imsave(f'{img_save_path}/{data_id}_lda.png', sample_img)
Example #6
0
center_pxl_num = len(center_pxls)
# add new column: center_pxl_id
center_pxls['center_pxl_id'] = pd.Series(range(1, center_pxl_num + 1),
                                         index=center_pxls.index)

target_folio = center_pxls['folio_name'].unique()

extend_pxls = pd.DataFrame(columns=center_pxls.columns)

for folio_name in target_folio:
    print('process on', folio_name, '..')
    # load images
    file_path = f'~/Desktop/sgp-imgs/{folio_name}/tif/'
    filenames = get_filenames(file_path)
    imgs = load_raw_images_data(filenames, preserve_range_after_rescale=True)
    channel_len = len(imgs)
    # find neighbors: [(5x5): radius-2, (3x3): radius-1]
    radius = 1
    pxls_index = center_pxls.loc[center_pxls['folio_name'] == folio_name].index
    for i in pxls_index:
        label = center_pxls.loc[i]['class_name']
        x_l = center_pxls.loc[i]['x_loc'] - 1
        y_l = center_pxls.loc[i]['y_loc'] - 1
        center_id = center_pxls.loc[i]['center_pxl_id']
        neighbors_locs = get_window(radius, x=x_l, y=y_l)
        for (neigh_x, neigh_y) in neighbors_locs:
            # intensity data
            channel_data = []
            for c in range(channel_len):
                channel_data.append(imgs[c][neigh_y][neigh_x])
Example #7
0
def enhanced_roi(data_id):
    data_class = 'allClass'
    folio_ids = ['024r_029v', '102v_107r', '214v_221r']
    model_data_id = folio_ids[0]
    data_type = 'cropped_roi'
    conv_nd = 2
    # net_style {'normal': 0, 'fconv': 1, 'hybrid': 2}
    net_style = 2

    # file paths
    data_path = f'networks/data/sgp/{data_id}/cropped_roi/*'
    model_path = 'networks/model'
    img_save_path = 'networks/reconstructed_roi'
    # mkdir if not exists
    Path(f'{img_save_path}').mkdir(parents=True, exist_ok=True)

    # load images
    print('-load images..')

    imgs_norm = load_raw_images_data(data_path, rescale_ratio=0.25)
    sample_img = get_sample_image(data_path, rescale_ratio=0.25)
    img_height, img_width = sample_img.shape

    channel_len = 23

    # conv model
    if net_style == 2:
        if conv_nd == 2:
            conv_model = models.conv2d_hyb_net(channel_len, img_width,
                                               img_height, 3)
        elif conv_nd == 3:
            conv_model = models.conv3d_hyb_net(channel_len, img_width,
                                               img_height, 3)
        model_name = f'{model_path}/conv{conv_nd}d_hyb_on_{data_class}_{model_data_id}.pth'
    elif net_style == 1:
        if conv_nd == 2:
            conv_model = models.fconv2d_net(channel_len, img_width, img_height,
                                            3)
        elif conv_nd == 3:
            conv_model = models.fconv3d_net(channel_len, img_width, img_height,
                                            3)
        model_name = f'{model_path}/fconv{conv_nd}d_on_{data_class}_{model_data_id}.pth'
    elif net_style == 0:
        if conv_nd == 2:
            conv_model = models.conv2d_net(channel_len, img_width, img_height,
                                           3)
        elif conv_nd == 3:
            conv_model = models.conv3d_net(channel_len, img_width, img_height,
                                           3)
        model_name = f'{model_path}/conv{conv_nd}d_on_{data_class}_{model_data_id}.pth'
    conv_model.load_state_dict(torch.load(model_name, map_location='cpu'))
    conv_model.eval()

    print('Reconstruct..')
    with torch.no_grad():
        output = conv_model(torch.FloatTensor(imgs_norm))
        _, conv_pred = torch.max(output.data, 1)

        print('-max in conv_pred: ', torch.max(conv_pred.data).item())
        print('-min in conv_pred: ', torch.min(conv_pred.data).item())

    imsave(f'{img_save_path}/{data_id}_orig_eval.png', sample_img)

    sample_img_conv = reconstruct_image(sample_img, conv_pred, count_note=True)
    if net_style == 2:
        img_name = f'{img_save_path}/conv{conv_nd}d_hyb/{data_id}_conv{conv_nd}d_hyb_eval_model_{model_data_id}.png'
    elif net_style == 1:
        img_name = f'{img_save_path}/fconv{conv_nd}d/{data_id}_fconv{conv_nd}d_eval_model_{model_data_id}.png'
    elif net_style == 0:
        img_name = f'{img_save_path}/conv{conv_nd}d/{data_id}_conv{conv_nd}d_eval_model_{model_data_id}.png'
    imsave(img_name, sample_img_conv)