コード例 #1
0
def load_dataset_al_train(dataset, args, batchsize):

    image_transform = transforms.Compose([
        transforms.Resize((args.height_train, args.width_train)),
        transforms.ToTensor()
    ])

    label_transform = transforms.Compose([
        transforms.Resize((args.height_train, args.width_train),
                          Image.NEAREST),
        ext_transforms.PILToLongTensor()
    ])

    train_al_set = dataset(args.dataset_dir,
                           mode='train_al',
                           transform=image_transform,
                           label_transform=label_transform)
    train_al_loader = data.DataLoader(
        train_al_set,
        batch_size=batchsize,
        shuffle=True,
        # drop_last=True,
        num_workers=args.workers)  #args.workers)

    class_encoding = train_al_set.color_encoding.copy()
    if args.dataset.lower() == 'camvid':
        # if 'road_marking' in class_encoding:
        del class_encoding['road_marking']
    # num_classes = len(class_encoding)

    print("Size of data in the dataloader", len(train_al_set))

    return (train_al_loader), class_encoding
コード例 #2
0
def predict():
    image_transform = transforms.Compose(
        [transforms.Resize(target_size),
         transforms.ToTensor()])

    label_transform = transforms.Compose(
        [transforms.Resize(target_size),
         ext_transforms.PILToLongTensor()])

    # Get selected dataset
    # Load the training set as tensors
    train_set = Cityscapes(data_dir,
                           mode='test',
                           transform=image_transform,
                           label_transform=label_transform)

    class_encoding = train_set.color_encoding

    num_classes = len(class_encoding)
    model = ENet(num_classes).to(device)

    # Initialize a optimizer just so we can retrieve the model from the
    # checkpoint
    optimizer = optim.Adam(model.parameters())

    # Load the previoulsy saved model state to the ENet model
    model = utils.load_checkpoint(model, optimizer, 'save',
                                  'ENet_cityscapes_mine.pth')[0]
    # print(model)

    image = Image.open('images/mainz_000000_008001_leftImg8bit.png')
    images = Variable(image_transform(image).to(device).unsqueeze(0))
    image = np.array(image)

    # Make predictions!
    predictions = model(images)
    _, predictions = torch.max(predictions.data, 1)
    # 0~18
    prediction = predictions.cpu().numpy()[0] - 1

    mask_color = np.asarray(label_to_color_image(prediction, 'cityscapes'),
                            dtype=np.uint8)
    mask_color = cv2.resize(mask_color, (image.shape[1], image.shape[0]))
    print(image.shape)
    print(mask_color.shape)
    res = cv2.addWeighted(image, 0.3, mask_color, 0.7, 0.6)
    # cv2.imshow('rr', mask_color)
    cv2.imshow('combined', res)
    cv2.waitKey(0)
