Exemplo n.º 1
0
def main(args):
    print(torch.cuda.device_count(), 'gpus available')
    # 1. prepare data & models
    train_transforms = transforms.Compose([
        ScaleMinSideToSize((CROP_SIZE, CROP_SIZE)),
        CropCenter(CROP_SIZE),
        TransformByKeys(transforms.ToPILImage(), ("image", )),
        TransformByKeys(transforms.ToTensor(), ("image", )),
        TransformByKeys(
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ("image", )),
    ])

    print("Reading data...")
    train_dataset = ThousandLandmarksDataset(os.path.join(args.data, 'train'),
                                             train_transforms,
                                             split="train")
    val_dataset = ThousandLandmarksDataset(os.path.join(args.data, 'train'),
                                           train_transforms,
                                           split="val")
    test_dataset = ThousandLandmarksDataset(os.path.join(args.data, 'test'),
                                            train_transforms,
                                            split="test")

    torch.save(
        {
            'train_dataset': train_dataset,
            'val_dataset': val_dataset,
            'test_dataset': test_dataset,
        }, os.path.join(args.data, 'datasets.pth'))
Exemplo n.º 2
0
def main(args):
    # 1. prepare data & models
    train_transforms = transforms.Compose([
        ScaleMinSideToSize((CROP_SIZE, CROP_SIZE)),
        CropCenter(CROP_SIZE),
        TransformByKeys(transforms.ToPILImage(), ("image",)),
        TransformByKeys(transforms.ToTensor(), ("image",)),
        TransformByKeys(transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ("image",)),
    ])

    print("Reading data...")
    train_dataset = ThousandLandmarksDataset(os.path.join(args.data, 'train'), train_transforms, split="train")
    train_dataloader = data.DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, pin_memory=True,
                                       shuffle=True, drop_last=True)
    val_dataset = ThousandLandmarksDataset(os.path.join(args.data, 'train'), train_transforms, split="val")
    val_dataloader = data.DataLoader(val_dataset, batch_size=args.batch_size, num_workers=4, pin_memory=True,
                                     shuffle=False, drop_last=False)

    print("Creating model...")
    device = torch.device("cuda: 0") if args.gpu else torch.device("cpu")
    model = models.resnet18(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, 2 * NUM_PTS, bias=True)
    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, amsgrad=True)
    loss_fn = fnn.mse_loss

    # 2. train & validate
    print("Ready for training...")
    best_val_loss = np.inf
    for epoch in range(args.epochs):
        train_loss = train(model, train_dataloader, loss_fn, optimizer, device=device)
        val_loss = validate(model, val_dataloader, loss_fn, device=device)
        print("Epoch #{:2}:\ttrain loss: {:5.2}\tval loss: {:5.2}".format(epoch, train_loss, val_loss))
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            with open(f"{args.name}_best.pth", "wb") as fp:
                torch.save(model.state_dict(), fp)

    # 3. predict
    test_dataset = ThousandLandmarksDataset(os.path.join(args.data, 'test'), train_transforms, split="test")
    test_dataloader = data.DataLoader(test_dataset, batch_size=args.batch_size, num_workers=4, pin_memory=True,
                                      shuffle=False, drop_last=False)

    with open(f"{args.name}_best.pth", "rb") as fp:
        best_state_dict = torch.load(fp, map_location="cpu")
        model.load_state_dict(best_state_dict)

    test_predictions = predict(model, test_dataloader, device)
    with open(f"{args.name}_test_predictions.pkl", "wb") as fp:
        pickle.dump({"image_names": test_dataset.image_names,
                     "landmarks": test_predictions}, fp)

    create_submission(args.data, test_predictions, f"{args.name}_submit.csv")
