Ejemplo n.º 1
0
def get_test_dataloader(data_path, image_size, batch_size, mean, std):
    valid_transforms = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(settings.MEAN, settings.STD),
    ])

    valid_dataset = CamVid(
        settings.DATA_PATH,
        'val',
        transforms=valid_transforms,
    )

    print(len(valid_dataset))
    validation_loader = torch.utils.data.DataLoader(valid_dataset,
                                                    batch_size=args.b,
                                                    num_workers=4,
                                                    shuffle=True)

    return validation_loader
Ejemplo n.º 2
0
def get_train_dataloader(data_path, image_size, batch_size, mean, std):
    train_transforms = transforms.Compose([
        transforms.Resize(image_size),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    train_dataset = CamVid(
        data_path,
        'train',
        transforms=train_transforms,
    )

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.b,
                                               num_workers=4,
                                               shuffle=True)

    return train_loader
Ejemplo n.º 3
0
import time

import torch
import numpy as np

import transforms
from conf import settings
#from dataset.camvid_lmdb import CamVid
from dataset.camvid import CamVid

if __name__ == '__main__':


    train_dataset = CamVid(
        'data',
        image_set='train',
        download=True
    )
    valid_dataset = CamVid(
        'data',
        image_set='val',
        download=True
    )

    train_transforms = transforms.Compose([
            transforms.Resize(settings.IMAGE_SIZE),
            transforms.RandomRotation(15, fill=train_dataset.ignore_index),
            transforms.RandomGaussianBlur(),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.4, 0.4),
            transforms.ToTensor(),
Ejemplo n.º 4
0
    # my_trans.RandScale([0.5, 2.0]),
    # my_trans.RandomGaussianBlur(),
    my_trans.RandomHorizontalFlip(),
    # my_trans.Crop([args.height, args.width],crop_type='rand', padding=mean, ignore_label=255),
    my_trans.ToTensor(),  # without div 255
    my_trans.Normalize(mean=mean, std=std)
])
val_transform = my_trans.Compose([
    # my_trans.Resize((args.height, args.width)),
    my_trans.ToTensor(),  # without div 255
    my_trans.Normalize(mean=mean, std=std)
])

data_dir = '/data/zzg/CamVid/'
train_dataset = CamVid(data_dir,
                       mode='train',
                       p=None,
                       transform=train_transform)
trainloader = torch.utils.data.DataLoader(train_dataset,
                                          batch_size=args.batch_size,
                                          shuffle=True,
                                          num_workers=4,
                                          pin_memory=True)
valid_dataset = CamVid(data_dir, mode='val', p=None, transform=val_transform)
valloader = torch.utils.data.DataLoader(valid_dataset,
                                        batch_size=args.batch_size,
                                        shuffle=False,
                                        num_workers=4,
                                        pin_memory=True)