コード例 #3
0
ファイル: main.py プロジェクト: snoofalus/PyTorch-ENet
def load_dataset(dataset):
    print("\nLoading dataset...\n")

    print("Selected dataset:", args.dataset)
    print("Dataset directory:", args.dataset_dir)
    print("Save directory:", args.save_dir)

    image_transform = transforms.Compose(
        [transforms.Resize((args.height, args.width)),
         transforms.ToTensor()])

    label_transform = transforms.Compose([
        transforms.Resize((args.height, args.width), transforms.InterpolationMode.NEAREST),
        ext_transforms.PILToLongTensor()
    ])

    # Get selected dataset
    # Load the training set as tensors
    train_set = dataset(
        args.dataset_dir,
        transform=image_transform,
        label_transform=label_transform)
    train_loader = data.DataLoader(
        train_set,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers)

    # Load the validation set as tensors
    val_set = dataset(
        args.dataset_dir,
        mode='val',
        transform=image_transform,
        label_transform=label_transform)
    val_loader = data.DataLoader(
        val_set,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers)

    # Load the test set as tensors
    test_set = dataset(
        args.dataset_dir,
        mode='test',
        transform=image_transform,
        label_transform=label_transform)
    test_loader = data.DataLoader(
        test_set,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers)

    # Get encoding between pixel valus in label images and RGB colors
    class_encoding = train_set.color_encoding

    # Remove the road_marking class from the CamVid dataset as it's merged
    # with the road class
    if args.dataset.lower() == 'camvid':
        del class_encoding['road_marking']

    # Get number of classes to predict
    num_classes = len(class_encoding)

    # Print information for debugging
    print("Number of classes to predict:", num_classes)
    print("Train dataset size:", len(train_set))
    print("Validation dataset size:", len(val_set))

    # Get a batch of samples to display
    if args.mode.lower() == 'test':
        images, labels = iter(test_loader).next()
    else:
        images, labels = iter(train_loader).next()
    print("Image size:", images.size())
    print("Label size:", labels.size())
    print("Class-color encoding:", class_encoding)

    # Show a batch of samples and labels
    if args.imshow_batch:
        print("Close the figure window to continue...")
        label_to_rgb = transforms.Compose([
            ext_transforms.LongTensorToRGBPIL(class_encoding),
            transforms.ToTensor()
        ])
        color_labels = utils.batch_transform(labels, label_to_rgb)
        utils.imshow_batch(images, color_labels)

    # Get class weights from the selected weighing technique
    print("\nWeighing technique:", args.weighing)
    print("Computing class weights...")
    print("(this can take a while depending on the dataset size)")
    class_weights = 0
    if args.weighing.lower() == 'enet':
        class_weights = enet_weighing(train_loader, num_classes)
    elif args.weighing.lower() == 'mfb':
        class_weights = median_freq_balancing(train_loader, num_classes)
    else:
        class_weights = None

    if class_weights is not None:
        class_weights = torch.from_numpy(class_weights).float().to(device)
        # Set the weight of the unlabeled class to 0
        if args.ignore_unlabeled:
            ignore_index = list(class_encoding).index('unlabeled')
            class_weights[ignore_index] = 0

    print("Class weights:", class_weights)

    return (train_loader, val_loader,
            test_loader), class_weights, class_encoding