Exemplo n.º 3
0
def main(args):
    # 1. prepare data & models
    train_transforms = transforms.Compose([
        ScaleMinSideToSize((CROP_SIZE, CROP_SIZE)),
        CropCenter(CROP_SIZE),
        TransformByKeys(transforms.ToPILImage(), ("image", )),
        TransformByKeys(transforms.ToTensor(), ("image", )),
        TransformByKeys(
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ("image", )),
    ])

    print("Reading data...")
    val_dataset = ThousandLandmarksDataset(os.path.join(args.data, 'train'),
                                           train_transforms,
                                           split="data")
    val_dataloader = data.DataLoader(val_dataset,
                                     batch_size=args.batch_size,
                                     num_workers=4,
                                     pin_memory=True,
                                     shuffle=False,
                                     drop_last=False)

    print("Creating model...")
    device = torch.device("cuda: 0") if args.gpu else torch.device("cpu")
    model = models.resnext50_32x4d(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, 2 * NUM_PTS, bias=True)
    model.to(device)

    MODEL_FILENAME = "./rexnext300_best.pth"
    with open(MODEL_FILENAME, "rb") as fp:
        best_state_dict = torch.load(fp, map_location="cpu")
        model.load_state_dict(best_state_dict)

    loss_fn = fnn.mse_loss

    # 2. predict for train
    print("Ready for training...")
    print(len(val_dataloader.dataset))

    accuracy = validate(model, val_dataloader, loss_fn, device=device)
    print("good div all: {:5.2}".format(accuracy))
Exemplo n.º 4
0
def main(args):
    # 1. prepare data & models
    train_transforms = transforms.Compose([
        ScaleMinSideToSize((CROP_SIZE, CROP_SIZE)),
        CropCenter(CROP_SIZE),
        AffineAugmenter(min_scale=0.9, max_offset=0.1, rotate=True),
        BrightnessContrastAugmenter(brightness=0.3, contrast=0.3),
        BlurAugmenter(max_kernel=5),
        TransformByKeys(transforms.ToPILImage(), ("image", )),
        TransformByKeys(transforms.ToTensor(), ("image", )),
        TransformByKeys(
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ("image", )),
    ])
    test_transforms = transforms.Compose([
        ScaleMinSideToSize((CROP_SIZE, CROP_SIZE)),
        CropCenter(CROP_SIZE),
        TransformByKeys(transforms.ToPILImage(), ("image", )),
        TransformByKeys(transforms.ToTensor(), ("image", )),
        TransformByKeys(
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ("image", )),
    ])

    print("Reading data...")
    train_dataset = ThousandLandmarksDataset(os.path.join(args.data, 'train'),
                                             train_transforms,
                                             split="train")
    train_dataloader = data.DataLoader(train_dataset,
                                       batch_size=args.batch_size,
                                       num_workers=4,
                                       pin_memory=True,
                                       shuffle=True,
                                       drop_last=True)
    val_dataset = ThousandLandmarksDataset(os.path.join(args.data, 'train'),
                                           test_transforms,
                                           split="val")
    val_dataloader = data.DataLoader(val_dataset,
                                     batch_size=args.batch_size,
                                     num_workers=4,
                                     pin_memory=True,
                                     shuffle=False,
                                     drop_last=False)

    print("Creating model...")
    device = torch.device("cuda: 0") if args.gpu else torch.device("cpu")
    model = models.resnext50_32x4d(pretrained=True)
    # for param in model.parameters():
    #     param.requires_grad = False

    model.fc = nn.Linear(model.fc.in_features, 2 * NUM_PTS, bias=True)

    # model.fc = nn.Sequential(
    #     # nn.BatchNorm1d(model.fc.in_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
    #     # nn.Linear(model.fc.in_features, model.fc.in_features, bias=True),
    #     # nn.ReLU(),
    #     nn.BatchNorm1d(model.fc.in_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
    #     nn.Linear(model.fc.in_features, 2 * NUM_PTS, bias=True))

    model.to(device)

    # optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=0.01, amsgrad=True)
    optimizer = RAdam(model.parameters(),
                      lr=args.learning_rate)  # , weight_decay=0.01)

    optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.2, patience=3)
    # optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.2)

    loss_fn = fnn.mse_loss

    # 2. train & validate
    print("Ready for training...")
    best_val_loss = np.inf
    for epoch in range(args.epochs):
        train_loss = train(model,
                           train_dataloader,
                           loss_fn,
                           optimizer,
                           device=device)
        val_loss = validate(model, val_dataloader, loss_fn, device=device)
        print("Epoch #{:2}:\ttrain loss: {:5.2}\tval loss: {:5.2}".format(
            epoch, train_loss, val_loss))
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            with open(f"{args.name}_best.pth", "wb") as fp:
                torch.save(model.state_dict(), fp)
        # with open(f"{args.name}_{epoch}_{train_loss:7.4}_{val_loss:7.4}.pth", "wb") as fp:
        #     torch.save(model.state_dict(), fp)

    # 3. predict
    test_dataset = ThousandLandmarksDataset(os.path.join(args.data, 'test'),
                                            test_transforms,
                                            split="test")
    test_dataloader = data.DataLoader(test_dataset,
                                      batch_size=args.batch_size,
                                      num_workers=4,
                                      pin_memory=True,
                                      shuffle=False,
                                      drop_last=False)

    with open(f"{args.name}_best.pth", "rb") as fp:
        best_state_dict = torch.load(fp, map_location="cpu")
        model.load_state_dict(best_state_dict)

    test_predictions = predict(model, test_dataloader, device)
    with open(f"{args.name}_test_predictions.pkl", "wb") as fp:
        pickle.dump(
            {
                "image_names": test_dataset.image_names,
                "landmarks": test_predictions
            }, fp)

    create_submission(args.data, test_predictions, f"{args.name}_submit.csv")

    if args.draw:
        print("Drawing landmarks...")
        directory = os.path.join("result",
                                 test_dataset.image_names[0].split('.')[0])
        if not os.path.exists(directory):
            os.makedirs(directory)
        random_idxs = np.random.choice(len(test_dataset.image_names),
                                       size=1000,
                                       replace=False)
        for i, idx in enumerate(random_idxs, 1):
            image = cv2.imread(test_dataset.image_names[idx])
            image = draw_landmarks(image, test_predictions[idx])
            cv2.imwrite(os.path.join("result", test_dataset.image_names[idx]),
                        image)