epoch_steps = int(len(train_dataset) / args.batch_size)
save_steps = epoch_steps * 10
Ejemplo n.º 5
0
    root_path = os.path.dirname(os.path.abspath(__file__))

    checkpoint_path = os.path.join(root_path, settings.CHECKPOINT_FOLDER,
                                   settings.TIME_NOW)
    log_dir = os.path.join(root_path, settings.LOG_FOLDER, settings.TIME_NOW)

    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    checkpoint_path = os.path.join(checkpoint_path, '{epoch}-{type}.pth')

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    writer = SummaryWriter(log_dir=log_dir)

    train_dataset = CamVid(settings.DATA_PATH, 'train')
    valid_dataset = CamVid(settings.DATA_PATH, 'val')

    train_transforms = transforms.Compose([
        transforms.RandomRotation(value=train_dataset.ignore_index),
        transforms.RandomScale(value=train_dataset.ignore_index),
        transforms.RandomGaussianBlur(),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(),
        transforms.Resize(settings.IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(settings.MEAN, settings.STD),
    ])

    valid_transforms = transforms.Compose([
        transforms.Resize(settings.IMAGE_SIZE),
Ejemplo n.º 6
0
def get_dataloader(args):
    if args.dataset.lower()=='mnist':
        train_loader = torch.utils.data.DataLoader( 
            datasets.MNIST(args.data_root, train=True, download=True,
                       transform=transforms.Compose([
                           transforms.Resize((32, 32)),
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                        ])),
            batch_size=args.batch_size, shuffle=True, num_workers=2)
        test_loader = torch.utils.data.DataLoader( 
            datasets.MNIST(args.data_root, train=False, download=True,
                      transform=transforms.Compose([
                          transforms.Resize((32, 32)),
                          transforms.ToTensor(),
                          transforms.Normalize((0.1307,), (0.3081,))
                        ])),
            batch_size=args.batch_size, shuffle=True, num_workers=2)

    elif args.dataset.lower()=='cifar10':
        train_loader = torch.utils.data.DataLoader( 
            datasets.CIFAR10(args.data_root, train=True, download=True,
                       transform=transforms.Compose([
                            transforms.RandomCrop(32, padding=4),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                        ])),
            batch_size=args.batch_size, shuffle=True, num_workers=2)
        test_loader = torch.utils.data.DataLoader( 
            datasets.CIFAR10(args.data_root, train=False, download=True,
                       transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                        ])),
            batch_size=args.batch_size, shuffle=True, num_workers=2)
    elif args.dataset.lower()=='cifar100':
        train_loader = torch.utils.data.DataLoader( 
            datasets.CIFAR100(args.data_root, train=True, download=True,
                       transform=transforms.Compose([
                            transforms.RandomCrop(32, padding=4),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                        ])),
            batch_size=args.batch_size, shuffle=True, num_workers=2)
        test_loader = torch.utils.data.DataLoader( 
            datasets.CIFAR100(args.data_root, train=False, download=True,
                       transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                        ])),
            batch_size=args.batch_size, shuffle=True, num_workers=2)
    elif args.dataset.lower()=='caltech101':
        train_loader = torch.utils.data.DataLoader(
            Caltech101(args.data_root, train=True, download=args.download,
                        transform=transforms.Compose([
                            transforms.Resize(128),
                            transforms.RandomCrop(128),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5,), (0.5,))
                        ])),
            batch_size=args.batch_size, shuffle=True, num_workers=2)
        test_loader = torch.utils.data.DataLoader(
            Caltech101(args.data_root, train=False, download=args.download, 
                        transform=transforms.Compose([
                            transforms.Resize(128),
                            transforms.CenterCrop(128),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5,), (0.5,))
                        ])), 
            batch_size=args.test_batch_size, shuffle=False, num_workers=2)

    ############ Segmentation       
    elif args.dataset.lower()=='camvid':
        train_loader = torch.utils.data.DataLoader(
            CamVid(args.data_root, split='train',
                        transform=ext_transforms.ExtCompose([
                            ext_transforms.ExtResize(256),
                            ext_transforms.ExtRandomCrop(128, pad_if_needed=True),
                            ext_transforms.ExtRandomHorizontalFlip(),
                            ext_transforms.ExtToTensor(),
                            ext_transforms.ExtNormalize((0.5,), (0.5,))
                        ])),
            batch_size=args.batch_size, shuffle=True, num_workers=2)
        test_loader = torch.utils.data.DataLoader(
            CamVid(args.data_root, split='test',
                        transform=ext_transforms.ExtCompose([
                            ext_transforms.ExtResize(256),
                            ext_transforms.ExtToTensor(),
                            ext_transforms.ExtNormalize((0.5,), (0.5,))
                        ])),
            batch_size=args.test_batch_size, shuffle=False, num_workers=2)
    elif args.dataset.lower() in ['nyuv2']:
        train_loader = torch.utils.data.DataLoader(
            NYUv2(args.data_root, split='train',
                        transform=ext_transforms.ExtCompose([
                            ext_transforms.ExtResize(256),
                            ext_transforms.ExtRandomCrop(128, pad_if_needed=True),
                            ext_transforms.ExtRandomHorizontalFlip(),
                            ext_transforms.ExtToTensor(),
                            ext_transforms.ExtNormalize((0.5,), (0.5,))
                        ])),
            batch_size=args.batch_size, shuffle=True, num_workers=2)
        test_loader = torch.utils.data.DataLoader(
            NYUv2(args.data_root, split='test',
                        transform=ext_transforms.ExtCompose([
                            ext_transforms.ExtResize(256),
                            ext_transforms.ExtToTensor(),
                            ext_transforms.ExtNormalize((0.5,), (0.5,))
                        ])),
            batch_size=args.test_batch_size, shuffle=False, num_workers=2)

    return train_loader, test_loader
Ejemplo n.º 7
0
    args = parser.add_argument('-weight', type=str, required=True,
                        help='weight file path')
    parser.add_argument('-b', type=int, default=10,
                        help='batch size for dataloader')

    args = parser.parse_args()

    valid_transforms = transforms.Compose([
        transforms.Resize(settings.IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(settings.MEAN, settings.STD)
    ])

    valid_dataset = CamVid(
        settings.DATA_PATH,
        'val',
        valid_transforms
    )

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=args.b, num_workers=4)

    metrics = Metrics(valid_dataset.class_num, valid_dataset.ignore_index)

    loss_fn = nn.CrossEntropyLoss()

    net = UNet(3, valid_dataset.class_num)
    net.load_state_dict(torch.load(args.weight))
    net = net.cuda()

    net.eval()
Ejemplo n.º 8
0
                        help='initial learning rate')
    parser.add_argument('-stop_div', type=bool, default=True,
                        help='stops when loss diverges')
    parser.add_argument('-num_it', type=int, default=100,
                        help='number of iterations')
    parser.add_argument('-skip_start', type=int, default=10,
                        help='number of batches to trim from the start')
    parser.add_argument('-skip_end', type=int, default=5,
                        help='number of batches to trim from the end')
    parser.add_argument('-weight_decay', type=float,
                        default=0, help='weight decay factor')
    parser.add_argument('-net', type=str, required=True, help='network name')
    args = parser.parse_args()

    train_dataset = CamVid(
        settings.DATA_PATH,
        'train'
    )

    train_transforms = transforms.Compose([
        transforms.RandomRotation(value=train_dataset.ignore_index),
        transforms.RandomScale(value=train_dataset.ignore_index),
        transforms.RandomGaussianBlur(),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(),
        transforms.Resize(settings.IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(settings.MEAN, settings.STD),
    ])

    train_dataset.transforms = train_transforms