コード例 #4
0
def load_dataset_activelearning(dataset, args):
    print("\nLoading dataset...")

    print("Selected Backbone:", args.backbone)
    print("Selected dataset:", args.dataset)
    print("Dataset directory:", args.dataset_dir)
    print("Save directory:", args.save_dir)

    image_transform = transforms.Compose(
        [transforms.Resize((args.height, args.width)),
         transforms.ToTensor()])

    label_transform = transforms.Compose([
        transforms.Resize((args.height, args.width), Image.NEAREST),
        ext_transforms.PILToLongTensor()
    ])

    train_set = dataset(args.dataset_dir,
                        transform=image_transform,
                        label_transform=label_transform)
    train_loader = data.DataLoader(
        train_set,
        batch_size=args.batch_size_AL,
        shuffle=False,
        pin_memory=True,
        # drop_last=True,
        num_workers=args.workers)

    val_set = dataset(args.dataset_dir,
                      mode='val',
                      transform=image_transform,
                      label_transform=label_transform)
    val_loader = data.DataLoader(
        val_set,
        batch_size=args.batch_size_AL,
        shuffle=False,
        pin_memory=True,
        # drop_last=True,
        num_workers=args.workers)

    # Get encoding between pixel valus in label images and RGB colors
    class_encoding = train_set.color_encoding.copy()

    # Remove the road_marking class from the CamVid dataset as it's merged
    # with the road class
    if args.dataset.lower() == 'camvid':
        del class_encoding['road_marking']

    # Get number of classes to predict
    num_classes = len(class_encoding)

    # Print information for debuggingFs
    print("Number of classes to predict:", num_classes)
    print("Train dataset size:", len(train_set))
    print("Required steps for each epoch: {}".format(len(train_set) // 1))
    print("Validation dataset size:", len(val_set))

    # Get a batch of samples to display
    images, labels, _, _ = iter(train_loader).next()
    print("Image size:", images.size())
    print("Label size:", labels.size())
    print("\n")
    print("Class-color encoding:", class_encoding)

    return (train_loader, val_loader), class_encoding
コード例 #5
0
ファイル: main.py プロジェクト: iamstg/vegans
def load_dataset(dataset):
    print("\nLoading dataset...\n")

    print("Selected dataset:", args.dataset)
    print("Dataset directory:", args.dataset_dir)
    print('Train file:', args.trainFile)
    print('Val file:', args.valFile)
    print('Test file:', args.testFile)
    print("Save directory:", args.save_dir)

    image_transform = transforms.Compose(
        [transforms.Resize((args.height, args.width)),
         transforms.ToTensor()])

    label_transform = transforms.Compose([
        transforms.Resize((args.height, args.width)),
        ext_transforms.PILToLongTensor()
    ])

    # Get selected dataset
    # Load the training set as tensors
    train_set = dataset(args.dataset_dir, args.trainFile, mode='train', transform=image_transform, \
     label_transform=label_transform, color_mean=color_mean, color_std=color_std)
    train_loader = data.DataLoader(train_set,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.workers)

    # Load the validation set as tensors
    val_set = dataset(args.dataset_dir, args.valFile, mode='val', transform=image_transform, \
     label_transform=label_transform, color_mean=color_mean, color_std=color_std)
    val_loader = data.DataLoader(val_set,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.workers)

    # Load the test set as tensors
    test_set = dataset(args.dataset_dir, args.testFile, mode='inference', transform=image_transform, \
     label_transform=label_transform, color_mean=color_mean, color_std=color_std)
    test_loader = data.DataLoader(test_set,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=args.workers)

    # Get encoding between pixel valus in label images and RGB colors
    class_encoding = train_set.color_encoding

    # Get number of classes to predict
    num_classes = len(class_encoding)

    # Print information for debugging
    print("Number of classes to predict:", num_classes)
    print("Train dataset size:", len(train_set))
    print("Validation dataset size:", len(val_set))

    # Get a batch of samples to display
    if args.mode.lower() == 'test':
        images, labels = iter(test_loader).next()
    else:
        images, labels = iter(train_loader).next()
    print("Image size:", images.size())
    print("Label size:", labels.size())
    print("Class-color encoding:", class_encoding)

    # Show a batch of samples and labels
    if args.imshow_batch:
        print("Close the figure window to continue...")
        label_to_rgb = transforms.Compose([
            ext_transforms.LongTensorToRGBPIL(class_encoding),
            transforms.ToTensor()
        ])
        color_labels = utils.batch_transform(labels, label_to_rgb)
        utils.imshow_batch(images, color_labels)

    # Get class weights from the selected weighing technique
    print("Weighing technique:", args.weighing)
    # If a class weight file is provided, try loading weights from in there
    class_weights = None
    if args.class_weights_file:
        print('Trying to load class weights from file...')
        try:
            class_weights = np.loadtxt(args.class_weights_file)
        except Exception as e:
            raise e
    if class_weights is None:
        print("Computing class weights...")
        print("(this can take a while depending on the dataset size)")
        class_weights = 0
        if args.weighing.lower() == 'enet':
            class_weights = enet_weighing(train_loader, num_classes)
        elif args.weighing.lower() == 'mfb':
            class_weights = median_freq_balancing(train_loader, num_classes)
        else:
            class_weights = None

    if class_weights is not None:
        class_weights = torch.from_numpy(class_weights).float().to(device)
        # Set the weight of the unlabeled class to 0
        print("Ignoring unlabeled class: ", args.ignore_unlabeled)
        if args.ignore_unlabeled:
            ignore_index = list(class_encoding).index('unlabeled')
            class_weights[ignore_index] = 0

    print("Class weights:", class_weights)

    return (train_loader, val_loader,
            test_loader), class_weights, class_encoding
コード例 #6
0
ファイル: inference.py プロジェクト: krrish94/ENet-ScanNet
def load_dataset(dataset):
	print("\nLoading dataset...\n")

	print("Selected dataset:", args.dataset)
	print("Dataset directory:", args.dataset_dir)
	print('Test file:', args.testFile)
	print("Save directory:", args.save_dir)

	image_transform = transforms.Compose(
		[transforms.Resize((args.height, args.width)),
		 transforms.ToTensor()])

	label_transform = transforms.Compose([
		transforms.Resize((args.height, args.width)),
		ext_transforms.PILToLongTensor()
	])

	# Load the test set as tensors
	test_set = dataset(args.dataset_dir, args.testFile, mode='inference', transform=image_transform, \
		label_transform=label_transform, color_mean=color_mean, color_std=color_std, \
		load_depth=(args.arch=='rgbd'), seg_classes=args.seg_classes)
	test_loader = data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)

	# Get encoding between pixel valus in label images and RGB colors
	class_encoding = test_set.color_encoding

	# Get number of classes to predict
	num_classes = len(class_encoding)

	# Print information for debugging
	print("Number of classes to predict:", num_classes)
	print("Test dataset size:", len(test_set))

	# Get a batch of samples to display
	if args.arch == 'rgbd':
		images, labels, data_path, depth_path, label_path = iter(test_loader).next()
	else:
		images, labels, data_path, label_path = iter(test_loader).next()

	print("Image size:", images.size())
	print("Label size:", labels.size())
	print("Class-color encoding:", class_encoding)

	# Show a batch of samples and labels
	if args.imshow_batch:
		print("Close the figure window to continue...")
		label_to_rgb = transforms.Compose([
			ext_transforms.LongTensorToRGBPIL(class_encoding),
			transforms.ToTensor()
		])
		color_labels = utils.batch_transform(labels, label_to_rgb)
		utils.imshow_batch(images, color_labels)

	# Get class weights
	# If a class weight file is provided, try loading weights from in there
	class_weights = None
	if args.class_weights_file:
		print('Trying to load class weights from file...')
		try:
			class_weights = np.loadtxt(args.class_weights_file)
		except Exception as e:
			raise e
	else:
		print('No class weights found...')

	if class_weights is not None:
		class_weights = torch.from_numpy(class_weights).float().to(device)
		# Set the weight of the unlabeled class to 0
		if args.ignore_unlabeled:
			ignore_index = list(class_encoding).index('unlabeled')
			class_weights[ignore_index] = 0

	print("Class weights:", class_weights)

	return test_loader, class_weights, class_encoding
コード例 #7
0
ファイル: main.py プロジェクト: zack-yu666/PyTorch-deeplabv2
def main():
    assert os.path.isdir(
        args.dataset_dir), "The directory \"{0}\" doesn't exist.".format(
            args.dataset_dir)

    # Fail fast if the saving directory doesn't exist
    assert os.path.isdir(
        args.save_dir), "The directory \"{0}\" doesn't exist.".format(
            args.save_dir)

    # Import the requested dataset
    if args.dataset.lower() == 'cityscapes':
        from data import Cityscapes as dataset
    else:
        # Should never happen...but just in case it does
        raise RuntimeError("\"{0}\" is not a supported dataset.".format(
            args.dataset))
    print("\nLoading dataset...\n")

    print("Selected dataset:", args.dataset)
    print("Dataset directory:", args.dataset_dir)
    print("Save directory:", args.save_dir)

    image_transform = transforms.Compose(
        [transforms.Resize((args.height, args.width)),
         transforms.ToTensor()])

    label_transform = transforms.Compose([
        transforms.Resize((args.height, args.width)),
        ext_transforms.PILToLongTensor()
    ])

    # Get selected dataset
    # Load the training set as tensors
    train_set = dataset(args.dataset_dir,
                        mode='train',
                        max_iters=args.max_iters,
                        transform=image_transform,
                        label_transform=label_transform)
    train_loader = data.DataLoader(train_set,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.workers)

    trainloader_iter = enumerate(train_loader)

    # Load the validation set as tensors
    val_set = dataset(args.dataset_dir,
                      mode='val',
                      max_iters=args.max_iters,
                      transform=image_transform,
                      label_transform=label_transform)
    val_loader = data.DataLoader(val_set,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.workers)

    # Load the test set as tensors
    test_set = dataset(args.dataset_dir,
                       mode='test',
                       max_iters=args.max_iters,
                       transform=image_transform,
                       label_transform=label_transform)
    test_loader = data.DataLoader(test_set,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=args.workers)

    # Get encoding between pixel valus in label images and RGB colors
    class_encoding = train_set.color_encoding
    # Get number of classes to predict
    num_classes = len(class_encoding)

    # Print information for debugging
    print("Number of classes to predict:", num_classes)
    print("Train dataset size:", len(train_set))
    print("Validation dataset size:", len(val_set))

    # Get the parameters for the validation set
    if args.mode.lower() == 'test':
        images, labels = iter(test_loader).next()
    else:
        images, labels = iter(train_loader).next()
    print("Image size:", images.size())
    print("Label size:", labels.size())
    print("Class-color encoding:", class_encoding)

    # Show a batch of samples and labels
    if args.imshow_batch:
        print("Close the figure window to continue...")
        label_to_rgb = transforms.Compose([
            ext_transforms.LongTensorToRGBPIL(class_encoding),
            transforms.ToTensor()
        ])
        color_labels = utils.batch_transform(labels, label_to_rgb)
        utils.imshow_batch(images, color_labels)

    # Get class weights from the selected weighing technique

    print("\nTraining...\n")

    num_classes = len(class_encoding)
    # Define the model with the encoder and decoder from the deeplabv2
    input_encoder = Encoder().to(device)
    decoder_t = Decoder(num_classes).to(device)

    # Define the entropy loss for the segmentation task
    criterion = CrossEntropy2d()

    # Set the optimizer function for model
    optimizer_g = optim.SGD(itertools.chain(input_encoder.parameters(),
                                            decoder_t.parameters()),
                            lr=args.learning_rate,
                            momentum=0.9,
                            weight_decay=1e-4)

    optimizer_g.zero_grad()

    # Evaluation metric
    if args.ignore_unlabeled:
        ignore_index = list(class_encoding).index('unlabeled')
    else:
        ignore_index = None
    metric = IoU(num_classes, ignore_index=ignore_index)

    # Optionally resume from a checkpoint
    if args.resume:

        input_encoder, decoder_t, optimizer_g, start_epoch, best_miou = utils.load_checkpoint(
            input_encoder, decoder_t, optimizer_g, args.save_dir, args.name)
        print("Resuming from model: Start epoch = {0} "
              "| Best mean IoU = {1:.4f}".format(start_epoch, best_miou))
    else:
        start_epoch = 0
        best_miou = 0

    # Start Training
    print()

    metric.reset()

    val = Test(input_encoder, decoder_t, val_loader, criterion, metric, device)

    for i_iter in range(args.max_iters):

        optimizer_g.zero_grad()
        adjust_learning_rate(optimizer_g, i_iter)

        _, batch_data = trainloader_iter.__next__()
        inputs = batch_data[0].to(device)
        labels = batch_data[1].to(device)

        f_i = input_encoder(inputs)

        outputs_i = decoder_t(f_i)
        loss_seg = criterion(outputs_i, labels)

        loss_g = loss_seg
        loss_g.backward()
        optimizer_g.step()

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}'.format(
                i_iter, args.max_iters, loss_g))
            print(">>>> [iter: {0:d}] Validation".format(i_iter))

            # Validate the trained model after the weights are saved
            loss, (iou, miou) = val.run_epoch(args.print_step)

            print(">>>> [iter: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".
                  format(i_iter, loss, miou))

            if miou > best_miou:
                for key, class_iou in zip(class_encoding.keys(), iou):
                    print("{0}: {1:.4f}".format(key, class_iou))

            # Save the model if it's the best thus far
            if miou > best_miou:
                print("\nBest model thus far. Saving...\n")
                best_miou = miou
                utils.save_checkpoint(input_encoder, decoder_t, optimizer_g,
                                      i_iter + 1, best_miou, args)
