コード例 #1
0
ファイル: main.py プロジェクト: snoofalus/PyTorch-ENet
def load_dataset(dataset):
    print("\nLoading dataset...\n")

    print("Selected dataset:", args.dataset)
    print("Dataset directory:", args.dataset_dir)
    print("Save directory:", args.save_dir)

    image_transform = transforms.Compose(
        [transforms.Resize((args.height, args.width)),
         transforms.ToTensor()])

    label_transform = transforms.Compose([
        transforms.Resize((args.height, args.width), transforms.InterpolationMode.NEAREST),
        ext_transforms.PILToLongTensor()
    ])

    # Get selected dataset
    # Load the training set as tensors
    train_set = dataset(
        args.dataset_dir,
        transform=image_transform,
        label_transform=label_transform)
    train_loader = data.DataLoader(
        train_set,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers)

    # Load the validation set as tensors
    val_set = dataset(
        args.dataset_dir,
        mode='val',
        transform=image_transform,
        label_transform=label_transform)
    val_loader = data.DataLoader(
        val_set,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers)

    # Load the test set as tensors
    test_set = dataset(
        args.dataset_dir,
        mode='test',
        transform=image_transform,
        label_transform=label_transform)
    test_loader = data.DataLoader(
        test_set,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers)

    # Get encoding between pixel valus in label images and RGB colors
    class_encoding = train_set.color_encoding

    # Remove the road_marking class from the CamVid dataset as it's merged
    # with the road class
    if args.dataset.lower() == 'camvid':
        del class_encoding['road_marking']

    # Get number of classes to predict
    num_classes = len(class_encoding)

    # Print information for debugging
    print("Number of classes to predict:", num_classes)
    print("Train dataset size:", len(train_set))
    print("Validation dataset size:", len(val_set))

    # Get a batch of samples to display
    if args.mode.lower() == 'test':
        images, labels = iter(test_loader).next()
    else:
        images, labels = iter(train_loader).next()
    print("Image size:", images.size())
    print("Label size:", labels.size())
    print("Class-color encoding:", class_encoding)

    # Show a batch of samples and labels
    if args.imshow_batch:
        print("Close the figure window to continue...")
        label_to_rgb = transforms.Compose([
            ext_transforms.LongTensorToRGBPIL(class_encoding),
            transforms.ToTensor()
        ])
        color_labels = utils.batch_transform(labels, label_to_rgb)
        utils.imshow_batch(images, color_labels)

    # Get class weights from the selected weighing technique
    print("\nWeighing technique:", args.weighing)
    print("Computing class weights...")
    print("(this can take a while depending on the dataset size)")
    class_weights = 0
    if args.weighing.lower() == 'enet':
        class_weights = enet_weighing(train_loader, num_classes)
    elif args.weighing.lower() == 'mfb':
        class_weights = median_freq_balancing(train_loader, num_classes)
    else:
        class_weights = None

    if class_weights is not None:
        class_weights = torch.from_numpy(class_weights).float().to(device)
        # Set the weight of the unlabeled class to 0
        if args.ignore_unlabeled:
            ignore_index = list(class_encoding).index('unlabeled')
            class_weights[ignore_index] = 0

    print("Class weights:", class_weights)

    return (train_loader, val_loader,
            test_loader), class_weights, class_encoding
