Example #1
0
File: tests.py Project: yhu9/WNet
def WNetTest():
    encoded=EncoderTest(verbose=False)
    decoder=WNet.UDec(4)
    reproduced=decoder(encoded)
    var=torch.var(reproduced)
    mean=torch.mean(reproduced)
    print('Passed Decoder Test with var=%s and mean=%s' % (var, mean))
Example #2
0
File: tests.py Project: yhu9/WNet
def DecoderTest():
    shape=(2, 4, 224, 224)
    out_shape=(2, 3, 224, 224)
    decoder=WNet.UDec(shape[1])
    data=torch.rand(tuple(shape))
    decoded=decoder(data)
    assert tuple(decoded.shape)==out_shape
    var=torch.var(decoded)
    mean=torch.mean(decoded)
    print('Passed Decoder Test with var=%s and mean=%s' % (var, mean))
Example #3
0
File: tests.py Project: yhu9/WNet
def EncoderTest(verbose=True):
    shape=(2, 4, 224, 224)
    encoder=WNet.UEnc(shape[1])
    data=torch.rand((shape[0], 3, shape[2], shape[3]))
    encoded=encoder(data)
    assert tuple(encoded.shape)==shape
    var=torch.var(encoded)
    mean=torch.mean(encoded)
    if verbose:
        print('Passed Encoder Test with var=%s and mean=%s' % (var, mean))
    return encoded
Example #4
0
def main():
    args = parser.parse_args()
    model = WNet.WNet(args.squeeze)

    model.load_state_dict(torch.load(args.model))
    model.eval()

    transform = transforms.Compose(
        [transforms.Resize((64, 64)),
         transforms.ToTensor()])

    image = Image.open(args.image).convert('RGB')
    x = transform(image)[None, :, :, :]

    enc, dec = model(x)
    show_image(x[0])
    show_image(enc[0, :1, :, :].detach())
    show_image(dec[0, :, :, :].detach())
Example #5
0
def main():
    args = parser.parse_args()
    model = WNet.WNet(args.squeeze)

    model.load_state_dict(
        torch.load(args.model, map_location=torch.device('cpu')))
    model.eval()

    transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor()])

    image = Image.open(args.image).convert('RGB')
    x = transform(image)[None, :, :, :]

    enc, dec = model(x)
    show_image(x[0])
    plt.imshow(torch.argmax(enc, dim=1).detach()[0])
    plt.show()
    show_image(dec[0, :, :, :].detach())
Example #6
0
def main():
    args = parser.parse_args()
    model = WNet.WNet(args.squeeze)

    model.load_state_dict(
        torch.load(args.model, map_location=torch.device('cpu')))
    model.eval()

    transform = transforms.Compose(
        [transforms.Resize((64, 64)),
         transforms.ToTensor()])

    img = Image.open("data2/images/train/1head.png").convert('RGB')
    x = transform(img)[None, :, :, :]

    enc, dec = model(x)
    show_image(x[0])
    # TODO: torch sum/ stack?
    show_image(enc[0, :3, :, :].detach())
    # show_image(torch.argmax(enc[:,:,:,:], dim=1))
    # show_image(dec[0, :, :, :].detach())
    # now put enc in crf
    segment = enc[0, :, :, :].detach()
    # put in tensor here?

    orimg = imread("data2/images/train/1head.png")
    img = resize(orimg, (64, 64))
    Q = dense_crf(img, segment.numpy())

    print(type(Q))
    Q = np.argmax(Q, axis=0)
    print(len(Q))

    print(np.unique(Q))
    plt.imshow(Q)
    plt.show()