Exemplo n.º 5
0
def main(args):
    # 1. prepare data & models
    # применение новых трансформаций не дало улучшения результатов
    # единственное изменение это параметры нормализации
    train_transforms = transforms.Compose([
        ScaleMinSideToSize((CROP_SIZE, CROP_SIZE)),
        CropCenter(CROP_SIZE),
        TransformByKeys(transforms.ToPILImage(), ("image", )),
        TransformByKeys(transforms.ToTensor(), ("image", )),
        TransformByKeys(
            transforms.Normalize(
                mean=[
                    0.485, 0.456, 0.406
                ],  # средние значения и дисперсии взяты из документации pytorch,
                std=[0.229, 0.224, 0.225]
            ),  # с такими же значениями обучалась сеть на корпусе imagenet
            ("image", ),
        ),
    ])
    device = torch.device("cuda: 0") if args.gpu else torch.device("cpu")
    print("Creating model...")
    model = models.resnext50_32x4d(pretrained=True)
    in_features = model.fc.in_features
    fc = nn.Sequential(nn.Linear(in_features, 2 * NUM_PTS), )  # новая "голова"
    model.fc = fc
    state_dict = None
    #  если есть сеть дообученная на датасете из контеста
    if args.pretrained_model:
        print(f"Load best_state_dict {args.pretrained_model}")
        state_dict = torch.load(args.pretrained_model)
        model.load_state_dict(state_dict)
        del state_dict

    model.to(device)
    print(model)

    factor = 0.1**(1 / 2)  # уменьшающий фактор для lr
    # оптимизатора выбран AdamW, с небольшой нормализацией весов
    optimizer = optim.AdamW(model.parameters(),
                            lr=args.learning_rate,
                            amsgrad=True,
                            weight_decay=0.05)
    loss_fn = fnn.mse_loss
    # изменения lr происходит при помощи ReduceLROnPlateau
    scheduler = ReduceLROnPlateau(
        optimizer,
        mode='min',
        patience=1,
        factor=factor,
    )

    print(loss_fn)
    print(optimizer)
    print(scheduler)

    print("Reading data...")
    print("Read train landmark dataset")
    train_dataset = ThousandLandmarksDataset(os.path.join(args.data, 'train'),
                                             train_transforms,
                                             split="train")
    print("Create picture loader for test dataset")
    train_dataloader = data.DataLoader(train_dataset,
                                       batch_size=args.batch_size,
                                       num_workers=0,
                                       pin_memory=True,
                                       shuffle=True,
                                       drop_last=True)
    print("Read val landmark dataset")
    val_dataset = ThousandLandmarksDataset(os.path.join(args.data, 'train'),
                                           train_transforms,
                                           split="val")
    print("Create picture loader for val dataset")
    val_dataloader = data.DataLoader(val_dataset,
                                     batch_size=args.batch_size,
                                     num_workers=0,
                                     pin_memory=True,
                                     shuffle=False,
                                     drop_last=False)

    # 2. train & validate
    print("Ready for training...")
    best_val_loss = np.inf
    for epoch in range(args.epochs):
        train_loss = train(model,
                           train_dataloader,
                           loss_fn,
                           optimizer,
                           device=device)
        val_loss = validate(model, val_dataloader, loss_fn, device=device)
        print("Epoch #{:2}:\ttrain loss: {:5.5}\tval loss: {:5.5}".format(
            epoch + 1, train_loss, val_loss))
        scheduler.step(val_loss)
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            with open(f"{args.name}_best.pth", "wb") as fp:
                torch.save(model.state_dict(), fp)

    # 3. predict
    print("Start predict")
    test_dataset = ThousandLandmarksDataset(os.path.join(args.data, 'test'),
                                            train_transforms,
                                            split="test")
    test_dataloader = data.DataLoader(test_dataset,
                                      batch_size=args.batch_size,
                                      num_workers=0,
                                      pin_memory=True,
                                      shuffle=False,
                                      drop_last=False)

    with open(f"{args.name}_best.pth", "rb") as fp:
        best_state_dict = torch.load(fp, map_location="cpu")
        model.load_state_dict(best_state_dict)

    test_predictions = predict(model, test_dataloader, device)
    with open(f"{args.name}_test_predictions.pkl", "wb") as fp:
        pickle.dump(
            {
                "image_names": test_dataset.image_names,
                "landmarks": test_predictions
            }, fp)

    create_submission(args.data, test_predictions, f"{args.name}_submit.csv")
