Esempio n. 1
0
def main():
    import load_utils
    import cv2
    import time
    train_data_loader = load_utils.train_loader(10)
    train_imgs, train_labels = next(train_data_loader)
    ts = time.time()
    cm = ConfidenceMap()
    heat_scale = 1
    heat_hw = np.asarray(train_imgs).shape[1:3]
    NCHW_corner_gau = cm.batch_gaussian_split_corner(train_imgs, train_labels,
                                                     heat_scale)
    NCHW_center_gau = cm.batch_gaussian_LRCenter(train_imgs, train_labels,
                                                 heat_scale)
    NCHW_c_lines = cm.batch_lines_LRCenter(heat_hw, train_labels, heat_scale)
    NCHW_t_lines = cm.batch_lines_LRTop(heat_hw, train_labels)
    NCHW_b_lines = cm.batch_lines_LRBottom(heat_hw, train_labels)
    NCHW_spine_mask = cm.batch_spine_mask(heat_hw, train_labels)
    NCHW_first_lrpt = cm.batch_gaussian_first_lrpt(train_imgs, train_labels)
    NCHW_last_lrpt = cm.batch_gaussian_last_lrpt(train_imgs, train_labels)
    NCHW_gaussian = np.concatenate(
        (NCHW_first_lrpt, NCHW_last_lrpt, NCHW_spine_mask), axis=1
    )  #NCHW_corner_gau, NCHW_center_gau, NCHW_lines, NCHW_first_lrpt), axis=1)
    te = time.time()
    print("Duration for gaussians: %f" %
          (te - ts))  # Time duration for generating gaussians
    for n in range(NCHW_gaussian.shape[0]):
        for c in range(NCHW_gaussian.shape[1]):
            assert NCHW_gaussian.max() < 1.5, "expect normalized values"
            cv2.imshow("Image", train_imgs[n])
            g = NCHW_gaussian[n, c]
            g = cv2.resize(g, dsize=None, fx=heat_scale, fy=heat_scale)
            cv2.imshow(
                "Image Heat",
                np.amax([train_imgs[n].astype(np.float32) / 255, g], axis=0))
            cv2.imshow("Heat Only", g)
            cv2.waitKey()
Esempio n. 2
0
import numpy as np
import spine_augmentation as aug

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-s",
                        default=5,
                        type=int,
                        required=False,
                        help="batch size")
    parser.add_argument("--trainval", action='store_true', default=False)
    args = parser.parse_args()
    batch_size = args.s
    if args.trainval:  # Final training, use train and val set
        train_data_loader = load_utils.train_loader(batch_size,
                                                    load_angle=True,
                                                    use_trainval=True)
        print("--- Using [train, val] set as training set!")
    else:
        train_data_loader = load_utils.train_loader(batch_size,
                                                    load_angle=True)
    test_data_loader = load_utils.test_loader(batch_size, load_angle=True)

    net_heat = part_affinity_field_net.SpineModelPAF()
    net_heat.cuda()
    net_heat.eval()
    net_angle = part_affinity_field_net.CobbAngleModel()
    net_angle.cuda()

    # Load heatmap network checkpoint
    save_path_heat = f.checkpoint_heat_trainval_path if args.trainval else f.checkpoint_heat_path
Esempio n. 3
0
            image_before = keypoints_before.draw_on_image(batch_img[i])
            image_after = keypoints_after.draw_on_image(aug_b_imgs[i])
            ia.imshow(np.hstack([image_before, image_after]))
    return aug_b_imgs, aug_b_pts


def augment_batch_img_for_angle(batch_img, batch_pts, plot=False):
    """
    Image augmentation, used when training
    :param batch_img: [B,H,W,C]
    :param batch_pts: [B,number,xy]
    :return: aug_b_img, aug_b_pts
    """
    sometimes = lambda aug: iaa.Sometimes(0.5, aug)
    seq = iaa.Sequential([
        iaa.CropAndPad(percent=((0., 0.), (-0.1, 0.1), (0., 0.), (-0.1, 0.1))),
        iaa.Add((-25, 25))  # change brightness
    ])
    aug_b_imgs, aug_b_pts = seq(images=batch_img, keypoints=batch_pts)

    return aug_b_imgs, aug_b_pts


if __name__ == "__main__":
    # Run this script to see augmentation results
    import load_utils
    data_gen = load_utils.train_loader(5)
    for imgs, labels in data_gen:
        # augment_batch_img(imgs, labels, plot=True)
        augment_batch_img_for_box(imgs, labels, plot=True)