def main(): print(CP_R + "e-Lab Segmentation Training Script" + CP_C) ################################################################# # Initialization step torch.manual_seed(args.seed) torch.set_default_tensor_type('torch.FloatTensor') ################################################################# # Acquire dataset loader object # Normalization factor based on ResNet stats prep_data = transforms.Compose([ #transforms.RandomCrop(900), transforms.Resize(args.img_size, 0), #transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) prep_target = transforms.Compose([ #transforms.RandomCrop(900), transforms.Resize(args.img_size, 0), #transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]) if args.dataset == 'cs': import data.segmented_data as segmented_data print("{}Cityscapes dataset in use{}!!!".format(CP_G, CP_C)) else: print("{}Invalid data-loader{}".format(CP_R, CP_C)) # Training data loader data_obj_train = segmented_data.SegmentedData(root=args.datapath, mode='train', transform=prep_data, target_transform=prep_target) data_loader_train = DataLoader(data_obj_train, batch_size=args.bs, shuffle=True, num_workers=args.workers) data_len_train = len(data_obj_train) # Testing data loader data_obj_test = segmented_data.SegmentedData(root=args.datapath, mode='test', transform=prep_data, target_transform=prep_target) data_loader_test = DataLoader(data_obj_test, batch_size=args.bs, shuffle=True, num_workers=args.workers) data_len_test = len(data_obj_test) ################################################################# # Load model print('{}{:=<80}{}'.format(CP_R, '', CP_C)) print('{}Models will be saved in: {}{}'.format(CP_Y, CP_C, str(args.save))) if not os.path.exists(str(args.save)): os.mkdir(str(args.save)) if args.saveAll: if not os.path.exists(str(args.save) + '/all'): os.mkdir(str(args.save) + '/all') epoch = 0 if args.resume: # Load previous model state checkpoint = torch.load(args.save + '/model_resume.pt') epoch = checkpoint['epoch'] model = checkpoint['model_def'] model.load_state_dict(checkpoint['state_dict']) optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd) optimizer.load_stat_dict(checkpoint['optim_state']) print('{}Loaded model from previous checkpoint epoch # {}()'.format( CP_G, CP_C, epoch)) else: # Load fresh model definition if args.model == 'linknet': # Save model definiton script call(["cp", "./models/linknet.py", args.save]) from models.linknet import LinkNet model = LinkNet(len(data_obj_train.class_name())) optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd) # Criterion model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) model.cuda() criterion = nn.NLLLoss2d() # Save arguements used for training args_log = open(args.save + '/args.log', 'w') args_log.write(str(args)) args_log.close() error_log = list() prev_error = 1000 train = Train(model, data_loader_train, optimizer, criterion, args.lr, args.wd, args.visdom) test = Test(model, data_loader_test, criterion, args.visdom) while epoch <= args.maxepoch: train_error = train.forward() test_error = test.forward() print('{}{:-<80}{}'.format(CP_R, '', CP_C)) print('{}Epoch #: {}{:03}'.format(CP_B, CP_C, epoch)) print('{}Training Error: {}{:.6f} | {}Testing Error: {}{:.6f}'.format( CP_B, CP_C, train_error, CP_B, CP_C, test_error)) error_log.append((train_error, test_error)) # Save weights and model definition prev_error = save_model( { 'epoch': epoch, 'model_def': LinkNet, 'state_dict': model.state_dict(), 'optim_state': optimizer.state_dict(), }, test_error, prev_error, args.save, args.saveAll) logger = open(args.save + '/error.log', 'w') logger.write('{:10} {:10}'.format('Train Error', 'Test Error')) logger.write('\n{:-<20}'.format('')) for total_error in error_log: logger.write('\n{:.6f} {:.6f}'.format(total_error[0], total_error[1])) logger.close()
def main(): print(CP_R + "e-Lab Segmentation Training Script" + CP_C) ################################################################# # Initialization step torch.manual_seed(args.seed) cudnn.benchmark = True torch.set_default_tensor_type('torch.FloatTensor') ################################################################# # Acquire dataset loader object # Normalization factor based on ResNet stats prep_data = transforms.Compose([ #transforms.Crop((512, 512)), transforms.Resize((1024, 2048)), transforms.ToTensor(), transforms.Normalize([[0.406, 0.456, 0.485], [0.225, 0.224, 0.229]]) ]) prep_target = transforms.Compose([ #transforms.Crop((512, 512)), transforms.Resize((512, 1024)), transforms.ToTensor(basic=True), ]) if args.dataset == 'cs': import data.segmented_data as segmented_data print ("{}Cityscapes dataset in use{}!!!".format(CP_G, CP_C)) else: print ("{}Invalid data-loader{}".format(CP_R, CP_C)) # Training data loader data_obj_train = segmented_data.SegmentedData(root=args.datapath, mode='train', transform=prep_data, target_transform=prep_target) data_loader_train = DataLoader(data_obj_train, batch_size=args.bs, shuffle=True, num_workers=args.workers, pin_memory=True) data_len_train = len(data_obj_train) # Testing data loader data_obj_test = segmented_data.SegmentedData(root=args.datapath, mode='val', transform=prep_data, target_transform=prep_target) data_loader_test = DataLoader(data_obj_test, batch_size=args.bs, shuffle=False, num_workers=args.workers, pin_memory=True) data_len_test = len(data_obj_test) class_names = data_obj_train.class_name() n_classes = len(class_names) ################################################################# # Load model epoch = 0 prev_iou = 0.0001 # Load fresh model definition print('{}{:=<80}{}'.format(CP_R, '', CP_C)) print('{}Models will be saved in: {}{}'.format(CP_Y, CP_C, str(args.save))) if not os.path.exists(str(args.save)): os.mkdir(str(args.save)) if args.saveAll: if not os.path.exists(str(args.save)+'/all'): os.mkdir(str(args.save)+'/all') if args.model == 'linknet': # Save model definiton script call(["cp", "./models/linknet.py", args.save]) from models.linknet import LinkNet from torchvision.models import resnet18 model = LinkNet(n_classes) # Copy weights of resnet18 into encoder pretrained_model = resnet18(pretrained=True) for i, j in zip(model.modules(), pretrained_model.modules()): if not list(i.children()): if not isinstance(i, nn.Linear) and len(i.state_dict()) > 0: i.weight.data = j.weight.data model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) model.cuda() optimizer = torch.optim.Adam(model.parameters(), args.lr)#, #momentum=args.momentum, weight_decay=args.wd) if args.resume: # Load previous model state checkpoint = torch.load(args.save + '/model_resume.pth') epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optim_state']) prev_iou = checkpoint['min_error'] print('{}Loaded model from previous checkpoint epoch # {}({})'.format(CP_G, CP_C, epoch)) # Criterion print("Model initialized for training...") hist_path = os.path.join(args.save, 'hist') if os.path.isfile(hist_path + '.npy'): hist = np.load(hist_path + '.npy') print('{}Loaded cached dataset stats{}!!!'.format(CP_Y, CP_C)) else: # Get class weights based on training data hist = np.zeros((n_classes), dtype=np.float) for batch_idx, (x, yt) in enumerate(data_loader_train): h, bins = np.histogram(yt.numpy(), list(range(n_classes + 1))) hist += h hist = hist/(max(hist)) # Normalize histogram print('{}Saving dataset stats{}...'.format(CP_Y, CP_C)) np.save(hist_path, hist) criterion_weight = 1/np.log(1.02 + hist) criterion_weight[0] = 0 criterion = nn.NLLLoss(Variable(torch.from_numpy(criterion_weight).float().cuda())) print('{}Using weighted criterion{}!!!'.format(CP_Y, CP_C)) #criterion = cross_entropy2d # Save arguements used for training args_log = open(args.save + '/args.log', 'w') for k in args.__dict__: args_log.write(k + ' : ' + str(args.__dict__[k]) + '\n') args_log.close() # Setup Metrics metrics = ConfusionMatrix(n_classes, class_names) train = Train(model, data_loader_train, optimizer, criterion, args.lr, args.wd, args.bs, args.visdom) test = Test(model, data_loader_test, criterion, metrics, args.bs, args.visdom) # Save error values in log file logger = open(args.save + '/error.log', 'w') logger.write('{:10} {:10}'.format('Train Error', 'Test Error')) logger.write('\n{:-<20}'.format('')) while epoch <= args.maxepoch: train_error = 0 print('{}{:-<80}{}'.format(CP_R, '', CP_C)) print('{}Epoch #: {}{:03}'.format(CP_B, CP_C, epoch)) train_error = train.forward() test_error, accuracy, avg_accuracy, iou, miou, conf_mat= test.forward() logger.write('\n{:.6f} {:.6f} {:.6f}'.format(train_error, test_error, miou)) print('{}Training Error: {}{:.6f} | {}Testing Error: {}{:.6f} |{}Mean IoU: {}{:.6f}'.format( CP_B, CP_C, train_error, CP_B, CP_C, test_error, CP_G, CP_C, miou)) # Save weights and model definition prev_iou = save_model({ 'epoch': epoch, 'model_def': model, 'state_dict': model.state_dict(), 'optim_state': optimizer.state_dict(), 'min_error': prev_iou }, class_names, conf_mat, miou, prev_iou, avg_accuracy, iou, args.save, args.saveAll) epoch += 1 logger.close()