Esempio n. 1
0
import torch
import numpy as np
from configure import Config
from model import WNet
from DataLoader import DataLoader
from Ncuts import NCutsLoss
import time
import os

config = Config()
os.environ["CUDA_VISIBLE_DEVICES"] = config.cuda_dev_list
if __name__ == '__main__':
    dataset = DataLoader(config.pascal, "train")
    dataloader = dataset.torch_loader()
    #model = torch.nn.DataParallel(Net(True))
    model = torch.nn.DataParallel(WNet())
    model.cuda()
    #model.to(device)
    model.train()
    #optimizer
    optimizer = torch.optim.SGD(model.parameters(), lr=config.init_lr)
    #reconstr = torch.nn.MSELoss().cuda(config.cuda_dev)
    Ncuts = NCutsLoss()
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=config.lr_decay_iter,
                                                gamma=config.lr_decay)
    for epoch in range(config.max_iter):
        print("Epoch: " + str(epoch + 1))
        scheduler.step()
        Ave_Ncuts = 0.0
        #Ave_Rec = 0.0
Esempio n. 2
0
import numpy as np
from configure import Config
from model import WNet
from Ncuts import NCutsLoss
from DataLoader import DataLoader
import time
import os
import torchvision
import pdb
from PIL import Image

config = Config()
if __name__ == '__main__':
    dataset = DataLoader(config.datapath, "test")
    dataloader = dataset.torch_loader()
    model = WNet()
    model.cuda(config.cuda_dev)
    optimizer = torch.optim.SGD(model.parameters(), lr=config.init_lr)
    #optimizer
    with open(config.model_tested, 'rb') as f:
        para = torch.load(f, "cuda:0")
        pdb.set_trace()
        model.load_state_dict(para['state_dict'], False)
    for step, [x] in enumerate(dataloader):
        print('Step' + str(step + 1))
        #NCuts Loss

        x = x.cuda(config.cuda_dev)
        pred, rec_image = model(x)
        seg = (pred.argmax(dim=1).to(torch.float) / 3 *
               255).cpu().detach().numpy()
Esempio n. 3
0
def main():
    print("PyTorch Version: ", torch.__version__)
    if torch.cuda.is_available():
        print("Cuda is available. Using GPU")

    config = Config()

    ###################################
    # Image loading and preprocessing #
    ###################################

    #TODO: Maybe we should crop a large square, then resize that down to our patch size?
    # For now, data augmentation must not introduce any missing pixels TODO: Add data augmentation noise
    train_xform = transforms.Compose([
        transforms.RandomCrop(224),
        transforms.Resize(128),
        transforms.RandomCrop(
            config.input_size +
            config.variationalTranslation),  # For now, cropping down to 224
        transforms.RandomHorizontalFlip(
        ),  # TODO: Add colorjitter, random erasing
        transforms.ToTensor()
    ])
    val_xform = transforms.Compose([
        transforms.CenterCrop(224),
        transforms.Resize(128),
        transforms.CenterCrop(config.input_size),
        transforms.ToTensor()
    ])

    #TODO: Load validation segmentation maps too  (for evaluation purposes)
    train_dataset = AutoencoderDataset("train", train_xform)
    val_dataset = AutoencoderDataset("val", val_xform)

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        num_workers=4,
        shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=4,
                                                 num_workers=4,
                                                 shuffle=False)

    util.clear_progress_dir()

    ###################################
    #          Model Setup            #
    ###################################

    autoencoder = WNet()
    if torch.cuda.is_available():
        autoencoder = autoencoder.cuda()
    optimizer = torch.optim.Adam(autoencoder.parameters())
    if config.debug:
        print(autoencoder)
    util.enumerate_params([autoencoder])

    # Use the current time to save the model at end of each epoch
    modelName = str(datetime.now())

    ###################################
    #          Loss Criterion         #
    ###################################

    def reconstruction_loss(x, x_prime):
        binary_cross_entropy = F.binary_cross_entropy(x_prime,
                                                      x,
                                                      reduction='sum')
        return binary_cross_entropy

    ###################################
    #          Training Loop          #
    ###################################

    autoencoder.train()

    progress_images, progress_expected = next(iter(val_dataloader))

    for epoch in range(config.num_epochs):
        running_loss = 0.0
        for i, [inputs, outputs] in enumerate(train_dataloader, 0):

            if config.showdata:
                print(inputs.shape)
                print(outputs.shape)
                print(inputs[0])
                plt.imshow(inputs[0].permute(1, 2, 0))
                plt.show()

            if torch.cuda.is_available():
                inputs = inputs.cuda()
                outputs = outputs.cuda()

            optimizer.zero_grad()

            segmentations, reconstructions = autoencoder(inputs)

            l_soft_n_cut = soft_n_cut_loss(inputs, segmentations)
            l_reconstruction = reconstruction_loss(
                inputs if config.variationalTranslation == 0 else outputs,
                reconstructions)

            loss = (l_reconstruction + l_soft_n_cut)
            loss.backward(
                retain_graph=False
            )  # We only need to do retain graph =true if we're backpropping from multiple heads
            optimizer.step()

            if config.debug and (i % 50) == 0:
                print(i)

            # print statistics
            running_loss += loss.item()

            if config.showSegmentationProgress and i == 0:  # If first batch in epoch
                util.save_progress_image(autoencoder, progress_images, epoch)
                optimizer.zero_grad()  # Don't change gradient on validation

        epoch_loss = running_loss / len(train_dataloader.dataset)
        print(f"Epoch {epoch} loss: {epoch_loss:.6f}")

        if config.saveModel:
            util.save_model(autoencoder, modelName)
