コード例 #1
0
    def __init__(self, args, transform, dtype):

        self.transform = transform
        self.loader = default_loader

        data_path = args.datadir
        if dtype == 'test':
            data_path += '/gallery'
        else:
            data_path += '/query'

        self.imgs = [path for path in list_pictures(data_path)]
コード例 #2
0
    def test(self):

        self.model.eval()
        qf = self.extract_feature(self.query_loader).numpy()
        gf = self.extract_feature(self.test_loader).numpy()

        print('query shape:', qf.shape)
        print('gallery shape:', gf.shape)

        # 查看提取的图像特征是否归一化
        # print(qf[0])
        # sum_qf = []
        # sum_gf = []
        # for i in range(len(qf)):
        #     sum_qf.append(np.sum(qf[i] ** 2))
        # for i in range(len(gf)):
        #     sum_gf.append(np.sum(gf[i] ** 2))
        #
        # print('sum_qf length:', len(sum_qf))
        # print('sum_gf length:', len(sum_gf))

        # 计算距离
        dist = cdist(qf, gf)
        print(dist.shape)

        # 结果保存为CSV文件
        res_df = pd.DataFrame(dist)

        data_path = args.datadir
        gallery_data_path = data_path + '/test_images'
        query_filename = 'query_test_image_name.txt'
        query_file_txt = os.path.join(data_path, query_filename)
        query_file = open(query_file_txt, 'r')
        query_inx_list = [
            path.rsplit('_', 1)[1].split('.')[0] for path in query_file
        ]

        gallery_inx_list = [
            path.rsplit('_', 1)[1].split('.')[0]
            for path in list_pictures(gallery_data_path)
        ]
        res_df.index = query_inx_list
        res_df.columns = gallery_inx_list
        res_dict_dictlist = {}
        for i in res_df.index:
            res_dict_i = pd.DataFrame(res_df.loc[i]).sort_values(
                by=[i], axis=0, ascending=True).to_dict()
            res_dict_dictlist.update(res_dict_i)

        saveCSV('PRCV2020_preid_result.csv', res_dict_dictlist)
コード例 #3
0
    def __init__(self, args, transform, dtype):

        self.transform = transform
        self.loader = default_loader

        data_path = args.datadir
        if dtype == 'train':
            data_path += '/bounding_box_train'
        elif dtype == 'test':
            data_path += '/bounding_box_test'
        else:
            data_path += '/query'

        
        self.imgs = [path for path in list_pictures(data_path) if self.id(path) != -1]

        self._id2label = {_id: idx for idx, _id in enumerate(self.unique_ids)}
コード例 #4
0
    def __init__(self, dataset_root, transform, split='train'):
        if not split in ['train', 'gallery', 'query']:
            raise Exception('Invalid dataset split.')
        self.transform = transform
        self.loader = default_loader
        self.split = split

        if split == 'train':
            data_path = ospj(dataset_root, 'bounding_box_train')
        elif split == 'gallery':
            data_path = ospj(dataset_root, 'bounding_box_test')
        elif split == 'query':
            data_path = ospj(dataset_root, 'query')
        
        self.imgs = [path for path in list_pictures(data_path)]
        if split == 'train':
            self._id2label = {_id: idx for idx, _id in enumerate(self.unique_ids)}
コード例 #5
0
    def __init__(self, args, transform, dtype):

        self.transform = transform
        self.loader = default_loader
        print('reading data !!!!!!')

        data_path = args.datadir
        if dtype == 'query':
            query_filename = 'query_test_image_name.txt'
            query_file_txt = os.path.join(data_path, query_filename)
            query_file = open(query_file_txt, 'r')

            data_path += '/query_images'
            self.imgs = [
                os.path.join(data_path, path.strip()) for path in query_file
            ]

        elif dtype == 'gallery':
            data_path += '/test_images'
            self.imgs = [path
                         for path in list_pictures(data_path)]  # img path list
        print(dtype + ' images num:', len(self.imgs))