示例#1
0
def cityCVaux_dataloader(cached_data_file,
                         data_dir,
                         classes,
                         batch_size,
                         scaleIn,
                         size=1024,
                         num_work=6):
    if size == 1024:
        scale = [1024, 1536, 1280, 768, 512]
        crop = [32, 96, 96, 32, 12]
    elif size == 2048:
        scale = [2048, 1536, 1280, 1024, 768]
        crop = [96, 96, 64, 32, 32]

    else:
        scale = [1024, 1536, 1280, 768, 512]
        crop = [32, 100, 100, 32, 0]

    if not os.path.isfile(cached_data_file):
        dataLoad = ld.LoadData(data_dir, classes, cached_data_file)
        data = dataLoad.processData()
        if data is None:
            print('Error while pickling data. Please check.')
            exit(-1)
    else:
        data = pickle.load(open(cached_data_file, "rb"))

    trainDataset_main = cvTransforms.Compose([
        cvTransforms.Normalize(mean=data['mean'], std=data['std']),
        cvTransforms.Scale(scale[0], scale[0] // 2),  #(1024, 512),
        cvTransforms.RandomCropResize(crop[0]),  #(32),
        cvTransforms.RandomFlip(),
        cvTransforms.ToMultiTensor(scaleIn),
        #
    ])
    print("%d , %d image size train with %d crop" %
          (scale[0], scale[0] // 2, crop[0]))

    trainDataset_scale1 = cvTransforms.Compose([
        cvTransforms.Normalize(mean=data['mean'], std=data['std']),
        cvTransforms.Scale(scale[1], scale[1] // 2),  # 1536, 768
        cvTransforms.RandomCropResize(crop[1]),
        cvTransforms.RandomFlip(),
        cvTransforms.ToMultiTensor(scaleIn),
        #
    ])
    print("%d , %d image size train with %d crop" %
          (scale[1], scale[1] // 2, crop[1]))

    trainDataset_scale2 = cvTransforms.Compose([
        cvTransforms.Normalize(mean=data['mean'], std=data['std']),
        cvTransforms.Scale(scale[2], scale[2] // 2),  # 1536, 768
        cvTransforms.RandomCropResize(crop[2]),
        cvTransforms.RandomFlip(),
        cvTransforms.ToMultiTensor(scaleIn),
        #
    ])
    print("%d , %d image size train with %d crop" %
          (scale[2], scale[2] // 2, crop[2]))

    trainDataset_scale3 = cvTransforms.Compose([
        cvTransforms.Normalize(mean=data['mean'], std=data['std']),
        cvTransforms.Scale(scale[3], scale[3] // 2),  #(768, 384),
        cvTransforms.RandomCropResize(crop[3]),
        cvTransforms.RandomFlip(),
        cvTransforms.ToMultiTensor(scaleIn),
        #
    ])
    print("%d , %d image size train with %d crop" %
          (scale[3], scale[3] // 2, crop[3]))

    trainDataset_scale4 = cvTransforms.Compose([
        cvTransforms.Normalize(mean=data['mean'], std=data['std']),
        cvTransforms.Scale(scale[4], scale[4] // 2),  #(512, 256),
        cvTransforms.RandomCropResize(crop[4]),
        cvTransforms.RandomFlip(),
        cvTransforms.ToMultiTensor(scaleIn),
        #
    ])
    print("%d , %d image size train with %d crop" %
          (scale[4], scale[4] // 2, crop[4]))

    valDataset = cvTransforms.Compose([
        cvTransforms.Normalize(mean=data['mean'], std=data['std']),
        cvTransforms.Scale(scale[0], scale[0] // 2),  #(1024, 512),
        cvTransforms.ToMultiTensor(1),
        #
    ])
    print("%d , %d image size validation" % (scale[0], scale[0] // 2))

    trainLoader = torch.utils.data.DataLoader(myDataLoader.MyAuxDataset(
        data['trainIm'], data['trainAnnot'], transform=trainDataset_main),
                                              batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=num_work,
                                              pin_memory=True)

    trainLoader_scale1 = torch.utils.data.DataLoader(myDataLoader.MyAuxDataset(
        data['trainIm'], data['trainAnnot'], transform=trainDataset_scale1),
                                                     batch_size=batch_size,
                                                     shuffle=True,
                                                     num_workers=num_work,
                                                     pin_memory=True)

    trainLoader_scale2 = torch.utils.data.DataLoader(myDataLoader.MyAuxDataset(
        data['trainIm'], data['trainAnnot'], transform=trainDataset_scale2),
                                                     batch_size=batch_size,
                                                     shuffle=True,
                                                     num_workers=num_work,
                                                     pin_memory=True)

    trainLoader_scale3 = torch.utils.data.DataLoader(myDataLoader.MyAuxDataset(
        data['trainIm'], data['trainAnnot'], transform=trainDataset_scale3),
                                                     batch_size=batch_size + 4,
                                                     shuffle=True,
                                                     num_workers=num_work,
                                                     pin_memory=True)

    trainLoader_scale4 = torch.utils.data.DataLoader(myDataLoader.MyAuxDataset(
        data['trainIm'], data['trainAnnot'], transform=trainDataset_scale4),
                                                     batch_size=batch_size + 4,
                                                     shuffle=True,
                                                     num_workers=num_work,
                                                     pin_memory=True)

    valLoader = torch.utils.data.DataLoader(myDataLoader.MyAuxDataset(
        data['valIm'], data['valAnnot'], transform=valDataset),
                                            batch_size=batch_size - 2,
                                            shuffle=False,
                                            num_workers=num_work,
                                            pin_memory=True)

    return trainLoader, trainLoader_scale1, trainLoader_scale2, trainLoader_scale3, trainLoader_scale4, valLoader, data