def main(): global args args = parser.parse_args() print(args) # Check if the save directory exists or not if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) cudnn.benchmark = True # Initialize the data transforms data_transforms = { 'train': transforms.Compose([ transforms.Resize((args.imageSize, args.imageSize), interpolation=Image.NEAREST), transforms.ToTensor(), ]), 'val': transforms.Compose([ transforms.Resize((args.imageSize, args.imageSize), interpolation=Image.NEAREST), transforms.ToTensor(), ]), 'test': transforms.Compose([ transforms.Resize((args.imageSize, args.imageSize), interpolation=Image.NEAREST), transforms.ToTensor(), ]), } # Data Loading data_dir = '/media/salman/DATA/General Datasets/cityscapes' # json path for class definitions json_path = '/home/salman/pytorch/capsNet/dataset/cityscapesClasses.json' image_datasets = { x: cityscapesDataset(data_dir, x, data_transforms[x], json_path) for x in ['train', 'val', 'test'] } dataloaders = { x: torch.utils.data.DataLoader(image_datasets[x], batch_size=args.batchSize, shuffle=True, num_workers=args.workers) for x in ['train', 'val', 'test'] } dataset_sizes = { x: len(image_datasets[x]) for x in ['train', 'val', 'test'] } # Initialize the Network model = capsNet.CapsNet(args.routing_iterations) if args.with_reconstruction: reconstruction_model = capsNet.ReconstructionNet(20, 20) reconstruction_alpha = 0.0005 model = capsNet.CapsNetWithReconstruction(model, reconstruction_model) if use_gpu: model.cuda() print(model) # Define the optimizer optimizer = optim.Adam(model.parameters(), lr=args.lr) # Get the dictionary for the id and RGB value pairs for the dataset classes = image_datasets['train'].classes key = utils.disentangleKey(classes) # Initialize the loss function # loss_fn = capsNet.MarginLoss(0.9, 0.1, 0.5) for epoch in range(args.start_epoch, args.epochs): # Train for one epoch train(dataloaders['train'], model, optimizer, epoch, key)
def main(): global args, best_prec1 args = parser.parse_args() print(args) if args.saveTest == 'True': args.saveTest = True elif args.saveTest == 'False': args.saveTest = False # Check if the save directory exists or not if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) cudnn.benchmark = True data_transforms = { 'train': transforms.Compose([ transforms.Resize((args.imageSize, args.imageSize), interpolation=Image.NEAREST), transforms.TenCrop(args.resizedImageSize), transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])), #transforms.Lambda(lambda normalized: torch.stack([transforms.Normalize([0.295, 0.204, 0.197], [0.221, 0.188, 0.182])(crop) for crop in normalized])) #transforms.RandomResizedCrop(224, interpolation=Image.NEAREST), #transforms.RandomHorizontalFlip(), #transforms.RandomVerticalFlip(), #transforms.ToTensor(), ]), 'test': transforms.Compose([ transforms.Resize((args.imageSize, args.imageSize), interpolation=Image.NEAREST), transforms.ToTensor(), #transforms.Normalize([0.295, 0.204, 0.197], [0.221, 0.188, 0.182]) ]), } # Data Loading data_dir = 'datasets/miccaiSegRefined' # json path for class definitions json_path = 'datasets/miccaiSegClasses.json' image_datasets = {x: miccaiSegDataset(os.path.join(data_dir, x), data_transforms[x], json_path) for x in ['train', 'test']} dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=args.batchSize, shuffle=True, num_workers=args.workers) for x in ['train', 'test']} dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']} # Get the dictionary for the id and RGB value pairs for the dataset classes = image_datasets['train'].classes key = utils.disentangleKey(classes) num_classes = len(key) # Initialize the model model = UNet(num_classes) # # Optionally resume from a checkpoint # if args.resume: # if os.path.isfile(args.resume): # print("=> loading checkpoint '{}'".format(args.resume)) # checkpoint = torch.load(args.resume) # #args.start_epoch = checkpoint['epoch'] # pretrained_dict = checkpoint['state_dict'] # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model.state_dict()} # model.state_dict().update(pretrained_dict) # model.load_state_dict(model.state_dict()) # print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch'])) # else: # print("=> no checkpoint found at '{}'".format(args.resume)) # # # # Freeze the encoder weights # # for param in model.encoder.parameters(): # # param.requires_grad = False # # optimizer = optim.Adam(model.parameters(), lr = args.lr, weight_decay = args.wd) # else: optimizer = optim.Adam(model.parameters(), lr = args.lr, weight_decay = args.wd) # Load the saved model if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) print(model) # Define loss function (criterion) criterion = nn.CrossEntropyLoss() # Use a learning rate scheduler scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) if use_gpu: model.cuda() criterion.cuda() # Initialize an evaluation Object evaluator = utils.Evaluate(key, use_gpu) for epoch in range(args.start_epoch, args.epochs): #adjust_learning_rate(optimizer, epoch) # Train for one epoch print('>>>>>>>>>>>>>>>>>>>>>>>Training<<<<<<<<<<<<<<<<<<<<<<<') train(dataloaders['train'], model, criterion, optimizer, scheduler, epoch, key) # Evaulate on validation set print('>>>>>>>>>>>>>>>>>>>>>>>Testing<<<<<<<<<<<<<<<<<<<<<<<') validate(dataloaders['test'], model, criterion, epoch, key, evaluator) # Calculate the metrics print('>>>>>>>>>>>>>>>>>> Evaluating the Metrics <<<<<<<<<<<<<<<<<') IoU = evaluator.getIoU() print('Mean IoU: {}, Class-wise IoU: {}'.format(torch.mean(IoU), IoU)) PRF1 = evaluator.getPRF1() precision, recall, F1 = PRF1[0], PRF1[1], PRF1[2] print('Mean Precision: {}, Class-wise Precision: {}'.format(torch.mean(precision), precision)) print('Mean Recall: {}, Class-wise Recall: {}'.format(torch.mean(recall), recall)) print('Mean F1: {}, Class-wise F1: {}'.format(torch.mean(F1), F1)) evaluator.reset() save_checkpoint({ 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, filename=os.path.join(args.save_dir, 'checkpoint_{}.tar'.format(epoch)))
def main(): global args args = parser.parse_args() print(args) if args.saveTest == 'True': args.saveTest = True elif args.saveTest == 'False': args.saveTest = False # Check if the save directory exists or not if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) cudnn.benchmark = True data_transform = transforms.Compose([ transforms.Resize((args.imageSize, args.imageSize), interpolation=Image.NEAREST), transforms.ToTensor(), ]) # Data Loading data_dir = '/home/salman/pytorch/segmentationNetworks/datasets/miccaiSegOrgans' # json path for class definitions json_path = '/home/salman/pytorch/segmentationNetworks/datasets/miccaiSegOrganClasses.json' image_dataset = miccaiSegDataset(os.path.join(data_dir, 'test'), data_transform, json_path) dataloader = torch.utils.data.DataLoader(image_dataset, batch_size=args.batchSize, shuffle=True, num_workers=args.workers) # Get the dictionary for the id and RGB value pairs for the dataset classes = image_dataset.classes key = utils.disentangleKey(classes) num_classes = len(key) # Initialize the model model = SegNet(args.bnMomentum, num_classes) # Load the saved model if os.path.isfile(args.model): print("=> loading checkpoint '{}'".format(args.model)) checkpoint = torch.load(args.model) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.model)) print(model) # Define loss function (criterion) criterion = nn.CrossEntropyLoss() if use_gpu: model.cuda() criterion.cuda() # Initialize an evaluation Object evaluator = utils.Evaluate(key, use_gpu) # Evaulate on validation/test set print('>>>>>>>>>>>>>>>>>>>>>>>Testing<<<<<<<<<<<<<<<<<<<<<<<') validate(dataloader, model, criterion, key, evaluator) # Calculate the metrics print('>>>>>>>>>>>>>>>>>> Evaluating the Metrics <<<<<<<<<<<<<<<<<') IoU = evaluator.getIoU() print('Mean IoU: {}, Class-wise IoU: {}'.format(torch.mean(IoU), IoU)) PRF1 = evaluator.getPRF1() precision, recall, F1 = PRF1[0], PRF1[1], PRF1[2] print('Mean Precision: {}, Class-wise Precision: {}'.format( torch.mean(precision), precision)) print('Mean Recall: {}, Class-wise Recall: {}'.format( torch.mean(recall), recall)) print('Mean F1: {}, Class-wise F1: {}'.format(torch.mean(F1), F1))