import threading
import Datastore as data
obj = data.Datastore()
threads = []
for _ in range(1):
    t = threading.Thread(target=obj.create,
                         args=['tirumala', '{name:tirumala,age:21}', 1000])
    t.start()
    threads.append(t)
for thread in threads:
    thread.join()
threads = []
for _ in range(1):
    t = threading.Thread(target=obj.read, args=['tirumala'])
    t.start()
    threads.append(t)
for thread in threads:
    thread.join()

threads = []
for _ in range(3):
    t = threading.Thread(target=obj.delete, args=['tirumala'])
    t.start()
    threads.append(t)
for thread in threads:
    thread.join()
예제 #2
0
파일: main.py 프로젝트: sfhbarnett/UNet
def main(mainpath, load=False, training=True, weights=False, rgb=0):

    torch.cuda.device(0)
    plt.ion()

    # If data is multi or single channel
    if rgb:
        tforms = transforms.Compose([transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        net = UNet(n_channels=3, n_classes=1)
    else:
        tforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.5, 0.5)])
        net = UNet(n_channels=1, n_classes=1)

    if training:
        trainpath = os.path.join(mainpath, 'image')
        filelist = os.listdir(trainpath)
        trainmasks = os.path.join(mainpath, 'label')
        masklist = os.listdir(trainmasks)

        if weights:
            if os.path.isdir(os.path.join(mainpath, 'weights')) != 1:
                os.mkdir(os.path.join(mainpath, 'weights'))
                print("generating weights")
                for file in masklist:
                    img = Image.open(os.path.join(mainpath, 'label', file))
                    weights = Datastore.generateWeights(img)
                    weights = Image.fromarray(weights)
                    weights.save(os.path.join(mainpath, 'weights', file[:-4]+'.tif'))
                print("generated weights")
                weightspath = os.path.join(mainpath, 'weights')
                weightslist = os.listdir(weightspath)
            else:
                weightspath = os.path.join(mainpath,'weights')
                weightslist = os.listdir(weightspath)

        dataset = Datastore.Datastore(filelist, masklist, weightslist, mainpath, transforms=tforms)
        batch_N = 1
        trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_N, shuffle=True, num_workers=0)
        N_train = len(dataset)
        gpu = 0
        startepoch = 0

        if gpu == 1:
            gpu = torch.device("cuda:0")
            print("Connected to device: ", gpu)
            net = net.to(gpu)

        epochs = 50
        lr = 0.001
        val_percent = 0.05
        optimizer = optim.SGD(net.parameters(),
                              lr=lr,
                              momentum=0.9)
        criterion = nn.BCEWithLogitsLoss()
        fig = plt.figure(figsize=(18, 5), dpi=80, facecolor='w', edgecolor='k')
        fig.tight_layout()

        # Load in previous model
        if load:
            try:
                checkpoint = torch.load('model2.pt')
                net.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                startepoch = checkpoint['epoch'] + 1
                loss = checkpoint['loss']
            except FileNotFoundError:
                print(f"No model file found at {mainpath}")

        train(net, optimizer, criterion, trainloader, startepoch, epochs, gpu, batch_N, N_train, mainpath)
    else:
        checkpoint = torch.load('model2.pt')
        net.load_state_dict(checkpoint['model_state_dict'])
        predict(net, mainpath)