def main(args):

    # 1. prepare data & models
    train_transforms = transforms.Compose([
        ScaleMinSideToSize((CROP_SIZE, CROP_SIZE)),
        CropCenter(CROP_SIZE),
        TransformByKeys(transforms.ToPILImage(), ("image", )),
        TransformByKeys(transforms.ToTensor(), ("image", )),
        TransformByKeys(
            transforms.Normalize(mean=[0.39963884, 0.31994772, 0.28253724],
                                 std=[0.33419772, 0.2864468, 0.26987]),
            ("image", )),
    ])

    print("Reading data...")
    train_dataset = ThousandLandmarksDataset(os.path.join(args.data, 'train'),
                                             train_transforms,
                                             split="train")
    train_dataloader = data.DataLoader(train_dataset,
                                       batch_size=args.batch_size,
                                       num_workers=4,
                                       pin_memory=True,
                                       shuffle=True,
                                       drop_last=True)
    val_dataset = ThousandLandmarksDataset(os.path.join(args.data, 'train'),
                                           train_transforms,
                                           split="val")
    val_dataloader = data.DataLoader(val_dataset,
                                     batch_size=args.batch_size,
                                     num_workers=4,
                                     pin_memory=True,
                                     shuffle=False,
                                     drop_last=False)

    print("Creating model...")
    device = torch.device("cuda: 0") if args.gpu else torch.device("cpu")

    #    model = models.wide_resnet101_2(pretrained=True)
    #     fc_layers = nn.Sequential(
    #                 nn.Linear(model.fc.in_features, model.fc.in_features),
    #                 nn.ReLU(inplace=True),
    #                 nn.Dropout(p=0.1),
    #                 nn.Linear(model.fc.in_features,  2 * NUM_PTS),
    #                 nn.ReLU(inplace=True),
    #                 nn.Dropout(p=0.1))
    #     model.fc = fc_layers

    model = models.resnext101_32x8d(pretrained=True)

    #   Uncomment for learning with freezed feature extractor
    #     for param in model.parameters():
    #         param.requires_grad = False

    model.fc = nn.Linear(model.fc.in_features, 2 * NUM_PTS, bias=True)

    if args.checkpoint is not None:
        model.load_state_dict(torch.load(args.checkpoint))
        print('PRETRAIDED LOADED')

    model.to(device)
    optimizer = optim.Adam(model.parameters(),
                           lr=args.learning_rate,
                           amsgrad=True)
    loss_fn = fnn.mse_loss
    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)

    # 2. train & validate
    writer = SummaryWriter()

    print("Ready for training...")
    best_val_loss = np.inf
    for epoch in range(args.epochs):

        train_loss = train(model,
                           train_dataloader,
                           loss_fn,
                           optimizer,
                           device=device,
                           epoch=epoch,
                           writer=writer)

        val_loss = validate(model,
                            val_dataloader,
                            loss_fn,
                            device=device,
                            epoch=epoch,
                            writer=writer)
        # if epoch > 0:
        scheduler.step()
        print(f"EPOCH {epoch} \n")

        print("Epoch #{:2}:\ttrain loss: {:5.2}\tval loss: {:5.2}".format(
            epoch, train_loss, val_loss))
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            with open(f"{args.name}_best.pth", "wb") as fp:
                torch.save(model.state_dict(), fp)

    # 3. predict
    test_dataset = ThousandLandmarksDataset(os.path.join(args.data, 'test'),
                                            train_transforms,
                                            split="test")
    test_dataloader = data.DataLoader(test_dataset,
                                      batch_size=args.batch_size,
                                      num_workers=4,
                                      pin_memory=True,
                                      shuffle=False,
                                      drop_last=False)

    with open(f"{args.name}_best.pth", "rb") as fp:
        best_state_dict = torch.load(fp, map_location="cpu")
        model.load_state_dict(best_state_dict)

    test_predictions = predict(model, test_dataloader, device)
    with open(f"{args.name}_test_predictions.pkl", "wb") as fp:
        pickle.dump(
            {
                "image_names": test_dataset.image_names,
                "landmarks": test_predictions
            }, fp)

    create_submission(args.data, test_predictions, f"{args.name}_submit.csv")
