def main(): print("PyTorch Version: ", torch.__version__) if torch.cuda.is_available(): print("Cuda is available. Using GPU") config = Config() ################################### # Image loading and preprocessing # ################################### #TODO: Maybe we should crop a large square, then resize that down to our patch size? # For now, data augmentation must not introduce any missing pixels TODO: Add data augmentation noise train_xform = transforms.Compose([ transforms.RandomCrop(224), transforms.Resize(128), transforms.RandomCrop( config.input_size + config.variationalTranslation), # For now, cropping down to 224 transforms.RandomHorizontalFlip( ), # TODO: Add colorjitter, random erasing transforms.ToTensor() ]) val_xform = transforms.Compose([ transforms.CenterCrop(224), transforms.Resize(128), transforms.CenterCrop(config.input_size), transforms.ToTensor() ]) #TODO: Load validation segmentation maps too (for evaluation purposes) train_dataset = AutoencoderDataset("train", train_xform) val_dataset = AutoencoderDataset("val", val_xform) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=config.batch_size, num_workers=4, shuffle=True) val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=4, num_workers=4, shuffle=False) util.clear_progress_dir() ################################### # Model Setup # ################################### autoencoder = WNet() if torch.cuda.is_available(): autoencoder = autoencoder.cuda() optimizer = torch.optim.Adam(autoencoder.parameters()) if config.debug: print(autoencoder) util.enumerate_params([autoencoder]) # Use the current time to save the model at end of each epoch modelName = str(datetime.now()) ################################### # Loss Criterion # ################################### def reconstruction_loss(x, x_prime): binary_cross_entropy = F.binary_cross_entropy(x_prime, x, reduction='sum') return binary_cross_entropy ################################### # Training Loop # ################################### autoencoder.train() progress_images, progress_expected = next(iter(val_dataloader)) for epoch in range(config.num_epochs): running_loss = 0.0 for i, [inputs, outputs] in enumerate(train_dataloader, 0): if config.showdata: print(inputs.shape) print(outputs.shape) print(inputs[0]) plt.imshow(inputs[0].permute(1, 2, 0)) plt.show() if torch.cuda.is_available(): inputs = inputs.cuda() outputs = outputs.cuda() optimizer.zero_grad() segmentations, reconstructions = autoencoder(inputs) l_soft_n_cut = soft_n_cut_loss(inputs, segmentations) l_reconstruction = reconstruction_loss( inputs if config.variationalTranslation == 0 else outputs, reconstructions) loss = (l_reconstruction + l_soft_n_cut) loss.backward( retain_graph=False ) # We only need to do retain graph =true if we're backpropping from multiple heads optimizer.step() if config.debug and (i % 50) == 0: print(i) # print statistics running_loss += loss.item() if config.showSegmentationProgress and i == 0: # If first batch in epoch util.save_progress_image(autoencoder, progress_images, epoch) optimizer.zero_grad() # Don't change gradient on validation epoch_loss = running_loss / len(train_dataloader.dataset) print(f"Epoch {epoch} loss: {epoch_loss:.6f}") if config.saveModel: util.save_model(autoencoder, modelName)
from configure import Config from model import WNet from Ncuts import NCutsLoss from DataLoader import DataLoader import time import os import torchvision import pdb from PIL import Image config = Config() if __name__ == '__main__': dataset = DataLoader(config.datapath, "test") dataloader = dataset.torch_loader() model = WNet() model.cuda(config.cuda_dev) optimizer = torch.optim.SGD(model.parameters(), lr=config.init_lr) #optimizer with open(config.model_tested, 'rb') as f: para = torch.load(f, "cuda:0") pdb.set_trace() model.load_state_dict(para['state_dict'], False) for step, [x] in enumerate(dataloader): print('Step' + str(step + 1)) #NCuts Loss x = x.cuda(config.cuda_dev) pred, rec_image = model(x) seg = (pred.argmax(dim=1).to(torch.float) / 3 * 255).cpu().detach().numpy() rec_image = rec_image.cpu().detach().numpy() * 255
from model import WNet from Ncuts import NCutsLoss from DataLoader import DataLoader import time import os import torchvision import pdb from PIL import Image config = Config() os.environ["CUDA_VISIBLE_DEVICES"] = config.cuda_dev_list if __name__ == '__main__': dataset = DataLoader(config.bsds, "test") dataloader = dataset.torch_loader() model = WNet() model.cuda() model.eval() optimizer = torch.optim.SGD(model.parameters(), lr=config.init_lr) #optimizer with open(config.model_tested, 'rb') as f: para = torch.load(f, "cuda:0") model.load_state_dict(para['state_dict']) for step, [x] in enumerate(dataloader): print('Step' + str(step + 1)) #NCuts Loss x = x.cuda() pred, pad_pred = model(x) seg = (pred.argmax(dim=1)).cpu().detach().numpy() x = x.cpu().detach().numpy() * 255 x = np.transpose(x.astype(np.uint8), (0, 2, 3, 1))