Ejemplo n.º 1
0
    parser.add_argument("-ver", metavar="V", type=int, default=1, dest="ver")
    #{1: default 5 encoder unet.py, 4: 4 encoders unet_4.py}
    args = parser.parse_args()

    test_imgs, test_masks = np.load("test_imgs_1.npy"), np.load(
        "test_masks_1.npy")
    testset = Covid(imgs=test_imgs, masks=test_masks)

    testloader = data.DataLoader(testset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=12)

    a = "cuda:" + str(args.cuda)
    device = torch.device(a if torch.cuda.is_available() else "cpu")
    net = unet.run_cnn() if args.ver == 1 else unet_6.run_cnn()
    checkpoint = torch.load("models_6/" + args.pre + "/best.pt")
    net.load_state_dict(checkpoint["net"])
    net.to(device)
    net = net.eval()
    tot_val = 0.0
    tot_rand = 0.0
    countt = 0
    try:
        os.mkdir("mask_pred/" + args.pre)
    except:
        pass
    with torch.no_grad():
        for img, mask in testloader:
            mask_type = torch.float32
            img, mask = (img.to(device), mask.to(device, dtype=mask_type))
Ejemplo n.º 2
0
    return 1 - c


if __name__ == "__main__":
    test_imgs, test_masks = np.load("test_imgs_1.npy"), np.load(
        "test_masks_1.npy")
    testset = Covid(imgs=test_imgs, masks=test_masks)
    testloader = data.DataLoader(testset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=12)

    a = "cuda:6"
    device = torch.device(a if torch.cuda.is_available() else "cpu")

    net = unet.run_cnn()
    aug = int(input("1 or 0? "))
    ver = int(input("version 1/6: "))
    net = unet.run_cnn() if ver == 1 else unet_6.run_cnn()
    pretrain = input("File path of pretrained model: ")
    if aug > 0:
        checkpoint = torch.load(
            "models_aug/" + pretrain +
            "/best.pt", map_location="cuda:6") if ver == 1 else torch.load(
                "models_6_aug/" + pretrain + "/best.pt", map_location="cuda:6")
    else:
        checkpoint = torch.load(
            "models/" + pretrain +
            "/best.pt", map_location="cuda:6") if ver == 1 else torch.load(
                "models_6/" + pretrain + "/best.pt", map_location="cuda:6")
    net.load_state_dict(checkpoint["net"])