コード例 #2
0
    train_loader = data.DataLoader(train_dataset,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   pin_memory=pin_memory)
    valid_dataset = CamVid(root_dir=root_dir,
                           mode='val',
                           transform=image_transform,
                           label_transform=label_transform)
    valid_loader = data.DataLoader(valid_dataset,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   pin_memory=pin_memory)
    class_encoding = train_dataset.color_encoding
    del class_encoding['road_marking']
    num_classes = len(class_encoding)
    class_weights = median_freq_balancing(train_loader, num_classes)
    class_weights = torch.from_numpy(class_weights).float().to(device)

    print('--- Building model ---')
    model = UNet(in_channels=3, out_channels=num_classes,
                 feature_scale=1).to(device)
    model_name = 'UNet'
    # criterion = nn.CrossEntropyLoss().to(device)
    criterion = nn.CrossEntropyLoss(weight=class_weights).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = lr_scheduler.StepLR(optimizer,
                                    step_size=lr_decay_epoches,
                                    gamma=lr_decay)

    train_obj = Train(model=model,
                      data_loader=train_loader,
コード例 #3
0
ファイル: main.py プロジェクト: iamstg/vegans
def load_dataset(dataset):
    print("\nLoading dataset...\n")

    print("Selected dataset:", args.dataset)
    print("Dataset directory:", args.dataset_dir)
    print('Train file:', args.trainFile)
    print('Val file:', args.valFile)
    print('Test file:', args.testFile)
    print("Save directory:", args.save_dir)

    image_transform = transforms.Compose(
        [transforms.Resize((args.height, args.width)),
         transforms.ToTensor()])

    label_transform = transforms.Compose([
        transforms.Resize((args.height, args.width)),
        ext_transforms.PILToLongTensor()
    ])

    # Get selected dataset
    # Load the training set as tensors
    train_set = dataset(args.dataset_dir, args.trainFile, mode='train', transform=image_transform, \
     label_transform=label_transform, color_mean=color_mean, color_std=color_std)
    train_loader = data.DataLoader(train_set,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.workers)

    # Load the validation set as tensors
    val_set = dataset(args.dataset_dir, args.valFile, mode='val', transform=image_transform, \
     label_transform=label_transform, color_mean=color_mean, color_std=color_std)
    val_loader = data.DataLoader(val_set,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.workers)

    # Load the test set as tensors
    test_set = dataset(args.dataset_dir, args.testFile, mode='inference', transform=image_transform, \
     label_transform=label_transform, color_mean=color_mean, color_std=color_std)
    test_loader = data.DataLoader(test_set,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=args.workers)

    # Get encoding between pixel valus in label images and RGB colors
    class_encoding = train_set.color_encoding

    # Get number of classes to predict
    num_classes = len(class_encoding)

    # Print information for debugging
    print("Number of classes to predict:", num_classes)
    print("Train dataset size:", len(train_set))
    print("Validation dataset size:", len(val_set))

    # Get a batch of samples to display
    if args.mode.lower() == 'test':
        images, labels = iter(test_loader).next()
    else:
        images, labels = iter(train_loader).next()
    print("Image size:", images.size())
    print("Label size:", labels.size())
    print("Class-color encoding:", class_encoding)

    # Show a batch of samples and labels
    if args.imshow_batch:
        print("Close the figure window to continue...")
        label_to_rgb = transforms.Compose([
            ext_transforms.LongTensorToRGBPIL(class_encoding),
            transforms.ToTensor()
        ])
        color_labels = utils.batch_transform(labels, label_to_rgb)
        utils.imshow_batch(images, color_labels)

    # Get class weights from the selected weighing technique
    print("Weighing technique:", args.weighing)
    # If a class weight file is provided, try loading weights from in there
    class_weights = None
    if args.class_weights_file:
        print('Trying to load class weights from file...')
        try:
            class_weights = np.loadtxt(args.class_weights_file)
        except Exception as e:
            raise e
    if class_weights is None:
        print("Computing class weights...")
        print("(this can take a while depending on the dataset size)")
        class_weights = 0
        if args.weighing.lower() == 'enet':
            class_weights = enet_weighing(train_loader, num_classes)
        elif args.weighing.lower() == 'mfb':
            class_weights = median_freq_balancing(train_loader, num_classes)
        else:
            class_weights = None

    if class_weights is not None:
        class_weights = torch.from_numpy(class_weights).float().to(device)
        # Set the weight of the unlabeled class to 0
        print("Ignoring unlabeled class: ", args.ignore_unlabeled)
        if args.ignore_unlabeled:
            ignore_index = list(class_encoding).index('unlabeled')
            class_weights[ignore_index] = 0

    print("Class weights:", class_weights)

    return (train_loader, val_loader,
            test_loader), class_weights, class_encoding
def get_data_loaders(dataset,
                     train_batch_size,
                     test_batch_size,
                     val_batch_size,
                     single_sample=False):
    print("\nLoading dataset...\n")

    print("Selected dataset:", dataset)
    print("Dataset directory:", args.dataset_dir)
    print("Save directory:", args.save_dir)

    image_transform = transforms.Compose(
        [transforms.Resize((args.height, args.width)),
         transforms.ToTensor()])

    label_transform = transforms.Compose([
        transforms.Resize((args.height, args.width), Image.NEAREST),
        ext_transforms.PILToLongTensor()
    ])

    # Get selected dataset
    # Load the training set as tensors
    train_set = dataset(args.dataset_dir,
                        transform=image_transform,
                        label_transform=label_transform)

    if single_sample:
        print("Reducing Training set to single batch of size {}".format(
            train_batch_size))
        train_set_reduced = torch.utils.data.Subset(
            train_set, list(range(0, train_batch_size)))

        train_loader = data.DataLoader(
            train_set_reduced,
            batch_size=train_batch_size,
            shuffle=False,  #Changed this 
            num_workers=args.workers)

    else:
        train_loader = data.DataLoader(
            train_set,
            batch_size=train_batch_size,
            shuffle=False,  #Changed this 
            num_workers=args.workers)

    # Load the validation set as tensors
    val_set = dataset(args.dataset_dir,
                      mode='val',
                      transform=image_transform,
                      label_transform=label_transform)
    val_loader = data.DataLoader(val_set,
                                 batch_size=val_batch_size,
                                 shuffle=False,
                                 num_workers=args.workers)

    # Load the test set as tensors
    test_set = dataset(args.dataset_dir,
                       mode='test',
                       transform=image_transform,
                       label_transform=label_transform)
    test_loader = data.DataLoader(test_set,
                                  batch_size=test_batch_size,
                                  shuffle=False,
                                  num_workers=args.workers)

    # Get encoding between pixel valus in label images and RGB colors
    class_encoding = train_set.color_encoding

    # Remove the road_marking class from the CamVid dataset as it's merged
    # with the road class
    if args.dataset.lower() == 'camvid':
        del class_encoding['road_marking']

    # Get number of classes to predict
    num_classes = len(class_encoding)

    # Print information for debugging
    print("Number of classes to predict:", num_classes)
    print("Train dataset size:", len(train_set))
    print("Test dataset size:", len(test_set))
    print("Validation dataset size:", len(val_set))

    class_weights = 0

    # Get class weights from the selected weighing technique
    print("Computing class weights...")
    print("(this can take a while depending on the dataset size)")
    class_weights = 0

    class_weights = median_freq_balancing(train_loader, num_classes)

    if class_weights is not None:
        class_weights = torch.from_numpy(class_weights).float().to(device)
        # Set the weight of the unlabeled class to 0
        if args.ignore_unlabeled:
            ignore_index = list(class_encoding).index('unlabeled')
            class_weights[ignore_index] = 0

    print("Class weights:", class_weights)

    return (train_loader, val_loader,
            test_loader), class_weights, class_encoding, (train_set, val_set,
                                                          test_set)