Ejemplo n.º 1
0
def main():
    model = UNET(in_channels=1, out_channels=1).to(device=DEVICE)
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    train_loader, val_loader = get_loaders(TRAIN_DIR, VAL_DIR, BATCH_SIZE,
                                           NUM_WORKER, PIN_MEMORY)

    if LOAD_MODEL:
        load_checkpoint(torch.load("mycheckpoint.pth.tar"), model)

    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHS):
        train_fn(train_loader, model, optimizer, loss_fn, scaler)

        # TODO :  save model
        checkpoint = {
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()
        }
        save_checkpoint(checkpoint)

        # TODO : check acuuracy
        check_accuracy(val_loader, model, device=DEVICE)

        # TODO :  Print results to folder
        save_predictions_as_imgs(val_loader, model, folder='saved_imgs/')
Ejemplo n.º 2
0
def main():
    train_transform = A.Compose([
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ], )

    val_transforms = A.Compose([
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ])

    model = UNET(in_channels=3, out_channels=1).to(DEVICE)
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    train_loader, val_loader = get_loaders(TRAIN_IMG_DIR, TRAIN_MASK_DIR,
                                           VAL_IMG_DIR, VAL_MASK_DIR,
                                           BATCH_SIZE, train_transform,
                                           val_transforms, NUM_WORKERS,
                                           PIN_MEMORY)

    if LOAD_MODEL:
        load_checkpoint(torch.load("my_checkpoint.pth.tar"))

    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHS):
        train_fn(train_loader, model, optimizer, loss_fn, scaler)

        # save model
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        }
        save_checkpoint(checkpoint)

        # check accuracy
        check_accuracy(val_loader, model, device=DEVICE)

        # print some example to a folder
        save_prediction_as_imgs(val_loader,
                                model,
                                folder="saved_images/",
                                device=DEVICE)
Ejemplo n.º 3
0
def predict_single():
    device = "cpu"

    image = np.array(Image.open('data/val_images/0.png').convert("RGB"))
    mask = np.array(Image.open('data/val_masks/0.bmp').convert("L"), dtype=np.float32)
    mask[mask == 255.0] = 1
    augmentations = val_transforms(image=image, mask=mask)
    image = augmentations["image"]
    mask = augmentations["mask"]

    plt.imshow(image.squeeze().permute(1, 2, 0))
    plt.show()
    plt.imshow(mask, cmap='gray')
    plt.show()

    image = torch.tensor(image, requires_grad=True).to(device)
    image = image.unsqueeze(0)

    model = UNET(in_channels=3, out_channels=1).to(device)
    load_checkpoint(torch.load("check_Unet_99_95.pth.tar", map_location=torch.device('cpu')), model)
    # image = image.to(device=device)
    model.eval()
    with torch.no_grad():
        preds = torch.sigmoid(model(image))
        preds = (preds > 0.5).float()
    torchvision.utils.save_image(preds, "./pred_100.png")
    model.train()
Ejemplo n.º 4
0
def main():
    model = UNET(in_channels=3, out_channels=1).to(config.DEVICE)
    BCE = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE)

    train_loader, val_loader = get_loaders(
        train_dir=config.TRAIN_IMG_DIR,
        train_mask_dir=config.TRAIN_MASK_DIR,
        val_dir=config.VAL_IMG_DIR,
        val_mask_dir=config.VAL_MASK_DIR,
        batch_size=config.BATCH_SIZE,
        train_transform=config.train_transform,
        val_transform=config.val_transform,
        num_workers=config.NUM_WORKERS,
        pin_memory=config.PIN_MEMORY,
    )

    if config.LOAD_MODEL:
        load_checkpoint(torch.load(config.CHECKPOINT_PTH), model)

    check_accuracy(val_loader, model)
    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(config.NUM_EPOCHS):
        train_fn(train_loader, model, optimizer, BCE, scaler, val_loader)

        # save model
        if config.SAVE_MODEL:
            checkpoint = {
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
            }
            save_checkpoint(checkpoint) 

        # check accuracy
        check_accuracy(val_loader, model)

        # print some example
        save_predictions_as_imgs(val_loader, model, folder=config.SAVE_IMAGES)