Exemplo n.º 7
0
def main(args):
    # 1. prepare data & models
    train_transforms = transforms.Compose([
        ScaleMinSideToSize((CROP_SIZE, CROP_SIZE)),
        CropCenter(CROP_SIZE),
        TransformByKeys(transforms.ToPILImage(), ("image", )),
        TransformByKeys(transforms.ToTensor(), ("image", )),
        TransformByKeys(
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]), ("image", )),
    ])

    test_transforms = transforms.Compose([
        ScaleMinSideToSize((CROP_SIZE, CROP_SIZE)),
        CropCenter(CROP_SIZE),
        TransformByKeys(transforms.ToPILImage(), ("image", )),
        TransformByKeys(transforms.ToTensor(), ("image", )),
        TransformByKeys(
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]), ("image", )),
    ])

    print(datetime.datetime.now())
    print("Reading data...")
    train_dataset = ThousandLandmarksDataset(os.path.join(args.data, 'train'),
                                             train_transforms,
                                             split="train")
    train_dataloader = data.DataLoader(train_dataset,
                                       batch_size=args.batch_size,
                                       num_workers=4,
                                       pin_memory=True,
                                       shuffle=True,
                                       drop_last=True)
    val_dataset = ThousandLandmarksDataset(os.path.join(args.data, 'train'),
                                           test_transforms,
                                           split="val")
    val_dataloader = data.DataLoader(val_dataset,
                                     batch_size=args.batch_size,
                                     num_workers=4,
                                     pin_memory=True,
                                     shuffle=False,
                                     drop_last=False)

    print(datetime.datetime.now())
    print("Creating model...")
    device = torch.device("cuda: 0") if args.gpu else torch.device("cpu")
    model = torch.hub.load(
        'facebookresearch/WSL-Images', 'resnext101_32x8d_wsl'
    )  # models.resnext50_32x4d(pretrained=True) # resnet18(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, 2 * NUM_PTS, bias=True)

    if os.path.isfile(f"{args.name}_best.pth"):
        print("Loading saved model " + f"{args.name}_best.pth")
        with open(f"{args.name}_best.pth", "rb") as fp:
            best_state_dict = torch.load(fp, map_location="cpu")
            model.load_state_dict(best_state_dict)

    model.to(device)

    optimizer = optim.Adam(model.parameters(),
                           lr=args.learning_rate,
                           amsgrad=True)
    loss_fn = fnn.l1_loss  # WingLoss() #fnn.mse_loss
    loss_val = fnn.mse_loss

    # 2. train & validate
    print("Ready for training...")
    best_val_loss = np.inf
    for epoch in range(args.epochs):
        train_loss = train(model,
                           train_dataloader,
                           loss_fn,
                           optimizer,
                           device=device)
        val_loss = validate(model, val_dataloader, loss_fn, device=device)
        val_loss_mse = validate(model, val_dataloader, loss_val, device=device)
        print(
            "Epoch #{:2}:\ttrain loss: {:5.4}\tval loss: {:5.4}\tval mse: {:5.4}"
            .format(epoch, train_loss, val_loss, val_loss_mse))

        if 1 == 1:  #val_loss < best_val_loss: #save results of all epoch to check several at kaggle
            best_val_loss = val_loss
            with open(f"{args.name}_best.pth", "wb") as fp:
                torch.save(model.state_dict(), fp)

            with open(f"{args.name}_" + str(epoch) + ".pth", "wb") as fp:
                torch.save(model.state_dict(), fp)

            # 3. predict
            print('Predict')
            test_dataset = ThousandLandmarksDataset(os.path.join(
                args.data, 'test'),
                                                    test_transforms,
                                                    split="test")
            test_dataloader = data.DataLoader(test_dataset,
                                              batch_size=args.batch_size,
                                              num_workers=4,
                                              pin_memory=True,
                                              shuffle=False,
                                              drop_last=False)

            #with open(f"{args.name}_best.pth", "rb") as fp:
            #    best_state_dict = torch.load(fp, map_location="cpu")
            #    model.load_state_dict(best_state_dict)

            test_predictions = predict(model, test_dataloader, device)
            with open(f"{args.name}_test_predictions.pkl", "wb") as fp:
                pickle.dump(
                    {
                        "image_names": test_dataset.image_names,
                        "landmarks": test_predictions
                    }, fp)

            create_submission(args.data, test_predictions,
                              f"{args.name}_submit_" + str(epoch) + ".csv")
