示例#1
0
文件: utils.py 项目: zlannnn/DG-Net
def get_data_loader_list(root,
                         file_list,
                         batch_size,
                         train,
                         new_size=None,
                         height=256,
                         width=128,
                         num_workers=4,
                         crop=True):
    transform_list = [
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ]
    transform_list = [transforms.RandomCrop(
        (height, width))] + transform_list if crop else transform_list
    transform_list = [transforms.Pad(10, padding_mode='edge')
                      ] + transform_list if train else transform_list
    transform_list = [transforms.Resize(
        (height, width), interpolation=3
    )] + transform_list if new_size is not None else transform_list
    transform_list = [transforms.RandomHorizontalFlip()
                      ] + transform_list if train else transform_list
    transform = transforms.Compose(transform_list)
    dataset = ImageFilelist(root, file_list, transform=transform)
    loader = DataLoader(dataset=dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        drop_last=True,
                        num_workers=num_workers)
    return loader
def get_data_loader_list(root,
                         file_list,
                         batch_size,
                         train,
                         new_size=None,
                         height=256,
                         width=256,
                         num_workers=4,
                         crop=True):
    transform_list = [
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]
    transform_list = [transforms.RandomCrop(
        (height, width))] + transform_list if crop else transform_list
    transform_list = [transforms.Resize(
        new_size)] + transform_list if new_size is not None else transform_list
    transform_list = [transforms.RandomHorizontalFlip()
                      ] + transform_list if train else transform_list
    transform = transforms.Compose(transform_list)
    dataset = ImageFilelist(root, file_list, transform=transform)
    loader = DataLoader(dataset=dataset,
                        batch_size=batch_size,
                        shuffle=train,
                        drop_last=True,
                        num_workers=num_workers)
    return loader
示例#3
0
def get_data_loader_list(root,
                         file_list,
                         batch_size,
                         train,
                         new_size=None,
                         height=256,
                         width=256,
                         num_workers=4,
                         crop=True,
                         horizontal_flip=True,
                         shuffle=True,
                         channels=1):
    transform = my_transforms(height=height,
                              width=width,
                              new_size=new_size,
                              crop=crop,
                              horizontal_flip=horizontal_flip,
                              channels=channels)
    dataset = ImageFilelist(root, file_list, transform=transform)
    loader = DataLoader(dataset=dataset,
                        batch_size=batch_size,
                        shuffle=shuffle,
                        drop_last=True,
                        num_workers=num_workers)
    return loader
示例#4
0
def get_dataloader(opt):
    if opt.dataroot is None:
        raise ValueError(
            '`dataroot` parameter is required for dataset \"%s\"' %
            opt.dataset)

    if opt.dataset == 'folderall':
        dataset = ImageFolderAll(root=opt.dataroot,
                                 transform=transforms.Compose([
                                     transforms.Resize(opt.imageSize),
                                     transforms.CenterCrop(opt.imageSize),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5),
                                                          (0.5, 0.5, 0.5)),
                                 ]),
                                 return_paths=not opt.is_train)
        nc = 3
    elif opt.dataset == 'filelist':
        assert opt.datalist != '', 'Please specify `--datalist` if you choose `filelist` dataset mode'
        dataset = ImageFilelist(root=opt.dataroot,
                                flist=opt.datalist,
                                transform=transforms.Compose([
                                    transforms.Resize(opt.imageSize),
                                    transforms.CenterCrop(opt.imageSize),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5),
                                                         (0.5, 0.5, 0.5)),
                                ]),
                                return_paths=not opt.is_train)
        nc = 3
    elif opt.dataset == 'pairfilelist':
        assert opt.datalist != '', 'Please specify `--datalist` if you choose `pairfilelist` dataset mode'
        dataset = ImagePairFilelist(root=opt.dataroot,
                                    flist=opt.datalist,
                                    transform=transforms.Compose([
                                        transforms.Resize(opt.imageSize),
                                        transforms.CenterCrop(opt.imageSize),
                                        transforms.ToTensor(),
                                        transforms.Normalize((0.5, 0.5, 0.5),
                                                             (0.5, 0.5, 0.5)),
                                    ]),
                                    transform_matrix=transforms.Compose([
                                        transforms.ToTensor(),
                                    ]),
                                    return_paths=not opt.is_train)
        nc = 3
    else:
        raise ValueError('Dataset type is not implemented!')

    assert dataset
    assert nc > 0
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=opt.batchSize,
                                             drop_last=opt.is_train,
                                             shuffle=opt.is_train,
                                             num_workers=int(opt.workers))
    return dataloader, nc
