def WNetTest(): encoded=EncoderTest(verbose=False) decoder=WNet.UDec(4) reproduced=decoder(encoded) var=torch.var(reproduced) mean=torch.mean(reproduced) print('Passed Decoder Test with var=%s and mean=%s' % (var, mean))
def DecoderTest(): shape=(2, 4, 224, 224) out_shape=(2, 3, 224, 224) decoder=WNet.UDec(shape[1]) data=torch.rand(tuple(shape)) decoded=decoder(data) assert tuple(decoded.shape)==out_shape var=torch.var(decoded) mean=torch.mean(decoded) print('Passed Decoder Test with var=%s and mean=%s' % (var, mean))
def EncoderTest(verbose=True): shape=(2, 4, 224, 224) encoder=WNet.UEnc(shape[1]) data=torch.rand((shape[0], 3, shape[2], shape[3])) encoded=encoder(data) assert tuple(encoded.shape)==shape var=torch.var(encoded) mean=torch.mean(encoded) if verbose: print('Passed Encoder Test with var=%s and mean=%s' % (var, mean)) return encoded
def main(): args = parser.parse_args() model = WNet.WNet(args.squeeze) model.load_state_dict(torch.load(args.model)) model.eval() transform = transforms.Compose( [transforms.Resize((64, 64)), transforms.ToTensor()]) image = Image.open(args.image).convert('RGB') x = transform(image)[None, :, :, :] enc, dec = model(x) show_image(x[0]) show_image(enc[0, :1, :, :].detach()) show_image(dec[0, :, :, :].detach())
def main(): args = parser.parse_args() model = WNet.WNet(args.squeeze) model.load_state_dict( torch.load(args.model, map_location=torch.device('cpu'))) model.eval() transform = transforms.Compose( [transforms.Resize((224, 224)), transforms.ToTensor()]) image = Image.open(args.image).convert('RGB') x = transform(image)[None, :, :, :] enc, dec = model(x) show_image(x[0]) plt.imshow(torch.argmax(enc, dim=1).detach()[0]) plt.show() show_image(dec[0, :, :, :].detach())
def main(): args = parser.parse_args() model = WNet.WNet(args.squeeze) model.load_state_dict( torch.load(args.model, map_location=torch.device('cpu'))) model.eval() transform = transforms.Compose( [transforms.Resize((64, 64)), transforms.ToTensor()]) img = Image.open("data2/images/train/1head.png").convert('RGB') x = transform(img)[None, :, :, :] enc, dec = model(x) show_image(x[0]) # TODO: torch sum/ stack? show_image(enc[0, :3, :, :].detach()) # show_image(torch.argmax(enc[:,:,:,:], dim=1)) # show_image(dec[0, :, :, :].detach()) # now put enc in crf segment = enc[0, :, :, :].detach() # put in tensor here? orimg = imread("data2/images/train/1head.png") img = resize(orimg, (64, 64)) Q = dense_crf(img, segment.numpy()) print(type(Q)) Q = np.argmax(Q, axis=0) print(len(Q)) print(np.unique(Q)) plt.imshow(Q) plt.show()
def main(): args = parser.parse_args() dir_path = os.path.dirname(os.path.realpath(__file__)) print(f"dir_path: {dir_path}") image_size = (128, 128) transforms = T.Compose([T.Resize(image_size), T.ToTensor()]) # Download BSD500 dataset torch_enhance.datasets.BSDS500() dataset = datasets.ImageFolder(".data", transform=transforms) dataloader = torch.utils.data.DataLoader(dataset, batch_size=len(dataset), shuffle=True, pin_memory=True) data_cuda = [x[0].cuda() for x in iter(dataloader)][0] start_epoch = 1 batch_size = 10 if args.load: wnet = torch.load(args.load, map_location='cpu') # result = re.match('.+-\d\.pt', args.load) else: wnet = WNet.WNet(4) wnet = wnet.cuda() if args.train: if not args.save_model: print("Provide a name for the model") return save_model_dir = f"models/{args.save_model}" os.makedirs(save_model_dir, exist_ok=True) learning_rate = 0.0003 optimizer = torch.optim.SGD(wnet.parameters(), learning_rate) start_time = datetime.now() loss = 0 for epoch in range(1, 50000): # batch, labels = next(iter(dataloader)) perm = torch.randperm(data_cuda.size(0)) idx = perm[:batch_size] batch = data_cuda[idx] # batch = torch.stack(random.sample(data_cuda, batch_size)) # batch = batch.cuda() wnet, loss, enc, dec = train_op(wnet, optimizer, batch) if epoch % 1000 == 0: learning_rate /= 10 print(f"Reducing learning rate to {learning_rate}") optimizer = torch.optim.SGD(wnet.parameters(), learning_rate) model_name = f"{args.save_model}-{epoch}.pt" model_path = f"{save_model_dir}/{model_name}" print(f"Saving current model as '{model_path}'") torch.save(wnet, model_path) if epoch % 100 == 0: print("==============================") print("Epoch = " + str(epoch)) duration = (datetime.now() - start_time).seconds print(f"Loss: {loss}") print(f"Duration: {duration}s") start_time = datetime.now() # show_grid(np.concatenate([enc[:10].cpu().detach().numpy() , dec[:10].cpu().detach().numpy()])) show_grid(enc[:10].cpu().detach().numpy()) show_grid(dec[:10].cpu().detach().numpy()) # duration = (datetime.now() - start_time).seconds # print(f"Duration: {duration}s") elif args.predict: encodings = [] for i, (batch, labels) in enumerate(dataloader): if i == 20: break print(f"\rSegmenting image {i}/{len(dataloader)}", end='') # batch = batch.cuda() enc = wnet.forward(batch, returns="enc") encodings.append(enc[0].detach().numpy()) encodings = torch.tensor(encodings) plt.imshow( torchvision.utils.make_grid(encodings, nrow=10).permute(1, 2, 0)) plt.show()
def test(): wnet = WNet.WNet(4) wnet = wnet.cuda() synthetic_data = torch.rand((1, 3, 128, 128)).cuda() optimizer = torch.optim.SGD(wnet.parameters(), 0.001) train_op(wnet, optimizer, synthetic_data)
def main(): # Load the arguments args, unknown = parser.parse_known_args() # Check if CUDA is available CUDA = torch.cuda.is_available() # Create empty lists for average N_cut losses and reconstruction losses n_cut_losses_avg = [] rec_losses_avg = [] # Squeeze k k = args.squeeze img_size = (224, 224) wnet = WNet.WNet(k) if (CUDA): wnet = wnet.cuda() learning_rate = 0.003 optimizer = torch.optim.SGD(wnet.parameters(), lr=learning_rate) transform = transforms.Compose( [transforms.Resize(img_size), transforms.ToTensor()]) dataset = datasets.ImageFolder(args.input_folder, transform=transform) # Train 1 image set batch size=1 and set shuffle to False dataloader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=True) # Run for every epoch for epoch in range(args.epochs): # At 1000 epochs divide SGD learning rate by 10 if (epoch > 0 and epoch % 1000 == 0): learning_rate = learning_rate / 10 optimizer = torch.optim.SGD(wnet.parameters(), lr=learning_rate) # Print out every epoch: print("Epoch = " + str(epoch)) # Create empty lists for N_cut losses and reconstruction losses n_cut_losses = [] rec_losses = [] start_time = time.time() for (idx, batch) in enumerate(dataloader): # Train 1 image idx > 1 # if(idx > 1): break # Train Wnet with CUDA if available if CUDA: batch[0] = batch[0].cuda() wnet, n_cut_loss, rec_loss = train_op(wnet, optimizer, batch[0], k, img_size) n_cut_losses.append(n_cut_loss.detach()) rec_losses.append(rec_loss.detach()) n_cut_losses_avg.append(torch.mean(torch.FloatTensor(n_cut_losses))) rec_losses_avg.append(torch.mean(torch.FloatTensor(rec_losses))) print("--- %s seconds ---" % (time.time() - start_time)) images, labels = next(iter(dataloader)) # Run wnet with cuda if enabled if CUDA: images = images.cuda() enc, dec = wnet(images) torch.save(wnet.state_dict(), "model_" + args.name) np.save("n_cut_losses_" + args.name, n_cut_losses_avg) np.save("rec_losses_" + args.name, rec_losses_avg) print("Done")