Ejemplo n.º 5
0
def main():
    train_transform = tr.Compose([
        tr.Resize((160, 240)),
        tr.ToTensor(),
        tr.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    model = UNET(in_channels=3, out_channels=3).to(DEVICE)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    train_loader, val_loader = get_loaders(
        TRAIN_IMG_DIR,
        TRAIN_BLUR_DIR,
        BATCH_SIZE,
        train_transform=train_transform,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
    )

    if LOAD_MODEL:
        load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)

    check_accuracy(val_loader, model, device=DEVICE)
    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHS):
        train_model(train_loader, model, optimizer, loss_fn, scaler)

        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        }
        save_checkpoint(checkpoint)

        check_accuracy(val_loader, model, device=DEVICE)
Ejemplo n.º 6
0
def predict():
    print('asdasdasd')
    model = UNET(in_channels=3, out_channels=1).to(DEVICE)
    load_checkpoint(torch.load("my_checkpoint.pth.tar", map_location=torch.device('cpu')), model)

    train_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Rotate(limit=35, p=1.0),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

    val_transforms = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

    train_loader, val_loader = get_loaders(
        TRAIN_IMG_DIR,
        TRAIN_MASK_DIR,
        VAL_IMG_DIR,
        VAL_MASK_DIR,
        BATCH_SIZE,
        train_transform,
        val_transforms,
        NUM_WORKERS,
        PIN_MEMORY,
    )

    save_predictions_as_imgs(
        val_loader, model, folder="saved_images/", device=DEVICE
    )
Ejemplo n.º 7
0
def generate_masks():
    # https://stackoverflow.com/questions/42703500/best-way-to-save-a-trained-model-in-pytorch

    model = UNET(in_channels=3, out_channels=1)

    model.load_state_dict(torch.load(r"model" + r"\blueno_detection.pth"))

    is_cuda = torch.cuda.is_available()

    if is_cuda:
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    model.to(device)

    test_x = []

    image_names = []

    # get the image names to label the prediction masks appropriately
    for path_img in os.listdir(my_path + "/test_model/test_images"):
        image_names.append(path_img)
        full_path_img = os.path.join(my_path + "/test_model/test_images",
                                     path_img)

        image = np.array(Image.open(full_path_img).convert("RGB"))

        augmentations = transform(image=image)

        test_x.append(augmentations["image"])

    test_x = torch.stack(test_x)

    # the second test_x is a dummy to make the tensor of appropriate dimension so that it may be passed to the model
    test_data = TensorDataset(test_x, test_x)

    test_loader = DataLoader(
        test_data,
        batch_size=1,
        num_workers=1,
        pin_memory=True,
        shuffle=False,
    )

    Path(my_path + "/test_model/generated_masks/").mkdir(parents=True,
                                                         exist_ok=True)

    for file in os.listdir(my_path + "/test_model/generated_masks/"):
        os.remove(my_path + "/test_model/generated_masks/" + file)

    save_prediction_masks(test_loader, model, image_names,
                          "test_model/generated_masks/", device)
Ejemplo n.º 8
0
def main():
    # TODO: Might be worth trying the normalization from assignment 2
    train_transform = A.Compose([
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ], )

    val_transforms = A.Compose([
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ], )

    model = UNET(in_channels=3, out_channels=1).to(DEVICE)
    """
    We're using with logitsLoss because we're not using sigmoid on the,
    final output layer.
    If we wanted to have several output channels, we'd change the loss_fn
    to a cross entropy loss instead.
    """
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    train_loader, val_loader = get_loaders(
        TRAIN_IMG_DIR,
        TRAIN_MASK_DIR,
        VAL_IMG_DIR,
        VAL_MASK_DIR,
        BATCH_SIZE,
        train_transform,
        val_transforms,
        NUM_WORKERS,
        PIN_MEMORY,
    )

    if LOAD_MODEL:
        load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)

    scaler = torch.cuda.amp.GradScaler(
    )  # Scales the gradients to avoid underflow. Requires a GPU

    for epoch in range(NUM_EPOCHS):
        train_fn(train_loader, model, optimizer, loss_fn, scaler)

        # save model
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        }
        save_checkpoint(checkpoint)

        # check accuracy
        check_accuracy(val_loader, model, device=DEVICE)

        # print some examples to a folder
        save_predictions_as_imgs(val_loader,
                                 model,
                                 folder="saved_images/",
                                 device=DEVICE)
Ejemplo n.º 9
0
from dataset import CaravanImageDataLoader, CaravanImageDataset


def train_model():
    #dataset = CaravanImageDataset("./caravan_images")
    #model = UNET()
    pass


def image_segmentation_accuracy(target, prediction):
    pass