Example #7
0
def main():
    args = parser.parse_args()

    dir_path = os.path.dirname(os.path.realpath(__file__))
    print(f"dir_path: {dir_path}")

    image_size = (128, 128)
    transforms = T.Compose([T.Resize(image_size), T.ToTensor()])

    # Download BSD500 dataset
    torch_enhance.datasets.BSDS500()

    dataset = datasets.ImageFolder(".data", transform=transforms)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=len(dataset),
                                             shuffle=True,
                                             pin_memory=True)
    data_cuda = [x[0].cuda() for x in iter(dataloader)][0]

    start_epoch = 1
    batch_size = 10
    if args.load:
        wnet = torch.load(args.load, map_location='cpu')
        # result = re.match('.+-\d\.pt', args.load)
    else:
        wnet = WNet.WNet(4)

    wnet = wnet.cuda()

    if args.train:

        if not args.save_model:
            print("Provide a name for the model")
            return

        save_model_dir = f"models/{args.save_model}"
        os.makedirs(save_model_dir, exist_ok=True)

        learning_rate = 0.0003
        optimizer = torch.optim.SGD(wnet.parameters(), learning_rate)

        start_time = datetime.now()
        loss = 0
        for epoch in range(1, 50000):
            # batch, labels = next(iter(dataloader))

            perm = torch.randperm(data_cuda.size(0))
            idx = perm[:batch_size]
            batch = data_cuda[idx]

            # batch = torch.stack(random.sample(data_cuda, batch_size))
            # batch = batch.cuda()
            wnet, loss, enc, dec = train_op(wnet, optimizer, batch)

            if epoch % 1000 == 0:
                learning_rate /= 10
                print(f"Reducing learning rate to {learning_rate}")
                optimizer = torch.optim.SGD(wnet.parameters(), learning_rate)

                model_name = f"{args.save_model}-{epoch}.pt"
                model_path = f"{save_model_dir}/{model_name}"
                print(f"Saving current model as '{model_path}'")
                torch.save(wnet, model_path)

            if epoch % 100 == 0:
                print("==============================")
                print("Epoch = " + str(epoch))
                duration = (datetime.now() - start_time).seconds
                print(f"Loss: {loss}")
                print(f"Duration: {duration}s")

                start_time = datetime.now()

                # show_grid(np.concatenate([enc[:10].cpu().detach().numpy() , dec[:10].cpu().detach().numpy()]))
                show_grid(enc[:10].cpu().detach().numpy())
                show_grid(dec[:10].cpu().detach().numpy())

            # duration = (datetime.now() - start_time).seconds
            # print(f"Duration: {duration}s")

    elif args.predict:

        encodings = []
        for i, (batch, labels) in enumerate(dataloader):
            if i == 20:
                break

            print(f"\rSegmenting image {i}/{len(dataloader)}", end='')
            # batch = batch.cuda()
            enc = wnet.forward(batch, returns="enc")
            encodings.append(enc[0].detach().numpy())

        encodings = torch.tensor(encodings)
        plt.imshow(
            torchvision.utils.make_grid(encodings, nrow=10).permute(1, 2, 0))
        plt.show()
Example #8
0
def test():
    wnet = WNet.WNet(4)
    wnet = wnet.cuda()
    synthetic_data = torch.rand((1, 3, 128, 128)).cuda()
    optimizer = torch.optim.SGD(wnet.parameters(), 0.001)
    train_op(wnet, optimizer, synthetic_data)
Example #9
0
File: train.py Project: heblol/WNet
def main():
    # Load the arguments
    args, unknown = parser.parse_known_args()

    # Check if CUDA is available
    CUDA = torch.cuda.is_available()

    # Create empty lists for average N_cut losses and reconstruction losses
    n_cut_losses_avg = []
    rec_losses_avg = []

    # Squeeze k
    k = args.squeeze
    img_size = (224, 224)
    wnet = WNet.WNet(k)
    if (CUDA):
        wnet = wnet.cuda()
    learning_rate = 0.003
    optimizer = torch.optim.SGD(wnet.parameters(), lr=learning_rate)

    transform = transforms.Compose(
        [transforms.Resize(img_size),
         transforms.ToTensor()])

    dataset = datasets.ImageFolder(args.input_folder, transform=transform)

    # Train 1 image set batch size=1 and set shuffle to False
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=10,
                                             shuffle=True)

    # Run for every epoch
    for epoch in range(args.epochs):

        # At 1000 epochs divide SGD learning rate by 10
        if (epoch > 0 and epoch % 1000 == 0):
            learning_rate = learning_rate / 10
            optimizer = torch.optim.SGD(wnet.parameters(), lr=learning_rate)

        # Print out every epoch:
        print("Epoch = " + str(epoch))

        # Create empty lists for N_cut losses and reconstruction losses
        n_cut_losses = []
        rec_losses = []
        start_time = time.time()

        for (idx, batch) in enumerate(dataloader):
            # Train 1 image idx > 1
            # if(idx > 1): break

            # Train Wnet with CUDA if available
            if CUDA:
                batch[0] = batch[0].cuda()

            wnet, n_cut_loss, rec_loss = train_op(wnet, optimizer, batch[0], k,
                                                  img_size)

            n_cut_losses.append(n_cut_loss.detach())
            rec_losses.append(rec_loss.detach())

        n_cut_losses_avg.append(torch.mean(torch.FloatTensor(n_cut_losses)))
        rec_losses_avg.append(torch.mean(torch.FloatTensor(rec_losses)))
        print("--- %s seconds ---" % (time.time() - start_time))

    images, labels = next(iter(dataloader))

    # Run wnet with cuda if enabled
    if CUDA:
        images = images.cuda()

    enc, dec = wnet(images)

    torch.save(wnet.state_dict(), "model_" + args.name)
    np.save("n_cut_losses_" + args.name, n_cut_losses_avg)
    np.save("rec_losses_" + args.name, rec_losses_avg)
    print("Done")