Exemple #1
0
def portrait_multiCVdataloader(cached_data_file,
                               data_dir,
                               classes,
                               batch_size,
                               scaleIn,
                               w=180,
                               h=320,
                               edge=False,
                               num_work=4,
                               Enc=True,
                               Augset=True):

    if not os.path.isfile(cached_data_file):
        if Augset:
            additional_data = []

            additional_data.append('/Nukki/baidu_V1/')
            additional_data.append('/Nukki/baidu_V2/')

            dataLoad = ld.LoadData(data_dir,
                                   classes,
                                   cached_data_file,
                                   additional=additional_data)
            data = dataLoad.processDataAug()
        else:
            dataLoad = ld.LoadData(data_dir, classes, cached_data_file)
            data = dataLoad.processData()

        if data is None:
            print('Error while pickling data. Please check.')
            exit(-1)
    else:
        data = pickle.load(open(cached_data_file, "rb"))

    trainDataset_main = cvTransforms.Compose([
        cvTransforms.Normalize(mean=data['mean'], std=data['std']),
        cvTransforms.Scale(w, h),
        cvTransforms.RandomCropResize(32),
        cvTransforms.RandomFlip(),
        # myTransforms.RandomCrop(64).
        cvTransforms.ToTensor(scaleIn),
        #
    ])

    trainDataset_main2 = cvTransforms.Compose([
        cvTransforms.Normalize(mean=data['mean'], std=data['std']),
        cvTransforms.Scale(224, 224),
        cvTransforms.RandomCropResize(16),
        cvTransforms.RandomFlip(),
        # myTransforms.RandomCrop(64).
        cvTransforms.ToTensor(scaleIn),
        #
    ])

    trainDataset_main3 = cvTransforms.Compose([
        cvTransforms.Normalize(mean=data['mean'], std=data['std']),
        cvTransforms.Scale(int(w * 0.8), int(h * 0.8)),
        cvTransforms.RandomCropResize(24),
        cvTransforms.RandomFlip(),
        # myTransforms.RandomCrop(64).
        cvTransforms.ToTensor(scaleIn),
        #
    ])
    valDataset = cvTransforms.Compose([
        cvTransforms.Normalize(mean=data['mean'], std=data['std']),
        cvTransforms.Scale(224, 224),
        cvTransforms.ToTensor(scaleIn),
        #
    ])

    print("This stage is Enc" + str(Enc))
    print(" Load public Baidu train dataset")
    trainLoader = torch.utils.data.DataLoader(myDataLoader.CVDataset(
        data['trainIm'],
        data['trainAnnot'],
        transform=trainDataset_main,
        edge=edge,
        Enc=Enc),
                                              batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=num_work,
                                              pin_memory=True)

    trainLoader2 = torch.utils.data.DataLoader(
        myDataLoader.CVDataset(data['trainIm'],
                               data['trainAnnot'],
                               transform=trainDataset_main2,
                               edge=edge,
                               Enc=Enc),
        batch_size=int(1.5 * batch_size),
        shuffle=True,
        num_workers=num_work,
        pin_memory=True)

    trainLoader3 = torch.utils.data.DataLoader(
        myDataLoader.CVDataset(data['trainIm'],
                               data['trainAnnot'],
                               transform=trainDataset_main3,
                               edge=edge,
                               Enc=Enc),
        batch_size=int(1.8 * batch_size),
        shuffle=True,
        num_workers=num_work,
        pin_memory=True)

    print(" Load public val dataset")

    valLoader = torch.utils.data.DataLoader(myDataLoader.CVDataset(
        data['valIm'],
        data['valAnnot'],
        transform=valDataset,
        edge=True,
        Enc=Enc),
                                            batch_size=batch_size,
                                            shuffle=False,
                                            num_workers=num_work,
                                            pin_memory=True)

    return trainLoader, trainLoader2, trainLoader3, valLoader, data
Exemple #2
0
def portrait_CVdataloader(cached_data_file,
                          data_dir,
                          classes,
                          batch_size,
                          scaleIn,
                          w=180,
                          h=320,
                          edge=False,
                          num_work=4,
                          Enc=True,
                          Augset=True):

    if not os.path.isfile(cached_data_file):
        if Augset:
            additional_data = []
            additional_data.append('/Nukki/baidu_V1/')
            additional_data.append('/Nukki/baidu_V2/')

            dataLoad = ld.LoadData(data_dir,
                                   classes,
                                   cached_data_file,
                                   additional=additional_data)
            data = dataLoad.processDataAug()
        else:
            dataLoad = ld.LoadData(data_dir, classes, cached_data_file)
            data = dataLoad.processData()

        if data is None:
            print('Error while pickling data. Please check.')
            exit(-1)
    else:
        data = pickle.load(open(cached_data_file, "rb"))

    trainDataset_main = cvTransforms.Compose([
        cvTransforms.Translation(w, h),
        # cvTransforms.data_aug_light(),
        cvTransforms.data_aug_color(),
        cvTransforms.data_aug_blur(),
        cvTransforms.data_aug_noise(),
        cvTransforms.Normalize(mean=data['mean'], std=data['std']),
        cvTransforms.Scale(w, h),
        cvTransforms.ToTensor(scaleIn),
        #
    ])

    valDataset = cvTransforms.Compose([
        cvTransforms.Normalize(mean=data['mean'], std=data['std']),
        cvTransforms.Scale(w, h),
        cvTransforms.ToTensor(scaleIn),
        #
    ])

    print("This stage is Enc" + str(Enc))
    print(" Load public Baidu train dataset")
    trainLoader = torch.utils.data.DataLoader(myDataLoader.CVDataset(
        data['trainIm'],
        data['trainAnnot'],
        transform=trainDataset_main,
        edge=edge,
        Enc=Enc),
                                              batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=num_work,
                                              pin_memory=True)

    print(" Load public val dataset")

    valLoader = torch.utils.data.DataLoader(myDataLoader.CVDataset(
        data['valIm'],
        data['valAnnot'],
        transform=valDataset,
        edge=True,
        Enc=Enc),
                                            batch_size=batch_size,
                                            shuffle=False,
                                            num_workers=num_work,
                                            pin_memory=True)

    return trainLoader, valLoader, data