def build_dataset_train(root, dataset, base_size, crop_size):
    data_dir = os.path.join(root, dataset)
    train_data_list = os.path.join(data_dir, dataset + '_' + 'train_list.txt')
    inform_data_file = os.path.join('./dataset/inform/', dataset + '_inform.pkl')

    # inform_data_file collect the information of mean, std and weigth_class
    if not os.path.isfile(inform_data_file):
        print("%s is not found" % (inform_data_file))
        if dataset == "cityscapes":
            dataCollect = CityscapesTrainInform(data_dir, 19, train_set_file=train_data_list,
                                                inform_data_file=inform_data_file)
        else:
            raise NotImplementedError(
                "This repository now supports two datasets: cityscapes and camvid, %s is not included" % dataset)

        datas = dataCollect.collectDataAndSave()
        if datas is None:
            print("error while pickling data. Please check.")
            exit(-1)
    else:
        datas = pickle.load(open(inform_data_file, "rb"))

    if dataset == "cityscapes":
        TrainDataSet = CityscapesTrainDataSet(data_dir, train_data_list, base_size=base_size, crop_size=crop_size,
                                        mean=datas['mean'], std=datas['std'], ignore_label=255)
        return datas, TrainDataSet
def build_dataset_train(root, dataset, base_size, crop_size, batch_size, random_scale, num_workers):
    data_dir = os.path.join(root, dataset)
    train_data_list = os.path.join(data_dir, dataset + '_' + 'train_list.txt')
    inform_data_file = os.path.join('./dataset/inform/', dataset + '_inform.pkl')

    # inform_data_file collect the information of mean, std and weigth_class
    if not os.path.isfile(inform_data_file):
        print("%s is not found" % (inform_data_file))
        if dataset == "cityscapes":
            dataCollect = CityscapesTrainInform(data_dir, 19, train_set_file=train_data_list,
                                                inform_data_file=inform_data_file)
        elif dataset == 'paris':
            dataCollect = ParisTrainInform(data_dir, 3, train_set_file=train_data_list,
                                           inform_data_file=inform_data_file)
        elif dataset == 'postdam' or dataset == 'vaihingen':
            dataCollect = IsprsTrainInform(data_dir, 6, train_set_file=train_data_list,
                                           inform_data_file=inform_data_file)
        else:
            raise NotImplementedError(
                "This repository now supports two datasets: cityscapes and camvid, %s is not included" % dataset)

        datas = dataCollect.collectDataAndSave()
        if datas is None:
            print("error while pickling data. Please check.")
            exit(-1)
    else:
        datas = pickle.load(open(inform_data_file, "rb"))

    if dataset == "cityscapes":
        trainLoader = data.DataLoader(
            CityscapesTrainDataSet(data_dir, train_data_list, base_size=base_size, crop_size=crop_size,
                                   mean=datas['mean'], std=datas['std'], ignore_label=255),
            batch_size=batch_size, shuffle=True, num_workers=num_workers,
            pin_memory=False, drop_last=True)

        return datas, trainLoader

    elif dataset == "paris":
        trainLoader = data.DataLoader(
            ParisTrainDataSet(data_dir, train_data_list, scale=random_scale, crop_size=crop_size,
                                   mean=datas['mean'], std=datas['std'], ignore_label=255),
            batch_size=batch_size, shuffle=True, num_workers=num_workers,
            pin_memory=False, drop_last=True)

        return datas, trainLoader

    elif dataset == "postdam" or dataset == 'vaihingen':
        trainLoader = data.DataLoader(
            IsprsTrainDataSet(data_dir, train_data_list, base_size=base_size, crop_size=crop_size,
                                   mean=datas['mean'], std=datas['std'], ignore_label=255),
            batch_size=batch_size, shuffle=True, num_workers=num_workers,
            pin_memory=False, drop_last=True)

        return datas, trainLoader