コード例 #8
0
def load_dataset(dataset):
    print("\n加载数据...\n")

    print("选择的数据:", args.dataset)
    print("Dataset 目录:", args.dataset_dir)
    print("存储目录:", args.save_dir)

    # 数据转换和标准化
    image_transform = transforms.Compose(
        [transforms.Resize((args.height, args.width)),
         transforms.ToTensor()])

    # 转化:PILToLongTensor,因为是label,所以不能进行标准化
    label_transform = transforms.Compose([
        transforms.Resize((args.height, args.width)),
        ext_transforms.PILToLongTensor()  #  (H x W x C) 转到 (C x H x W )
    ])

    # 获取选定的数据集
    # 加载数据集作为一个tensors
    train_set = dataset(args.dataset_dir,
                        transform=image_transform,
                        label_transform=label_transform)
    train_loader = data.DataLoader(train_set,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.workers)

    # 加载验证集作为一个tensors
    val_set = dataset(args.dataset_dir,
                      mode='val',
                      transform=image_transform,
                      label_transform=label_transform)
    val_loader = data.DataLoader(val_set,
                                 batch_size=3,
                                 shuffle=True,
                                 num_workers=args.workers)

    # 加载测试集作为一个tensors
    test_set = dataset(args.dataset_dir,
                       mode='test',
                       transform=image_transform,
                       label_transform=label_transform)
    test_loader = data.DataLoader(test_set,
                                  batch_size=3,
                                  shuffle=True,
                                  num_workers=args.workers)

    # 获取标签图像和RGB颜色中的像素值之间的编码
    class_encoding = train_set.color_encoding

    # 获取需要预测的类别的数量
    num_classes = len(class_encoding)

    # 打印调试的信息
    print("Number of classes to predict:", num_classes)
    print("Train dataset size:", len(train_set))
    print("Validation dataset size:", len(val_set))

    # 展示一个batch的样本
    if args.mode.lower() == 'test':
        images, labels = iter(test_loader).next()
    else:
        images, labels = iter(train_loader).next()
    print("Image size:", images.size())
    print("Label size:", labels.size())
    print("Class-color encoding:", class_encoding)

    # 展示一个batch的samples和labels
    if args.imshow_batch:
        print("Close the figure window to continue...")
        label_to_rgb = transforms.Compose([
            ext_transforms.LongTensorToRGBPIL(class_encoding),
            transforms.ToTensor()
        ])
        color_labels = utils.batch_transform(labels, label_to_rgb)
        utils.imshow_batch(images, color_labels)

    # 获取类别的权重
    print("\nWeighing technique:", args.weighing)
    print("Computing class weights...")
    print("(this can take a while depending on the dataset size)")
    class_weights = 0
    if args.weighing.lower() == 'enet':
        # 传回的class_weights是一个list
        class_weights = np.array([
            1.44752114, 33.41317956, 43.89576605, 47.85765692, 48.3393951,
            47.18958997, 40.2809274, 46.61960781, 48.28854284
        ])
        # class_weights = enet_weighing(train_loader, num_classes)
    else:
        class_weights = None

    if class_weights is not None:
        class_weights = torch.Tensor(class_weights)
        # 把没有标记的类别设置为0
        # if args.ignore_unlabeled:
        #     ignore_index = list(class_encoding).index('unlabeled')
        #     class_weights[ignore_index] = 0

    print("Class weights:", class_weights)

    return (train_loader, val_loader,
            test_loader), class_weights, class_encoding