def main(args):
    # 1. prepare data & models
    train_transforms = transforms.Compose([
        ScaleMinSideToSize((CROP_SIZE, CROP_SIZE)),
        CropCenter(CROP_SIZE),
        Cutout(10),
        RandomBlur(),
        TransformByKeys(transforms.ToPILImage(), ("image",)),
        TransformByKeys(transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.02), ("image",)),
        TransformByKeys(transforms.ToTensor(), ("image",)),
        TransformByKeys(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ("image",)),
    ])

    val_transforms = transforms.Compose([
        ScaleMinSideToSize((CROP_SIZE, CROP_SIZE)),
        CropCenter(CROP_SIZE),
        TransformByKeys(transforms.ToPILImage(), ("image",)),
        TransformByKeys(transforms.ToTensor(), ("image",)),
        TransformByKeys(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ("image",)),
    ])

    print("Creating model...")
    device = torch.device("cuda: 0") if args.gpu else torch.device("cpu")
    model = models.resnet50(pretrained=True)

    if args.freeze > 0:
        ct = 0
        for child in model.children():
            ct += 1
            if ct <= args.freeze + 4:
                for param in child.parameters():
                    param.requires_grad = False

    model.fc = nn.Linear(model.fc.in_features, 2 * NUM_PTS, bias=True)

    startEpoch = args.cont
    if startEpoch > 0:
        with open(f"{args.name}_best_{startEpoch}.pth", "rb") as fp:
            best_state_dict = torch.load(fp, map_location="cpu")
            model.load_state_dict(best_state_dict)

    model.to(device)

    if args.test:
        val_dataset = ThousandLandmarksDataset(os.path.join(args.data, 'train'), val_transforms, split="train")
        val_dataloader = data.DataLoader(val_dataset, batch_size=args.batch_size, num_workers=4, pin_memory=True,
                                         shuffle=False, drop_last=False)
        val_loss_fn = fnn.mse_loss

        val_full = validate_full(model, val_dataloader, val_loss_fn, device=device)

        res = dict(sorted(val_full.items(), key=lambda x: x[1], reverse=True)[:100])
        js = json.dumps(res)
        with open(f"{args.name}.json", "w") as f:
            f.write(js)
        print(res)
        return

    if not args.predict:
        print("Reading data...")
        train_dataset = ThousandLandmarksDataset(os.path.join(args.data, 'train'), train_transforms, split="train")
        train_dataloader = data.DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, pin_memory=True,
                                           shuffle=True, drop_last=True)
        val_dataset = ThousandLandmarksDataset(os.path.join(args.data, 'train'), val_transforms, split="val")
        val_dataloader = data.DataLoader(val_dataset, batch_size=args.batch_size, num_workers=4, pin_memory=True,
                                         shuffle=False, drop_last=False)

        optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=0.0001, nesterov=True)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2, verbose=True)
        train_loss_fn = nn.SmoothL1Loss(reduction="mean")
        val_loss_fn = fnn.mse_loss

        # 2. train & validate
        print("Ready for training...")
        best_val_loss = np.inf
        for epoch in range(startEpoch, args.epochs):
            train_loss = train(model, train_dataloader, train_loss_fn, optimizer, device=device)
            val_loss = validate(model, val_dataloader, val_loss_fn, device=device)
            scheduler.step(val_loss)
            print("Epoch #{:2}:\ttrain loss: {:.5f}\tval loss: {:.5f}".format(epoch, train_loss, val_loss))
            with open(f"{args.name}_res.txt", 'a+') as file:
                file.write("Epoch #{:2}:\ttrain loss: {:.5f}\tval loss: {:.5f}\n".format(epoch, train_loss, val_loss))

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                with open(f"{args.name}_best.pth", "wb") as fp:
                    torch.save(model.state_dict(), fp)

            if epoch > startEpoch and epoch % 5 == 0:
                best_val_loss = val_loss
                with open(f"{args.name}_best_{epoch}.pth", "wb") as fp:
                    torch.save(model.state_dict(), fp)

    # 3. predict
    test_dataset = ThousandLandmarksDataset(os.path.join(args.data, 'test'), val_transforms, split="test")
    test_dataloader = data.DataLoader(test_dataset, batch_size=args.batch_size, num_workers=4, pin_memory=True,
                                      shuffle=False, drop_last=False)

    with open(f"{args.name}_best.pth", "rb") as fp:
        best_state_dict = torch.load(fp, map_location="cpu")
        model.load_state_dict(best_state_dict)

    for layer in model.modules():
        layer.eval()

    test_predictions = predict(model, test_dataloader, device)
    with open(f"{args.name}_test_predictions.pkl", "wb") as fp:
        pickle.dump({"image_names": test_dataset.image_names,
                     "landmarks": test_predictions}, fp)

    create_submission(args.data, test_predictions, f"{args.name}_submit.csv")