def build_dataset_test(root, dataset, crop_size, mode='whole', gt=False):
    data_dir = os.path.join(root, dataset)
    inform_data_file = os.path.join('./dataset/inform/', dataset + '_inform.pkl')
    train_data_list = os.path.join(data_dir, dataset + '_train_list.txt')
    if mode == 'whole':
        test_data_list = os.path.join(data_dir, dataset + '_test' + '_list.txt')
    else:
        test_data_list = os.path.join(data_dir, dataset + '_test_sliding' + '_list.txt')

    # inform_data_file collect the information of mean, std and weigth_class
    if not os.path.isfile(inform_data_file):
        print("%s is not found" % (inform_data_file))
        if dataset == "cityscapes":
            dataCollect = CityscapesTrainInform(data_dir, 19, train_set_file=train_data_list,
                                                inform_data_file=inform_data_file)
        else:
            raise NotImplementedError(
                "This repository now supports two datasets: cityscapes and camvid, %s is not included" % dataset)

        datas = dataCollect.collectDataAndSave()
        if datas is None:
            print("error while pickling data. Please check.")
            exit(-1)
    else:
        datas = pickle.load(open(inform_data_file, "rb"))

    class_dict_df = pd.read_csv(os.path.join('./dataset', dataset, 'class_map.csv'))
    if dataset == "cityscapes":
        # for cityscapes, if test on validation set, set none_gt to False
        # if test on the test set, set none_gt to True
        if gt:
            test_data_list = os.path.join(data_dir, dataset + '_val' + '_list.txt')
            testdataset = CityscapesValDataSet(data_dir, test_data_list, crop_size=crop_size, mean=datas['mean'],
                                     std=datas['std'], ignore_label=255)
        else:
            test_data_list = os.path.join(data_dir, dataset + '_test' + '_list.txt')
            testdataset = CityscapesTestDataSet(data_dir, test_data_list, crop_size=crop_size, mean=datas['mean'],
                                      std=datas['std'], ignore_label=255)
        return testdataset, class_dict_df
def build_dataset_train(dataset, input_size, batch_size, train_type,
                        random_scale, random_mirror, num_workers):
    data_dir = os.path.join('/media/ding/Data/datasets', dataset)
    train_data_list = os.path.join(data_dir,
                                   dataset + '_' + train_type + '_list.txt')
    val_data_list = os.path.join(data_dir, dataset + '_val' + '_list.txt')
    inform_data_file = os.path.join('./dataset/inform/',
                                    dataset + '_inform.pkl')

    # inform_data_file collect the information of mean, std and weigth_class
    if not os.path.isfile(inform_data_file):
        print("%s is not found" % (inform_data_file))
        if dataset == "cityscapes":
            dataCollect = CityscapesTrainInform(
                data_dir,
                19,
                train_set_file=train_data_list,
                inform_data_file=inform_data_file)
        elif dataset == 'camvid':
            dataCollect = CamVidTrainInform(data_dir,
                                            11,
                                            train_set_file=train_data_list,
                                            inform_data_file=inform_data_file)
        elif dataset == 'paris':
            dataCollect = ParisTrainInform(data_dir,
                                           3,
                                           train_set_file=train_data_list,
                                           inform_data_file=inform_data_file)
        elif dataset == 'road':
            dataCollect = ParisTrainInform(data_dir,
                                           2,
                                           train_set_file=train_data_list,
                                           inform_data_file=inform_data_file)
        else:
            raise NotImplementedError(
                "This repository now supports two datasets: cityscapes and camvid, %s is not included"
                % dataset)

        datas = dataCollect.collectDataAndSave()
        if datas is None:
            print("error while pickling data. Please check.")
            exit(-1)
    else:
        print("find file: ", str(inform_data_file))
        datas = pickle.load(open(inform_data_file, "rb"))

    if dataset == "cityscapes":

        trainLoader = data.DataLoader(CityscapesDataSet(data_dir,
                                                        train_data_list,
                                                        crop_size=input_size,
                                                        scale=random_scale,
                                                        mirror=random_mirror,
                                                        mean=datas['mean']),
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=num_workers,
                                      pin_memory=True,
                                      drop_last=True)

        valLoader = data.DataLoader(CityscapesValDataSet(data_dir,
                                                         val_data_list,
                                                         f_scale=1,
                                                         mean=datas['mean']),
                                    batch_size=1,
                                    shuffle=True,
                                    num_workers=num_workers,
                                    pin_memory=True,
                                    drop_last=True)

        return datas, trainLoader, valLoader

    elif dataset == "camvid":

        trainLoader = data.DataLoader(CamVidDataSet(data_dir,
                                                    train_data_list,
                                                    crop_size=input_size,
                                                    scale=random_scale,
                                                    mirror=random_mirror,
                                                    mean=datas['mean']),
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=num_workers,
                                      pin_memory=True,
                                      drop_last=True)

        valLoader = data.DataLoader(CamVidValDataSet(data_dir,
                                                     val_data_list,
                                                     f_scale=1,
                                                     mean=datas['mean']),
                                    batch_size=1,
                                    shuffle=True,
                                    num_workers=num_workers,
                                    pin_memory=True)

        return datas, trainLoader, valLoader

    elif dataset == "paris":

        trainLoader = data.DataLoader(ParisDataSet(data_dir,
                                                   train_data_list,
                                                   crop_size=input_size,
                                                   scale=random_scale,
                                                   mirror=random_mirror,
                                                   mean=datas['mean']),
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=num_workers,
                                      pin_memory=True,
                                      drop_last=True)

        valLoader = data.DataLoader(ParisValDataSet(data_dir,
                                                    val_data_list,
                                                    f_scale=1,
                                                    mean=datas['mean']),
                                    batch_size=1,
                                    shuffle=True,
                                    num_workers=num_workers,
                                    pin_memory=True)

        return datas, trainLoader, valLoader

    elif dataset == "road":

        trainLoader = data.DataLoader(RoadDataSet(data_dir,
                                                  train_data_list,
                                                  crop_size=input_size,
                                                  scale=random_scale,
                                                  mirror=random_mirror,
                                                  mean=datas['mean']),
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=num_workers,
                                      pin_memory=True,
                                      drop_last=True)

        valLoader = data.DataLoader(RoadValDataSet(data_dir,
                                                   val_data_list,
                                                   f_scale=1,
                                                   mean=datas['mean']),
                                    batch_size=1,
                                    shuffle=True,
                                    num_workers=num_workers,
                                    pin_memory=True)

        return datas, trainLoader, valLoader