def get_data_loaders(dataset,
                     train_batch_size,
                     test_batch_size,
                     val_batch_size,
                     single_sample=False):
    print("\nLoading dataset...\n")

    print("Selected dataset:", dataset)
    print("Dataset directory:", args.dataset_dir)
    print("Save directory:", args.save_dir)

    image_transform = transforms.Compose(
        [transforms.Resize((args.height, args.width)),
         transforms.ToTensor()])

    label_transform = transforms.Compose([
        transforms.Resize((args.height, args.width), Image.NEAREST),
        ext_transforms.PILToLongTensor()
    ])

    # Get selected dataset
    # Load the training set as tensors
    train_set = dataset(args.dataset_dir,
                        transform=image_transform,
                        label_transform=label_transform)

    if single_sample:
        print("Reducing Training set to single batch of size {}".format(
            train_batch_size))
        train_set_reduced = torch.utils.data.Subset(
            train_set, list(range(0, train_batch_size)))

        train_loader = data.DataLoader(
            train_set_reduced,
            batch_size=train_batch_size,
            shuffle=False,  #Changed this 
            num_workers=args.workers)

    else:
        train_loader = data.DataLoader(
            train_set,
            batch_size=train_batch_size,
            shuffle=False,  #Changed this 
            num_workers=args.workers)

    # Load the validation set as tensors
    val_set = dataset(args.dataset_dir,
                      mode='val',
                      transform=image_transform,
                      label_transform=label_transform)
    val_loader = data.DataLoader(val_set,
                                 batch_size=val_batch_size,
                                 shuffle=False,
                                 num_workers=args.workers)

    # Load the test set as tensors
    test_set = dataset(args.dataset_dir,
                       mode='test',
                       transform=image_transform,
                       label_transform=label_transform)
    test_loader = data.DataLoader(test_set,
                                  batch_size=test_batch_size,
                                  shuffle=False,
                                  num_workers=args.workers)

    # Get encoding between pixel valus in label images and RGB colors
    class_encoding = train_set.color_encoding

    # Remove the road_marking class from the CamVid dataset as it's merged
    # with the road class
    if args.dataset.lower() == 'camvid':
        del class_encoding['road_marking']

    # Get number of classes to predict
    num_classes = len(class_encoding)

    # Print information for debugging
    print("Number of classes to predict:", num_classes)
    print("Train dataset size:", len(train_set))
    print("Test dataset size:", len(test_set))
    print("Validation dataset size:", len(val_set))

    class_weights = 0

    # Get class weights from the selected weighing technique
    print("Computing class weights...")
    print("(this can take a while depending on the dataset size)")
    class_weights = 0

    class_weights = median_freq_balancing(train_loader, num_classes)

    if class_weights is not None:
        class_weights = torch.from_numpy(class_weights).float().to(device)
        # Set the weight of the unlabeled class to 0
        if args.ignore_unlabeled:
            ignore_index = list(class_encoding).index('unlabeled')
            class_weights[ignore_index] = 0

    print("Class weights:", class_weights)

    return (train_loader, val_loader,
            test_loader), class_weights, class_encoding, (train_set, val_set,
                                                          test_set)
コード例 #10
0
    # checkpoint
    optimizer = optim.Adam(model.parameters())

    # Load the previoulsy saved model state to the ENet model
    model = utils.load_checkpoint(model, optimizer, args.save_dir,
                                  args.name)[0]
    # print(model)
    #inference(model, test_loader, w_class, class_encoding)

    image_transform = transforms.Compose(
        [transforms.Resize((args.height, args.width)),
         transforms.ToTensor()])

    label_transform = transforms.Compose([
        transforms.Resize((args.height, args.width)),
        ext_transforms.PILToLongTensor()
    ])

    if args.arch.lower() == 'rgb':
        image = scannet_loader(args.data_path)
        image = image.unsqueeze(dim=0)
        # print(image.size())
    elif args.arch.lower() == 'rgbd':
        image = scannet_loader_depth(args.data_path, args.depth_path)
        image = image.unsqueeze(dim=0)
    else:
        # This condition will not occur (argparse will fail if an invalid option is specified)
        raise RuntimeError(
            'Invalid network architecture for dataloader specified.')
    print(class_encoding)
    predict(model, image, class_encoding)