Exemplo n.º 1
0
from transforms import *
import time
from datetime import datetime

# dataset load
transform = ComposeKeyPoints([
    To3ChannelsIRKeyPoints(),
    ResizeKeypoints(224),
    RandomMirrorKeyPoints(),
    RandomAffineKeyPoints((-60, 60)),
    ToTensorKeyPoints(),
    NormalizeKeyPoints((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

train_set = GaneratedHandsDataset(
    "/disk1/ofirbartal/Projects/Dataset/GANeratedHands_Release/dataset_csv/train_dataset.csv",
    transform)
train_loader = torch.utils.data.DataLoader(train_set,
                                           batch_size=16,
                                           shuffle=True)

val_set = GaneratedHandsDataset(
    "/disk1/ofirbartal/Projects/Dataset/GANeratedHands_Release/dataset_csv/val_dataset.csv",
    transform)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=16, shuffle=True)

checkpoints_path = "/disk1/ofirbartal/Projects/LandmarksExtraction/Checkpoints/Autoencoder_checkpoints/{}/".format(
    datetime.now().strftime("%d_%m_%y__%H_%M_%S"))

if not os.path.exists(checkpoints_path):
    os.makedirs(checkpoints_path)
Exemplo n.º 2
0
    return np.array(dist)


# dataset load
transform = ComposeKeyPoints([
    To3ChannelsGrayscaleKeyPoints(),
    ResizeKeypoints(224),
    RandomMirrorKeyPoints(),
    RandomAffineKeyPoints((-60, 60)),
    ToTensorKeyPoints(),
    NormalizeKeyPoints((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

test_set = GaneratedHandsDataset(
    "/disk1/ofirbartal/Projects/Dataset/GANeratedHands_Release/dataset_csv_small/test_dataset.csv",
    transform)
test_loader = torch.utils.data.DataLoader(test_set,
                                          batch_size=16,
                                          shuffle=True)

pixel_error_threshhold_close = 5
pixel_error_threshhold_mid = 10
pixel_error_threshhold_far = 15

model = CoordRegressionNetwork(n_locations=21).cuda()
model = torch.nn.DataParallel(model, [0, 1], 0)

model.load_state_dict(
    torch.load(
        "/disk1/ofirbartal/Projects/LandmarksExtraction/Checkpoints/FCN_checkpoints/gray_26_11_19__14_36_37//epoch_80.pth"