Esempio n. 4
0
import torch
import numpy as np
from configure import Config
from model import WNet
from Ncuts import NCutsLoss
from DataLoader import DataLoader
import time
import os
import pdb

config = Config()
cuda_device = torch.cuda.device(config.cuda_dev)
if __name__ == '__main__':
    dataset = DataLoader(config.datapath,"train")
    dataloader = dataset.torch_loader()
    model = torch.nn.DataParallel(WNet(),config.cuda_dev_list)
    model.cuda(config.cuda_dev)
    #optimizer
    optimizer = torch.optim.SGD(model.parameters(),lr = config.init_lr)
    reconstr = torch.nn.MSELoss().cuda(config.cuda_dev)
    Ncuts = torch.nn.DataParallel(NCutsLoss(),config.cuda_dev_list).cuda(config.cuda_dev)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=config.lr_decay_iter, gamma=config.lr_decay)
    for epoch in range(config.max_iter):
        print("Epoch: "+str(epoch+1))
        scheduler.step()
        Ave_Ncuts = 0.0
        Ave_Rec = 0.0
        for step,[x,w] in enumerate(dataloader):
            # print('Step' + str(step+1))
            #NCuts Loss
            timer = time.time()
Esempio n. 5
0
if __name__ == '__main__':
    is_cuda = torch.cuda.is_available()
    #ds_test = AbdomenDS("/Users/dhanunjayamitta/Desktop/single_val","test", config.datasetMode, config.interpFactor)
    #ds_val = AbdomenDS("/raid/scratch/schatter/Dataset/dhanun/MRI/MRI_Val","train",(0.5,0.5,0.5))
    ds_train = AbdomenDS("/raid/scratch/schatter/Dataset/dhanun/MRI/MRITTemp",
                         "test", config.interpFactor)
    # ds_val = AbdomenDS("/raid/scratch/schatter/Dataset/dhanun/MRI/MRITTemp","train",(1,0.5,0.5))

    checkname = None  #'/raid/scratch/schatter/Dataset/dhanun/checkpoints/checkpoint_9_14_13_52_epoch_15' #Set None if no need to load

    result_upsample = True
    #upunterp_fact = None
    upinterp_fact = (1, 1, 1)
    #model_downscale = False
    dataloader = DataLoader(ds_test, batch_size=config.BatchSize, shuffle=True)
    model = WNet(is_cuda)
    #model = WNet()
    if is_cuda:
        device = torch.device("cuda:1")
    else:
        device = torch.device("cpu")
    model.to(device)
    #model.cuda()
    model.eval()
    #model_downscale = False
    mode = 'test'
    optimizer = torch.optim.Adam(model.parameters(), lr=config.init_lr)
    #optimizer
    with open(config.model_tested, 'rb') as f:
        para = torch.load(f, "cpu")
        #para = torch.load(f,"cuda:0")
Esempio n. 6
0
    # ds_train = AbdomenDS("/raid/scratch/schatter/Dataset/dhanun/MRI/MRITTemp","train",config.interpFactor)
    #ds_val = AbdomenDS("/Users/dhanunjayamitta/Desktop/single_train","train", config.datasetMode, config.interpFactor)
    ds_val = AbdomenDS("/raid/scratch/schatter/Dataset/dhanun/MRI/MRI_Val",
                       "train", config.datasetMode, config.interpFactor)
    # ds_val = AbdomenDS("/raid/scratch/schatter/Dataset/dhanun/MRI/MRITTemp","train",config.datasetMode, config.interpFactor)

    checkname = None  #"/Users/dhanunjayamitta/Desktop/Archive 8 fresh/checkpoints/checkpoint_9_16_17_8_epoch_990" #Set None if no need to load

    dataloader = DataLoader(ds_train,
                            batch_size=config.BatchSize,
                            shuffle=True)
    dataloader1 = DataLoader(ds_val, batch_size=config.BatchSize, shuffle=True)
    #eval_set = DataLoader("MRI/new_test","train")

    #eval_loader = eval_set.torch_loader()
    model = WNet(is_cuda)
    #model = torch.nn.DataParallel(WNet())
    if is_cuda:
        device1 = torch.device("cuda:1")
        device2 = torch.device("cuda:0")
    else:
        device1 = torch.device("cpu")
        device2 = torch.device("cpu")
    model.to(device1)
    #model_eval = torch.nn.DataParallel(WNet())
    #model.cuda()
    #model_eval.cuda()
    #model_eval.eval()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.init_lr)
    #reconstr = torch.nn.MSELoss().cuda(config.cuda_dev)
    if config.useSSIMLoss:
Esempio n. 7
0
from configure import Config
from model import WNet
from Ncuts import NCutsLoss
from DataLoader import DataLoader
import time
import os
import torchvision
import pdb
from PIL import Image

config = Config()
os.environ["CUDA_VISIBLE_DEVICES"] = config.cuda_dev_list
if __name__ == '__main__':
    dataset = DataLoader(config.bsds, "test")
    dataloader = dataset.torch_loader()
    model = WNet()
    model.cuda()
    model.eval()
    optimizer = torch.optim.SGD(model.parameters(), lr=config.init_lr)
    #optimizer
    with open(config.model_tested, 'rb') as f:
        para = torch.load(f, "cuda:0")
        model.load_state_dict(para['state_dict'])
    for step, [x] in enumerate(dataloader):
        print('Step' + str(step + 1))
        #NCuts Loss

        x = x.cuda()
        pred, pad_pred = model(x)
        seg = (pred.argmax(dim=1)).cpu().detach().numpy()
        x = x.cpu().detach().numpy() * 255