Exemplo n.º 1
0
                                     std=IMG_NORM_STDDEV,
                                     inplace=True)])

AUGMENTATION_TRANSFORM = SeededCompose([
    torchvision.transforms.ToPILImage(mode="F"),
    torchvision.transforms.RandomHorizontalFlip(p=HORIZONTAL_FLIP_P),
    torchvision.transforms.RandomAffine(
        degrees=(-ROTATION_MAX_DEGREES, +ROTATION_MAX_DEGREES),
        translate=TRANSLATION_MAX_RATIO, scale=SCALE_RANGE),
    torchvision.transforms.RandomCrop(size=TRAIN_HW, pad_if_needed=True),
    torchvision.transforms.ToTensor()])

minival_dl = torch.utils.data.DataLoader(
    CocoDistillationDatasetAugmented(COCO_DIR, "val2017",
                                     img_transform=IMG_NORMALIZE_TRANSFORM,
                                     remove_images_without_annotations=False,
                                     gt_stddevs_pix=MINIVAL_GT_STDDEVS,
                                     whitelist_ids=MINIVAL_IDS),
    batch_size=1,
    shuffle=False,
    num_workers=0,
    pin_memory=True)

val_dl = torch.utils.data.DataLoader(
    CocoDistillationDatasetAugmented(COCO_DIR, "val2017",
                                     img_transform=IMG_NORMALIZE_TRANSFORM,
                                     remove_images_without_annotations=False,
                                     gt_stddevs_pix=VAL_GT_STDDEVS),
    batch_size=1,
    shuffle=False,
    num_workers=0,
Exemplo n.º 2
0
SAVE_EVERY = 100  # None
SAVE_DIR = "/tmp"


# #############################################################################
# # MAIN ROUTINE
# #############################################################################

# dataloaders
IMG_TRANSFORM = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=IMG_NORM_MEAN,
                                     std=IMG_NORM_STDDEV,
                                     inplace=True)])

val_ds = CocoDistillationDatasetAugmented(
    COCO_DIR, "val2017", remove_images_without_annotations=False)
img_paths = [os.path.join(COCO_DIR, "images", "val2017",
                          "{:012d}.jpg".format(x)) for x in val_ds.ids]

# model
hhrnet = get_hrnet_w48_teacher(MODEL_PATH).to(DEVICE)
hhrnet.eval()

hm_parser = HeatmapParser(num_joints=NUM_HEATMAPS,
                          **HM_PARSER_PARAMS)

# main loop
all_preds = []
all_scores = []
for ii, imgpath in enumerate(img_paths):
    img = Image.open(imgpath).convert("RGB")
Exemplo n.º 3
0
OVERALL_HHRNET_TRANSFORM = SeededCompose([
    torchvision.transforms.ToPILImage(mode="F"),
    torchvision.transforms.RandomHorizontalFlip(p=0.5),
    torchvision.transforms.RandomAffine(
        degrees=(-45, +45), translate=(0.1, 0.1), scale=(0.7, 1.3)),
    torchvision.transforms.RandomCrop(size=(480, 480), pad_if_needed=True),
    torchvision.transforms.ToTensor()])


# #############################################################################
# # MAIN ROUTINE
# #############################################################################
minival_dataset = CocoDistillationDatasetAugmented(
    COCO_DIR, "val2017",
    os.path.join(COCO_DIR, "hrnet_predictions", "val2017"),
    gt_stddevs_pix=[2.0],
    img_transform=IMG_NORMALIZE_TRANSFORM,
    whitelist_ids=MINIVAL_IDS)

val_augm_dataset = CocoDistillationDatasetAugmented(
    COCO_DIR, "val2017",
    os.path.join(COCO_DIR, "hrnet_predictions", "val2017"),
    gt_stddevs_pix=[20.0, 9.0, 2.0],
    img_transform=IMG_NORMALIZE_TRANSFORM,
    overall_transform=OVERALL_HHRNET_TRANSFORM)


stem = network_to_half(StemHRNet())
stem[1].load_pretrained(MODEL_PATH, device="cuda")
stem.to("cuda")
Exemplo n.º 4
0
])

AUGMENTATION_TRANSFORM = SeededCompose([
    torchvision.transforms.ToPILImage(mode="F"),
    # torchvision.transforms.RandomHorizontalFlip(p=HORIZONTAL_FLIP_P),
    # torchvision.transforms.RandomAffine(
    #     degrees=(-ROTATION_MAX_DEGREES, +ROTATION_MAX_DEGREES),
    #     translate=TRANSLATION_MAX_RATIO, scale=SCALE_RANGE),
    torchvision.transforms.RandomCrop(size=TRAIN_HW, pad_if_needed=True),
    torchvision.transforms.ToTensor()
])

minival_dl = torch.utils.data.DataLoader(CocoDistillationDatasetAugmented(
    COCO_DIR,
    "val2017",
    img_transform=IMG_NORMALIZE_TRANSFORM,
    remove_images_without_annotations=False,
    gt_stddevs_pix=MINIVAL_GT_STDDEVS,
    whitelist_ids=MINIVAL_IDS),
                                         batch_size=1,
                                         shuffle=False,
                                         num_workers=0,
                                         pin_memory=True)

val_dl = torch.utils.data.DataLoader(CocoDistillationDatasetAugmented(
    COCO_DIR,
    "val2017",
    img_transform=IMG_NORMALIZE_TRANSFORM,
    remove_images_without_annotations=False,
    gt_stddevs_pix=VAL_GT_STDDEVS),
                                     batch_size=1,
])

OVERALL_HHRNET_TRANSFORM = SeededCompose([
    torchvision.transforms.ToPILImage(mode="F"),
    torchvision.transforms.RandomHorizontalFlip(p=0.5),
    torchvision.transforms.RandomAffine(degrees=(-45, +45),
                                        translate=(0.1, 0.1),
                                        scale=(0.7, 1.3)),
    # torchvision.transforms.RandomCrop(size=(480, 480), pad_if_needed=True),
    torchvision.transforms.ToTensor()
])

minival_dataset = CocoDistillationDatasetAugmented(
    "/home/a9fb1e/datasets/coco",
    "val2017",
    "/home/a9fb1e/datasets/coco/hrnet_predictions/val2017",
    # img_transform=IMG_NORMALIZE_TRANSFORM,
    # overall_transform=OVERALL_HHRNET_TRANSFORM,
    whitelist_ids=MINIVAL_IDS)

loss = DistillationLoss()

hm_parser = HeatmapParser(num_joints=17,
                          max_num_people=30,
                          detection_threshold=0.1,
                          tag_threshold=1.0,
                          use_detection_val=True,
                          ignore_too_much=False,
                          tag_per_joint=True,
                          nms_ksize=5,
                          nms_padding=2)