if __name__ == "__main__":
    dev = "cuda" if torch.cuda.is_available() else "cpu"
    from pudb import set_trace as st
    st()
    trn_tf, val_tf = CaravanImageDataLoader.get_default_transforms(
        height=1280 // 4, width=1920 // 4)
    data = CaravanImageDataLoader("../caravan_images", 2, trn_tf, val_tf)
    model = UNET().to(device=dev)
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    mm = ModelManager(model, data, loss_fn, optimizer, 'WD')
    mm.train_epoch()
    # Size: batch_size, nFeatures, Heigh, Width
    #tt = torch.rand((2, 1, 572, 572))
    #model.forward(tt)

    pass
Ejemplo n.º 10
0
def train():
    # network = EncoderDecoder()
    network = UNET()
    network = nn.DataParallel(network)
    try:
        if pretrained_model_file_path != None and os.path.isfile(pretrained_model_file_path):
            network.load_state_dict(torch.load(pretrained_model_file_path))
            print('Network weights initialized from file at:', os.path.abspath(pretrained_model_file_path))
    except Exception:
        print('Unable to initialize network weights from file at:', os.path.abspath(pretrained_model_file_path))
    network.to(MODEL['DEVICE'])
    network.train()

    train_dataset = NoiseDataloader(dataset_type=NoiseDataloader.TRAIN,
                                    noisy_per_image=DATASET['NOISY_PER_IMAGE'],
                                    noise_type=DATASET['NOISE_TYPE'])

    train_batcher = DataLoader(dataset=train_dataset,
                               batch_size=MODEL['BATCH_SIZE'],
                               # shuffle=True)
                               shuffle=True,
                               num_workers=MODEL['NUM_WORKERS'])

    optimizer = optim.Adam(network.parameters(),
                           lr=OPTIMIZER['LR'],
                           betas=OPTIMIZER['BETAS'],
                           eps=OPTIMIZER['EPSILON'])

    instance = 0
    while os.path.isdir(os.path.join(pp.trained_models_folder_path, 'Instance_' + str(instance).zfill(3))):
        instance += 1
    os.mkdir(os.path.join(pp.trained_models_folder_path, 'Instance_' + str(instance).zfill(3)))

    num_batches = math.floor(len(train_dataset) / MODEL['BATCH_SIZE'])
    for epoch in range(MODEL['NUM_EPOCHS']):

        epoch_start_time = time.time()
        print('-' * 80)
        print('Epoch: {} of {}...'.format(epoch + 1, MODEL['NUM_EPOCHS']))

        epoch_loss = 0
        batch_counter = 1

        for batch in train_batcher:  # Get Batch
            print('\tProcessing Batch: {} of {}...'.format(batch_counter, num_batches))
            batch_counter += 1

            input_noisy_patch, output_noisy_patch = batch
            input_noisy_patch = input_noisy_patch.to(MODEL['DEVICE'])
            output_noisy_patch = output_noisy_patch.to(MODEL['DEVICE'])

            denoised_input_patch = network(input_noisy_patch)  # Pass Batch

            loss = OPTIMIZER['LOSS_FUNCTION'](denoised_input_patch, output_noisy_patch)  # Calculate Loss

            epoch_loss += loss
            optimizer.zero_grad()
            loss.backward()  # Calculate Gradients
            optimizer.step()  # Update Weights
            print('\tBatch (Train) Loss:', loss)
            print()

        epoch_end_time = time.time()
        torch.save(network.state_dict(),
                   os.path.join(pp.trained_models_folder_path, 'Instance_' + str(instance).zfill(3), 'Model_Epoch_{}.pt'.format(str(epoch).zfill(3))))

        print('Epoch (Train) Loss:', epoch_loss)
        print('Epoch (Train) Time:', epoch_end_time - epoch_start_time)
        print('-' * 80)
Ejemplo n.º 11
0
def denosie_using_noise2noise(noisy_image):
    denoised_image = network(
        torch.unsqueeze(torch.as_tensor(
            NoiseDataloader.convert_image_to_model_input(noisy_image)),
                        dim=0))[0]
    denoised_image = NoiseDataloader.convert_model_output_to_image(
        denoised_image)

    return denoised_image


pretrained_model_folder_path = os.path.join(pp.trained_models_folder_path,
                                            'Instance_000',
                                            'Model_Epoch_012.pt')
global network
network = UNET()
network = nn.DataParallel(network)
network.load_state_dict(torch.load(pretrained_model_folder_path))

# Custom Image
noise_type = 'Gaussian'
for image_file_name in os.listdir(os.path.join(pp.real_world_data,
                                               noise_type)):
    image_file_path = os.path.join(pp.real_world_data, noise_type,
                                   image_file_name)
    if os.path.isfile(image_file_path) and os.path.splitext(
            image_file_name)[1].lower() in IMAGE_EXTENSIONS:
        noisy_image = np.asarray(cv2.cvtColor(
            cv2.imread(image_file_path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) /
                                 255,
                                 dtype=np.float32)
Ejemplo n.º 12
0
import os
import numpy as np
import torch
import torchvision
from albumentations.pytorch import ToTensorV2
import albumentations as A
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from PIL import Image
import matplotlib.pyplot as plt
from model import UNET

model = UNET(in_channels=3, out_channels=1)

val_transforms = A.Compose([
    A.Resize(height=480, width=720),
    A.Normalize(
        mean=[0.0, 0.0, 0.0],
        std=[1.0, 1.0, 1.0],
        max_pixel_value=255.0,
    ),
    ToTensorV2(),
], )


class MRZ_Dataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)
Ejemplo n.º 13
0
device = "cpu"

image = np.array(Image.open('19.png').convert("RGB"))
mask = np.array(Image.open('19.bmp').convert("L"), dtype=np.float32)
mask[mask == 255.0] = 1
augmentations = val_transforms(image=image, mask=mask)
image = augmentations["image"]
mask = augmentations["mask"]

plt.imshow(image.squeeze().permute(1, 2, 0))
plt.show()
plt.imshow(mask, cmap='gray')
plt.show()

image = torch.tensor(image, requires_grad=True).to(device)
image = image.unsqueeze(0)

model = UNET(in_channels=3, out_channels=1).to(device)
print("=> Loading checkpoint")

model.load_state_dict(
    torch.load("check_Unet_99_95.pth.tar",
               map_location=torch.device('cpu'))["state_dict"])
# image = image.to(device=device)
model.eval()
with torch.no_grad():
    preds = torch.sigmoid(model(image))
    preds = (preds > 0.5).float()
torchvision.utils.save_image(preds, "./pred_100.png")
model.train()
Ejemplo n.º 14
0
        val_dir = VAL_IMG_DIR,
        val_maskdir = VAL_MASK_DIR,
        batch_size = batch,
        num_workers = workers,
        pin_memory = pin_memory,
    )
    
    
    """
    dataset = Microscopy_dataset(image_dir, mask_dir)

    train_loader = DataLoader(trainset, batch_size=batch, num_workers=1, shuffle=True, drop_last=True)
    eval_loader = DataLoader(evalset, batch_size=batch, num_workers=1, shuffle=True, drop_last=True)
    """
    
    model = UNET(n_channels=1, n_classes=1).to(device=device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    
    #criterion = nn.MSELoss()
    #criterion = nn.BCELoss()
    
    training_loss = []
    eval_loss = []

    for epoch in range(epoch):
        train()
        evaluate()
        plot_losses()
        if (epoch % 10) == 0:
            torch.save(model.state_dict(), f"unet_{attempt}_{epoch}.pt")
Ejemplo n.º 15
0
def main():
    model = UNET(in_channels=3, out_channels=1).to(device)
    loss_fun = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    train_loader, test_loader = get_loaders(batch_size)

    print("Training Model")
    print("==============")

    epoch_count = 1
    for epoch in range(epochs):
        print("Epoch ", epoch_count)
        print("---------")
        train_loading_bar = tqdm(train_loader, position=0, leave=True)
        model.train()

        train_correct_pixels = 0
        train_total_pixels = 0

        count = 0
        # iterate over the train data loader
        for _, (pixel_data, target_masks) in enumerate(train_loading_bar):
            count += 1
            pixel_data = pixel_data.to(device=device)
            target_masks_unsqueezed = target_masks.float().unsqueeze(1).to(
                device=device)

            model.zero_grad()

            predictions = model(pixel_data)

            loss = loss_fun(predictions, target_masks_unsqueezed)
            loss.backward()

            # get and accumualate the train accuracy
            (correct_pixels,
             total_pixels) = get_accuracy(predictions, target_masks, device)

            train_correct_pixels = train_correct_pixels + correct_pixels
            train_total_pixels = train_total_pixels + total_pixels

            optimizer.step()

            train_loading_bar.set_postfix(loss=loss.item())

        print(
            f"\nTrain Accuracy: {train_correct_pixels/train_total_pixels*100:.2f}%"
        )

        model.eval()

        epoch_count += 1

    # save model upon training
    print("Training Complete!")

    Path(my_path + "/model").mkdir(parents=True, exist_ok=True)

    torch.save(model.state_dict(), r"model" + r"\blueno_detection.pth")

    test_loading_bar = tqdm(test_loader)

    test_correct_pixels = 0
    test_total_pixels = 0

    print("Testing Model")
    print("=============")
    count = 0
    # iterate over the test data loader
    for _, (pixel_data, target_masks) in enumerate(test_loading_bar):
        count += 1
        pixel_data = pixel_data.to(device=device)
        target_masks_unsqueezed = target_masks.float().unsqueeze(1).to(
            device=device)

        predictions = model(pixel_data)

        # get and accumualate the test accuracy
        (correct_pixels, total_pixels) = get_accuracy(predictions,
                                                      target_masks, device)

        test_correct_pixels = test_correct_pixels + correct_pixels
        test_total_pixels = test_total_pixels + total_pixels

        test_loading_bar.set_postfix(loss=loss.item())

    print(f"\nTest Accuracy: {test_correct_pixels/test_total_pixels*100:.2f}%")
def main():
    train_transforms = A.Compose([
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
    ])

    val_transforms = A.Compose([
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
    ])

    trainset = CarvanaDataset(img_path='train',
                              mask_path='train_masks',
                              dataframe=train_df,
                              transform=train_transforms)
    valset = CarvanaDataset(img_path='train',
                            mask_path='train_masks',
                            dataframe=val_df,
                            transform=val_transforms)

    trainloader = DataLoader(trainset,
                             batch_size=BATCH_SIZE,
                             collate_fn=trainset.collate_fn,
                             num_workers=NUM_WORKERS,
                             pin_memory=PIN_MEMORY,
                             shuffle=True)
    valloader = DataLoader(valset,
                           batch_size=BATCH_SIZE,
                           collate_fn=valset.collate_fn,
                           num_workers=NUM_WORKERS,
                           pin_memory=PIN_MEMORY)

    print(f"Number of training samples: {len(trainset)}")
    print(f"Number of validating samples: {len(valset)}")

    # if out_channels > 1 then use cross entropy loss for multiple classes
    model = UNET(in_channels=3, out_channels=1).to(DEVICE)
    # not doing Sigmoid at output layer
    criterion = nn.BCEWithLogitsLoss().to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scaler = torch.cuda.amp.GradScaler()  # prevent underflow

    if LOAD_MODEL:
        load_checkpoint(CHECKPOINT, model, optimizer)
        check_accuracy(valloader, model, device=DEVICE)

    else:
        for epoch in range(NUM_EPOCHS):
            train(trainloader, model, optimizer, criterion, scaler)

            checkpoint = {
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            }

            save_checkpoint(checkpoint, filename=CHECKPOINT)

            check_accuracy(valloader, model, device=DEVICE)

            save_predictions_as_imgs(valloader,
                                     model,
                                     folder='saved_images',
                                     device=DEVICE)
Ejemplo n.º 17
0
import matplotlib.pyplot as plt
from model import UNET

argparser = argparse.ArgumentParser(description='U-net predictions')

argparser.add_argument('-c', '--conf', help='path to config file')

argparser.add_argument('-i', '--input', help='path to input image')

argparser.add_argument('-w', '--weight', help='path to input weight')

with open(argparser.parse_args().conf) as raw:
    config = json.load(raw)

img_path = argparser.parse_args().input
wei_path = argparser.parse_args().weight

unet = UNET(config)
unet.load_weights(wei_path)

I = skimage.io.imread(img_path)
I = skimage.transform.resize(I, (388, 388))
Ip = skimage.util.pad(I, ((92, 92), (92, 92), (0, 0)), 'constant')
out = unet.model.predict(Ip.reshape(1, *Ip.shape)).argmax(axis=-1).squeeze()

plt.figure(figsize=(10, 10))
plt.imshow(I)
plt.imshow(1 - out, alpha=0.2)
plt.axis('off')
plt.show()
Ejemplo n.º 18
0
import json
import skimage
import matplotlib.pyplot as plt
import argparse
from model import UNET

argparser = argparse.ArgumentParser(
    description='end-to-end U-net training')

argparser.add_argument(
    '-c',
    '--conf',
    help='path to config file')

with open(argparser.parse_args().conf) as raw:
    config = json.load(raw)

unet = UNET(config)
unet.train()