示例#5
0
def get_data_loader_list(
    root,
    file_list,
    batch_size,
    train,
    new_size=None,
    height=256,
    width=256,
    num_workers=4,
    crop=True,
):
    """ List-based data loader with transformations
     (horizontal flip, resizing, random crop, normalization are handled)

    Arguments:
        root {str} -- path root
        file_list {str list} -- list of the file names
        batch_size {int} --
        train {bool} -- training mode

    Keyword Arguments:
        new_size {int} -- parameter for resizing (default: {None})
        height {int} -- dimension for random cropping (default: {256})
        width {int} -- dimension for random cropping (default: {256})
        num_workers {int} -- number of workers (default: {4})
        crop {bool} -- crop(default: {True})

    Returns:
        loader -- data loader with transformed dataset
    """
    transform_list = [
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
    transform_list = (
        [transforms.RandomCrop((height, width))] + transform_list if crop else transform_list
    )
    transform_list = (
        [transforms.Resize((new_size, new_size))] + transform_list
        if new_size is not None
        else transform_list
    )
    transform_list = (
        [transforms.RandomHorizontalFlip()] + transform_list if train else transform_list
    )
    transform = transforms.Compose(transform_list)
    dataset = ImageFilelist(root, file_list, transform=transform)
    loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=train,
        drop_last=True,
        num_workers=num_workers,
    )
    return loader
示例#6
0
def get_data_loader_list(root,
                         file_list,
                         batch_size,
                         train,
                         new_size=None,
                         height=256,
                         width=256,
                         num_workers=4,
                         crop=True,
                         datakind=''):
    transform_list = [
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]
    transform_list = [transforms.RandomCrop(
        (height, width))] + transform_list if crop else transform_list
    transform_list = [transforms.Resize(
        new_size)] + transform_list if new_size is not None else transform_list
    transform_list = [transforms.RandomHorizontalFlip()
                      ] + transform_list if train else transform_list
    if datakind == 'selfie2anime':
        transform_list = [
            transforms.ColorJitter(hue=0.15),
            transforms.RandomGrayscale(p=0.25),
            transforms.RandomRotation(35),
            #                          transforms.RandomTranslation,
            transforms.RandomPerspective(distortion_scale=0.35)
        ] + transform_list
    transform = transforms.Compose(transform_list)
    dataset = ImageFilelist(root, file_list, transform=transform)
    loader = DataLoader(dataset=dataset,
                        batch_size=batch_size,
                        shuffle=train,
                        drop_last=True,
                        num_workers=num_workers)
    return loader
示例#7
0
config['gpu']=1
config['trainlist'] = homePath + '/data/WIDER_v0.1/train.lst'
config['testlist'] = homePath + '/data/WIDER_v0.1/test.lst'
config['data_root'] = homePath + '/data'
config['model_name'] = 'twoLinearResnet101.save'

trans = transforms.Compose(
    [
        transforms.Resize([300,300]),
        transforms.RandomCrop([224,224])
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
    ]
)
dataSet = ImageFilelist(root =config['data_root'], flist=config['trainlist'], transform = trans)
testdataSet = ImageFilelist(root =config['data_root'] , flist=config['testlist'], transform = trans)
dataloader = torch.utils.data.DataLoader(dataSet, batch_size = config['batch_size'], shuffle=True,)
testdataloader = torch.utils.data.DataLoader(testdataSet, batch_size = config['batch_size'], shuffle=True,num_workers=8)
cerit = F.cross_entropy
net = TwoLinearModel()
print (net)
net.cuda(config['gpu'])
train_net = nn.ModuleList([net.fc1, net.fc2])
optim = torch.optim.Adam( train_net.parameters(),0.0001, )
optim1 = torch.optim.Adam(nn.ModuleList(list(net.resnet.children())[:-1] ).parameters(),0.000001, )

for epoch in range(25):
    # train
    net.train()
    for idx, (img, label) in enumerate(dataloader):