def build_dataset_sliding_test(dataset, num_workers, none_gt=False):
    data_dir = os.path.join('/media/ding/Data/datasets', dataset)
    dataset_list = os.path.join(dataset, '_train_list.txt')
    if (dataset == 'cityscapes'):
        test_data_list = os.path.join(data_dir, dataset + '_val' + '_list.txt')
    else:
        test_data_list = os.path.join(data_dir,
                                      dataset + '_sliding_test' + '_list.txt')
    inform_data_file = os.path.join('./dataset/inform/',
                                    dataset + '_inform.pkl')

    # inform_data_file collect the information of mean, std and weigth_class
    if not os.path.isfile(inform_data_file):
        print("%s is not found" % (inform_data_file))
        if dataset == "cityscapes":
            dataCollect = CityscapesTrainInform(
                data_dir,
                19,
                train_set_file=dataset_list,
                inform_data_file=inform_data_file)
        elif dataset == 'camvid':
            dataCollect = CamVidTrainInform(data_dir,
                                            11,
                                            train_set_file=dataset_list,
                                            inform_data_file=inform_data_file)
        elif dataset == 'paris':
            dataCollect = ParisTrainInform(data_dir,
                                           3,
                                           train_set_file=dataset_list,
                                           inform_data_file=inform_data_file)
        # elif dataset == 'austin':
        #     dataCollect = AustinTrainInform(data_dir, 2, train_set_file=dataset_list,
        #                                     inform_data_file=inform_data_file)
        elif dataset == 'road':
            dataCollect = RoadTrainInform(data_dir,
                                          2,
                                          train_set_file=dataset_list,
                                          inform_data_file=inform_data_file)
        else:
            raise NotImplementedError(
                "This repository now supports two datasets: cityscapes and camvid, %s is not included"
                % dataset)

        datas = dataCollect.collectDataAndSave()
        if datas is None:
            print("error while pickling data. Please check.")
            exit(-1)
    else:
        print("find file: ", str(inform_data_file))
        datas = pickle.load(open(inform_data_file, "rb"))

    if dataset == "cityscapes":
        # for cityscapes, if test on validation set, set none_gt to False
        # if test on the test set, set none_gt to True
        if not none_gt:
            testLoader = data.DataLoader(CityscapesTestDataSet(
                data_dir, test_data_list, mean=datas['mean']),
                                         batch_size=1,
                                         shuffle=False,
                                         num_workers=num_workers,
                                         pin_memory=True)
        else:
            test_data_list = os.path.join(data_dir,
                                          dataset + '_val' + '_list.txt')
            testLoader = data.DataLoader(CityscapesValDataSet(
                data_dir, test_data_list, mean=datas['mean']),
                                         batch_size=1,
                                         shuffle=False,
                                         num_workers=num_workers,
                                         pin_memory=True)

        return datas, testLoader

    elif dataset == "camvid":

        testLoader = data.DataLoader(CamVidValDataSet(data_dir,
                                                      test_data_list,
                                                      mean=datas['mean'],
                                                      std=datas['std']),
                                     batch_size=1,
                                     shuffle=False,
                                     num_workers=num_workers,
                                     pin_memory=True)

        return datas, testLoader

    elif dataset == "paris":

        testLoader = data.DataLoader(ParisTestDataSet(data_dir,
                                                      test_data_list,
                                                      mean=datas['mean']),
                                     batch_size=1,
                                     shuffle=False,
                                     num_workers=num_workers,
                                     pin_memory=True)

        return datas, testLoader

    # elif dataset == "austin":
    #
    #     testLoader = data.DataLoader(
    #         AustinTestDataSet(data_dir, test_data_list, mean=datas['mean']),
    #         batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True)
    #
    #     return datas, testLoader

    elif dataset == "road":

        testLoader = data.DataLoader(RoadTestDataSet(data_dir,
                                                     test_data_list,
                                                     mean=datas['mean']),
                                     batch_size=1,
                                     shuffle=False,
                                     num_workers=num_workers,
                                     pin_memory=True)

        return datas, testLoader