def load_dataset(dataset): print("\nLoading dataset...\n") print("Selected dataset:", args.dataset) print("Dataset directory:", args.dataset_dir) print("Save directory:", args.save_dir) image_transform = transforms.Compose( [transforms.Resize((args.height, args.width)), transforms.ToTensor()]) label_transform = transforms.Compose([ transforms.Resize((args.height, args.width), Image.NEAREST), ext_transforms.PILToLongTensor() ]) # Get selected dataset # Load the training set as tensors train_set = dataset(args.dataset_dir, transform=image_transform, label_transform=label_transform) train_loader = data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers) # Load the validation set as tensors val_set = dataset(args.dataset_dir, mode='val', transform=image_transform, label_transform=label_transform) val_loader = data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) # Load the test set as tensors test_set = dataset(args.dataset_dir, mode='test', transform=image_transform, label_transform=label_transform) test_loader = data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) # Get encoding between pixel valus in label images and RGB colors class_encoding = train_set.color_encoding # Remove the road_marking class from the CamVid dataset as it's merged # with the road class if args.dataset.lower() == 'camvid': del class_encoding['road_marking'] # Get number of classes to predict num_classes = len(class_encoding) # Print information for debugging print("Number of classes to predict:", num_classes) print("Train dataset size:", len(train_set)) print("Validation dataset size:", len(val_set)) # Get a batch of samples to display if args.mode.lower() == 'test': images, labels = iter(test_loader).next() else: images, labels = iter(train_loader).next() print("Image size:", images.size()) print("Label size:", labels.size()) print("Class-color encoding:", class_encoding) # Show a batch of samples and labels if args.imshow_batch: print("Close the figure window to continue...") label_to_rgb = transforms.Compose([ ext_transforms.LongTensorToRGBPIL(class_encoding), transforms.ToTensor() ]) color_labels = utils.batch_transform(labels, label_to_rgb) utils.imshow_batch(images, color_labels) # Get class weights from the selected weighing technique print("\nWeighing technique:", args.weighing) print("Computing class weights...") print("(this can take a while depending on the dataset size)") class_weights = 0 if args.weighing.lower() == 'enet': class_weights = enet_weighing(train_loader, num_classes) elif args.weighing.lower() == 'mfb': class_weights = median_freq_balancing(train_loader, num_classes) else: class_weights = None if class_weights is not None: class_weights = torch.from_numpy(class_weights).float().to(device) # Set the weight of the unlabeled class to 0 if args.ignore_unlabeled: ignore_index = list(class_encoding).index('unlabeled') class_weights[ignore_index] = 0 print("Class weights:", class_weights) return (train_loader, val_loader, test_loader), class_weights, class_encoding
def load_dataset(dataset): print("\nLoading dataset...\n") print("Selected dataset:", args.dataset) print("Dataset directory:", args.dataset_dir) print('Train file:', args.trainFile) print('Val file:', args.valFile) print('Test file:', args.testFile) print("Save directory:", args.save_dir) image_transform = transforms.Compose( [transforms.Resize((args.height, args.width)), transforms.ToTensor()]) label_transform = transforms.Compose([ transforms.Resize((args.height, args.width)), ext_transforms.PILToLongTensor() ]) # Get selected dataset # Load the training set as tensors train_set = dataset(args.dataset_dir, args.trainFile, mode='train', transform=image_transform, \ label_transform=label_transform, color_mean=color_mean, color_std=color_std, \ load_depth=(args.arch=='rgbd'), seg_classes=args.seg_classes) train_loader = data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers) # Load the validation set as tensors val_set = dataset(args.dataset_dir, args.valFile, mode='val', transform=image_transform, \ label_transform=label_transform, color_mean=color_mean, color_std=color_std, \ load_depth=(args.arch=='rgbd'), seg_classes=args.seg_classes) val_loader = data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) # Load the test set as tensors test_set = dataset(args.dataset_dir, args.testFile, mode='test', transform=image_transform, \ label_transform=label_transform, color_mean=color_mean, color_std=color_std, \ load_depth=(args.arch=='rgbd'), seg_classes=args.seg_classes) test_loader = data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) # Get encoding between pixel valus in label images and RGB colors class_encoding = train_set.color_encoding # Get number of classes to predict num_classes = len(class_encoding) # Print information for debugging print("Number of classes to predict:", num_classes) print("Train dataset size:", len(train_set)) print("Validation dataset size:", len(val_set)) # Get a batch of samples to display if args.mode.lower() == 'test': images, labels = iter(test_loader).next() else: images, labels = iter(train_loader).next() print("Image size:", images.size()) print("Label size:", labels.size()) print("Class-color encoding:", class_encoding) # Show a batch of samples and labels if args.imshow_batch: print("Close the figure window to continue...") label_to_rgb = transforms.Compose([ ext_transforms.LongTensorToRGBPIL(class_encoding), transforms.ToTensor() ]) color_labels = utils.batch_transform(labels, label_to_rgb) utils.imshow_batch(images, color_labels) # Get class weights from the selected weighing technique print("\nWeighing technique:", args.weighing) # If a class weight file is provided, try loading weights from in there class_weights = None if args.class_weights_file: print('Trying to load class weights from file...') try: class_weights = np.loadtxt(args.class_weights_file) except Exception as e: raise e if class_weights is None: print("Computing class weights...") print("(this can take a while depending on the dataset size)") class_weights = 0 if args.weighing.lower() == 'enet': class_weights = enet_weighing(train_loader, num_classes) elif args.weighing.lower() == 'mfb': class_weights = median_freq_balancing(train_loader, num_classes) else: class_weights = None if class_weights is not None: class_weights = torch.from_numpy(class_weights).float().to(device) # Set the weight of the unlabeled class to 0 if args.ignore_unlabeled: ignore_index = list(class_encoding).index('unlabeled') class_weights[ignore_index] = 0 print("Class weights:", class_weights) return (train_loader, val_loader, test_loader), class_weights, class_encoding
('motorcycle', (0, 0, 230)), ('bicycle', (119, 11, 32)) ]) full_classes = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19) # The values above are remapped to the following new_classes = (0, 7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33) image_transform = transforms.Compose( [transforms.Resize((args.height, args.width)), transforms.ToTensor()]) label_to_rgb = transforms.Compose( [ext_transforms.LongTensorToRGBPIL(color_encoding), transforms.ToTensor()]) def predict(model, image, device): image = image.to(device) model = model.to(device) while torch.no_grad(): # Make predictions! prediction, _ = model(image) # Predictions is one-hot encoded with "num_classes" channels. # Convert it to a single int using the indices where the maximum (1) occurs _, predictions = torch.max(prediction.data, 1) return predictions
def load_dataset(dataset): print("\nLoading dataset...\n") print("Selected dataset:", args.dataset) print("Dataset directory:", args.dataset_dir) print('Test file:', args.testFile) print("Save directory:", args.save_dir) image_transform = transforms.Compose( [transforms.Resize((args.height, args.width)), transforms.ToTensor()]) label_transform = transforms.Compose([ transforms.Resize((args.height, args.width)), ext_transforms.PILToLongTensor() ]) # Load the test set as tensors test_set = dataset(args.dataset_dir, args.testFile, mode='inference', transform=image_transform, \ label_transform=label_transform, color_mean=color_mean, color_std=color_std, \ load_depth=(args.arch=='rgbd'), seg_classes=args.seg_classes) test_loader = data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) # Get encoding between pixel valus in label images and RGB colors class_encoding = test_set.color_encoding # Get number of classes to predict num_classes = len(class_encoding) # Print information for debugging print("Number of classes to predict:", num_classes) print("Test dataset size:", len(test_set)) # Get a batch of samples to display if args.arch == 'rgbd': images, labels, data_path, depth_path, label_path = iter(test_loader).next() else: images, labels, data_path, label_path = iter(test_loader).next() print("Image size:", images.size()) print("Label size:", labels.size()) print("Class-color encoding:", class_encoding) # Show a batch of samples and labels if args.imshow_batch: print("Close the figure window to continue...") label_to_rgb = transforms.Compose([ ext_transforms.LongTensorToRGBPIL(class_encoding), transforms.ToTensor() ]) color_labels = utils.batch_transform(labels, label_to_rgb) utils.imshow_batch(images, color_labels) # Get class weights # If a class weight file is provided, try loading weights from in there class_weights = None if args.class_weights_file: print('Trying to load class weights from file...') try: class_weights = np.loadtxt(args.class_weights_file) except Exception as e: raise e else: print('No class weights found...') if class_weights is not None: class_weights = torch.from_numpy(class_weights).float().to(device) # Set the weight of the unlabeled class to 0 if args.ignore_unlabeled: ignore_index = list(class_encoding).index('unlabeled') class_weights[ignore_index] = 0 print("Class weights:", class_weights) return test_loader, class_weights, class_encoding
def main(): assert os.path.isdir( args.dataset_dir), "The directory \"{0}\" doesn't exist.".format( args.dataset_dir) # Fail fast if the saving directory doesn't exist assert os.path.isdir( args.save_dir), "The directory \"{0}\" doesn't exist.".format( args.save_dir) # Import the requested dataset if args.dataset.lower() == 'cityscapes': from data import Cityscapes as dataset else: # Should never happen...but just in case it does raise RuntimeError("\"{0}\" is not a supported dataset.".format( args.dataset)) print("\nLoading dataset...\n") print("Selected dataset:", args.dataset) print("Dataset directory:", args.dataset_dir) print("Save directory:", args.save_dir) image_transform = transforms.Compose( [transforms.Resize((args.height, args.width)), transforms.ToTensor()]) label_transform = transforms.Compose([ transforms.Resize((args.height, args.width)), ext_transforms.PILToLongTensor() ]) # Get selected dataset # Load the training set as tensors train_set = dataset(args.dataset_dir, mode='train', max_iters=args.max_iters, transform=image_transform, label_transform=label_transform) train_loader = data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers) trainloader_iter = enumerate(train_loader) # Load the validation set as tensors val_set = dataset(args.dataset_dir, mode='val', max_iters=args.max_iters, transform=image_transform, label_transform=label_transform) val_loader = data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) # Load the test set as tensors test_set = dataset(args.dataset_dir, mode='test', max_iters=args.max_iters, transform=image_transform, label_transform=label_transform) test_loader = data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) # Get encoding between pixel valus in label images and RGB colors class_encoding = train_set.color_encoding # Get number of classes to predict num_classes = len(class_encoding) # Print information for debugging print("Number of classes to predict:", num_classes) print("Train dataset size:", len(train_set)) print("Validation dataset size:", len(val_set)) # Get the parameters for the validation set if args.mode.lower() == 'test': images, labels = iter(test_loader).next() else: images, labels = iter(train_loader).next() print("Image size:", images.size()) print("Label size:", labels.size()) print("Class-color encoding:", class_encoding) # Show a batch of samples and labels if args.imshow_batch: print("Close the figure window to continue...") label_to_rgb = transforms.Compose([ ext_transforms.LongTensorToRGBPIL(class_encoding), transforms.ToTensor() ]) color_labels = utils.batch_transform(labels, label_to_rgb) utils.imshow_batch(images, color_labels) # Get class weights from the selected weighing technique print("\nTraining...\n") num_classes = len(class_encoding) # Define the model with the encoder and decoder from the deeplabv2 input_encoder = Encoder().to(device) decoder_t = Decoder(num_classes).to(device) # Define the entropy loss for the segmentation task criterion = CrossEntropy2d() # Set the optimizer function for model optimizer_g = optim.SGD(itertools.chain(input_encoder.parameters(), decoder_t.parameters()), lr=args.learning_rate, momentum=0.9, weight_decay=1e-4) optimizer_g.zero_grad() # Evaluation metric if args.ignore_unlabeled: ignore_index = list(class_encoding).index('unlabeled') else: ignore_index = None metric = IoU(num_classes, ignore_index=ignore_index) # Optionally resume from a checkpoint if args.resume: input_encoder, decoder_t, optimizer_g, start_epoch, best_miou = utils.load_checkpoint( input_encoder, decoder_t, optimizer_g, args.save_dir, args.name) print("Resuming from model: Start epoch = {0} " "| Best mean IoU = {1:.4f}".format(start_epoch, best_miou)) else: start_epoch = 0 best_miou = 0 # Start Training print() metric.reset() val = Test(input_encoder, decoder_t, val_loader, criterion, metric, device) for i_iter in range(args.max_iters): optimizer_g.zero_grad() adjust_learning_rate(optimizer_g, i_iter) _, batch_data = trainloader_iter.__next__() inputs = batch_data[0].to(device) labels = batch_data[1].to(device) f_i = input_encoder(inputs) outputs_i = decoder_t(f_i) loss_seg = criterion(outputs_i, labels) loss_g = loss_seg loss_g.backward() optimizer_g.step() if i_iter % args.save_pred_every == 0 and i_iter != 0: print('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}'.format( i_iter, args.max_iters, loss_g)) print(">>>> [iter: {0:d}] Validation".format(i_iter)) # Validate the trained model after the weights are saved loss, (iou, miou) = val.run_epoch(args.print_step) print(">>>> [iter: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}". format(i_iter, loss, miou)) if miou > best_miou: for key, class_iou in zip(class_encoding.keys(), iou): print("{0}: {1:.4f}".format(key, class_iou)) # Save the model if it's the best thus far if miou > best_miou: print("\nBest model thus far. Saving...\n") best_miou = miou utils.save_checkpoint(input_encoder, decoder_t, optimizer_g, i_iter + 1, best_miou, args)
def load_dataset(dataset): print("\n加载数据...\n") print("选择的数据:", args.dataset) print("Dataset 目录:", args.dataset_dir) print("存储目录:", args.save_dir) # 数据转换和标准化 image_transform = transforms.Compose( [transforms.Resize((args.height, args.width)), transforms.ToTensor()]) # 转化:PILToLongTensor,因为是label,所以不能进行标准化 label_transform = transforms.Compose([ transforms.Resize((args.height, args.width)), ext_transforms.PILToLongTensor() # (H x W x C) 转到 (C x H x W ) ]) # 获取选定的数据集 # 加载数据集作为一个tensors train_set = dataset(args.dataset_dir, transform=image_transform, label_transform=label_transform) train_loader = data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers) # 加载验证集作为一个tensors val_set = dataset(args.dataset_dir, mode='val', transform=image_transform, label_transform=label_transform) val_loader = data.DataLoader(val_set, batch_size=3, shuffle=True, num_workers=args.workers) # 加载测试集作为一个tensors test_set = dataset(args.dataset_dir, mode='test', transform=image_transform, label_transform=label_transform) test_loader = data.DataLoader(test_set, batch_size=3, shuffle=True, num_workers=args.workers) # 获取标签图像和RGB颜色中的像素值之间的编码 class_encoding = train_set.color_encoding # 获取需要预测的类别的数量 num_classes = len(class_encoding) # 打印调试的信息 print("Number of classes to predict:", num_classes) print("Train dataset size:", len(train_set)) print("Validation dataset size:", len(val_set)) # 展示一个batch的样本 if args.mode.lower() == 'test': images, labels = iter(test_loader).next() else: images, labels = iter(train_loader).next() print("Image size:", images.size()) print("Label size:", labels.size()) print("Class-color encoding:", class_encoding) # 展示一个batch的samples和labels if args.imshow_batch: print("Close the figure window to continue...") label_to_rgb = transforms.Compose([ ext_transforms.LongTensorToRGBPIL(class_encoding), transforms.ToTensor() ]) color_labels = utils.batch_transform(labels, label_to_rgb) utils.imshow_batch(images, color_labels) # 获取类别的权重 print("\nWeighing technique:", args.weighing) print("Computing class weights...") print("(this can take a while depending on the dataset size)") class_weights = 0 if args.weighing.lower() == 'enet': # 传回的class_weights是一个list class_weights = np.array([ 1.44752114, 33.41317956, 43.89576605, 47.85765692, 48.3393951, 47.18958997, 40.2809274, 46.61960781, 48.28854284 ]) # class_weights = enet_weighing(train_loader, num_classes) else: class_weights = None if class_weights is not None: class_weights = torch.Tensor(class_weights) # 把没有标记的类别设置为0 # if args.ignore_unlabeled: # ignore_index = list(class_encoding).index('unlabeled') # class_weights[ignore_index] = 0 print("Class weights:", class_weights) return (train_loader, val_loader, test_loader), class_weights, class_encoding
def run_epoch(self, epoch, iteration_loss=False): """Runs an epoch of training. Keyword arguments: - iteration_loss (``bool``, optional): Prints loss at every step. Returns: - The epoch loss (float). """ self.model.train() epoch_loss = 0.0 self.metric.reset() for step, batch_data in enumerate(self.data_loader): if step > 0: break # Get the inputs and labels inputs = batch_data[0].to(self.device) labels = batch_data[1].to(self.device) #print('labels!!!!!!!!!!!!!!!!!!!!!!',labels) # Forward propagation outputs = self.model(inputs) _, preds = torch.max(outputs.data, 1) # Loss computation loss = self.criterion(outputs, labels) # Backpropagation self.optim.zero_grad() loss.backward() self.optim.step() # Keep track of loss for current epoch epoch_loss += loss.item() # Keep track of the evaluation metric self.metric.add(outputs.detach(), labels.detach()) if iteration_loss: print("[Step: %d] Iteration loss: %.4f" % (step, loss.item())) if epoch == 59 or epoch == 99 or epoch == 159 or epoch == 199: if step % 500 == 0: print('Visualization of Train:') label_to_rgb = transforms.Compose([ ext_transforms.LongTensorToRGBPIL(self.color_encoding), transforms.ToTensor() ]) color_labels = utils.batch_transform( labels.cpu(), label_to_rgb) color_outputs = utils.batch_transform( preds.cpu(), label_to_rgb) utils.imshow_batch(color_outputs, color_labels) if epoch == 80: torch.save(self.model.state_dict(), '/home/wan/PyTorch-ENet/model_80') if epoch == 120: torch.save(self.model.state_dict(), '/home/wan/PyTorch-ENet/model_120') if epoch == 140: torch.save(self.model.state_dict(), '/home/wan/PyTorch-ENet/model_140') if epoch == 160: torch.save(self.model.state_dict(), '/home/wan/PyTorch-ENet/model_160') if epoch == 180: torch.save(self.model.state_dict(), '/home/wan/PyTorch-ENet/model_180') if epoch == 199: torch.save(self.model.state_dict(), '/home/wan/PyTorch-ENet/model_200') return epoch_loss / len(self.data_loader), self.metric.value()
def train(train_loader, val_loader, class_weights, class_encoding): print("\nTraining...\n") num_classes = len(class_encoding) # Intialize ENet model = ENet(num_classes).to(device) # Check if the network architecture is correct print(model) # We are going to use the CrossEntropyLoss loss function as it's most # frequentely used in classification problems with multiple classes which # fits the problem. This criterion combines LogSoftMax and NLLLoss. criterion = nn.CrossEntropyLoss(weight=class_weights) # ENet authors used Adam as the optimizer optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) # Learning rate decay scheduler lr_updater = lr_scheduler.StepLR(optimizer, args.lr_decay_epochs, args.lr_decay) # Evaluation metric if args.ignore_unlabeled: ignore_index = list(class_encoding).index('unlabeled') else: ignore_index = None metric = IoU(num_classes, ignore_index=ignore_index) # Optionally resume from a checkpoint if args.resume: model, optimizer, start_epoch, best_miou = utils.load_checkpoint( model, optimizer, args.save_dir, args.name) print("Resuming from model: Start epoch = {0} " "| Best mean IoU = {1:.4f}".format(start_epoch, best_miou)) else: start_epoch = 0 best_miou = 0 # Start Training print() train = Train(model, train_loader, optimizer, criterion, metric, device) val = Test(model, val_loader, criterion, metric, device) for epoch in range(start_epoch, args.epochs): print(">>>> [Epoch: {0:d}] Training".format(epoch)) lr_updater.step() epoch_loss, (iou, miou) = train.run_epoch(args.print_step) # Visualization by TensorBoardX writer.add_scalar('data/train/loss', epoch_loss, epoch) writer.add_scalar('data/train/mean_IoU', miou, epoch) print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}". format(epoch, epoch_loss, miou)) if (epoch + 1) % 1 == 0 or epoch + 1 == args.epochs: print(">>>> [Epoch: {0:d}] Validation".format(epoch)) loss, (iou, miou) = val.run_epoch(args.print_step) # Visualization by TensorBoardX writer.add_scalar('data/val/loss', loss, epoch) writer.add_scalar('data/val/mean_IoU', miou, epoch) print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}". format(epoch, loss, miou)) # Print per class IoU on last epoch or if best iou if epoch + 1 == args.epochs or miou > best_miou: for key, class_iou in zip(class_encoding.keys(), iou): print("{0}: {1:.4f}".format(key, class_iou)) # Save the model if it's the best thus far if miou > best_miou: print("\nBest model thus far. Saving...\n") best_miou = miou utils.save_checkpoint(model, optimizer, epoch + 1, best_miou, args) # Visualization of the predicted batch in TensorBoard for i, batch in enumerate(val_loader): if i == 1: break # Get the inputs and labels inputs = batch[0].to(device) labels = batch[1].to(device) # Forward propagation with torch.no_grad(): predictions = model(inputs) # Predictions is one-hot encoded with "num_classes" channels. # Convert it to a single int using the indices where the maximum (1) occurs _, predictions = torch.max(predictions.data, 1) label_to_rgb = transforms.Compose([ ext_transforms.LongTensorToRGBPIL(class_encoding), transforms.ToTensor() ]) color_predictions = utils.batch_transform( predictions.cpu(), label_to_rgb) in_training_visualization(model, inputs, labels, class_encoding, writer, epoch, 'val') return model