def __init__(self,
                 iccv_res_dir,
                 image_dir,
                 dataset_list,
                 lmdb_paths=None,
                 downsample_scale=0.25,
                 sampling_num=100,
                 sub_graph_nodes=24,
                 transform_func='default'):
        # sampling_count: sampling the numbers of subgraph for a dataset

        self.num_dataset = len(dataset_list)
        self.iccv_res_dir = iccv_res_dir
        self.sampling_num = sampling_num
        self.image_dir = image_dir
        self.sub_graph_nodes = sub_graph_nodes
        self.transform_func = transform_func
        self.downsample_scale = downsample_scale
        if lmdb_paths is not None:
            self.use_lmdb = True
            self.lmdb_db = LMDBModel(lmdb_paths[0])
            self.lmdb_meta = pickle.load(open(lmdb_paths[1], 'rb'))
        else:
            self.use_lmdb = False

        if self.transform_func == 'default':
            self.transform_func = transforms.Compose([
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])

        # read image list and calibration
        self.frame_list = {}
        self.K = {}
        self.dataset_names = []
        for ds in dataset_list:
            dataset_name = ds['name']
            self.dataset_names.append(dataset_name)

            frame_list = read_image_list(
                os.path.join(iccv_res_dir, dataset_name, 'ImageList.txt'))
            frame_list = [f[2:].split('.')[0].strip() for f in frame_list]

            self.frame_list[dataset_name] = frame_list

            K, img_dim = read_calibration(
                os.path.join(iccv_res_dir, dataset_name, 'calibration.txt'))
            self.K[dataset_name] = K

        self.Es = {}
        self.Cs = {}

        self.edge_sampler = {}
        self.covis_map = {}
        self.edge_local_feat_cache = {}

        print(
            '[1dsfm dataset Init] load in. and out. edges and sampling sub_graphs'
        )
        for ds in tqdm(dataset_list):
            dataset_name = ds['name']
            # eg_file_path = os.path.join(image_dir, dataset_name, 'EGs.txt')
            # bundle_file_name = os.path.join(image_dir, dataset_name, ds['bundle_file'])

            Es, Cs = read_poses(
                os.path.join(iccv_res_dir, dataset_name, 'bundle.out'))
            self.Es[dataset_name] = Es
            self.Cs[dataset_name] = Cs

            n_Cameras = len(Cs)
            inoutMat = np.load(
                os.path.join(iccv_res_dir, dataset_name, 'inoutMat.npy'))
            covis_map = np.load(
                os.path.join(iccv_res_dir, dataset_name, 'covis_map.npy'))

            with open(
                    os.path.join(iccv_res_dir, dataset_name,
                                 'edge_feat_pos_cache.bin'), 'rb') as f:
                edge_feat_pos_cache = pickle.load(f)

            # random generate a sub graph
            # todo: fix sub_graph_nodes
            gen = SamplingGenerator(n_Cameras, inoutMat)
            gen.setSamplingSize(sub_graph_nodes)
            gen.setSamplingNumber(sampling_num)
            gen.generation()
            self.edge_sampler[dataset_name] = gen
            self.covis_map[dataset_name] = covis_map
            self.edge_local_feat_cache[dataset_name] = edge_feat_pos_cache

        print('[1dsfm Init] Done')
Example #2
0

def read_lmdb(dataset, lmdb, processed_edge_dict, processed_node_dict):
    train_loader = DataLoader(dataset, num_workers=0, shuffle=True)

    pbar = tqdm(total=len(dataset))
    for sample in train_loader:
        dataset_name, idx, img_names, imgs, img_ori_dim, cam_Es, cam_Ks, _, img_id2sub_id, sub_id2img_id, _, edge_subnode_idx, edge_type, edge_local_matches_n1, edge_local_matches_n2, edge_rel_Rt = sample

        for e_i, e in enumerate(edge_subnode_idx):
            sub_n1, sub_n2 = e[0].item(), e[1].item()
            n1, n2 = sub_id2img_id[sub_n1], sub_id2img_id[sub_n2]
            node_key = '%s,%d' % (dataset_name[0], sub_n1)
            edge_key = '%s,%d-%d' % (dataset_name[0], n1, n2)

            node_feat = lmdb.read_ndarray_by_key(node_key)
            edge_feat = lmdb.read_ndarray_by_key(edge_key)

        pbar.update(1)


""" Dump to lmdb
"""
# init lmdb
lmdb = LMDBModel(out_node_edge_feat_lmdb_path, read_only=True)
if os.path.exists(out_node_edge_feat_meta_path):
    with open(out_node_edge_feat_meta_path, 'rb') as f:
        o = pickle.load(f)
        processed_edge_dict, processed_node_dict = o
read_lmdb(train_set, lmdb, processed_edge_dict, processed_node_dict)
class OneDSFMDataset(Dataset):
    def __init__(self,
                 iccv_res_dir,
                 image_dir,
                 dataset_list,
                 lmdb_paths=None,
                 downsample_scale=0.25,
                 sampling_num=100,
                 sub_graph_nodes=24,
                 transform_func='default'):
        # sampling_count: sampling the numbers of subgraph for a dataset

        self.num_dataset = len(dataset_list)
        self.iccv_res_dir = iccv_res_dir
        self.sampling_num = sampling_num
        self.image_dir = image_dir
        self.sub_graph_nodes = sub_graph_nodes
        self.transform_func = transform_func
        self.downsample_scale = downsample_scale
        if lmdb_paths is not None:
            self.use_lmdb = True
            self.lmdb_db = LMDBModel(lmdb_paths[0])
            self.lmdb_meta = pickle.load(open(lmdb_paths[1], 'rb'))
        else:
            self.use_lmdb = False

        if self.transform_func == 'default':
            self.transform_func = transforms.Compose([
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])

        # read image list and calibration
        self.frame_list = {}
        self.K = {}
        self.dataset_names = []
        for ds in dataset_list:
            dataset_name = ds['name']
            self.dataset_names.append(dataset_name)

            frame_list = read_image_list(
                os.path.join(iccv_res_dir, dataset_name, 'ImageList.txt'))
            frame_list = [f[2:].split('.')[0].strip() for f in frame_list]

            self.frame_list[dataset_name] = frame_list

            K, img_dim = read_calibration(
                os.path.join(iccv_res_dir, dataset_name, 'calibration.txt'))
            self.K[dataset_name] = K

        self.Es = {}
        self.Cs = {}

        self.edge_sampler = {}
        self.covis_map = {}
        self.edge_local_feat_cache = {}

        print(
            '[1dsfm dataset Init] load in. and out. edges and sampling sub_graphs'
        )
        for ds in tqdm(dataset_list):
            dataset_name = ds['name']
            # eg_file_path = os.path.join(image_dir, dataset_name, 'EGs.txt')
            # bundle_file_name = os.path.join(image_dir, dataset_name, ds['bundle_file'])

            Es, Cs = read_poses(
                os.path.join(iccv_res_dir, dataset_name, 'bundle.out'))
            self.Es[dataset_name] = Es
            self.Cs[dataset_name] = Cs

            n_Cameras = len(Cs)
            inoutMat = np.load(
                os.path.join(iccv_res_dir, dataset_name, 'inoutMat.npy'))
            covis_map = np.load(
                os.path.join(iccv_res_dir, dataset_name, 'covis_map.npy'))

            with open(
                    os.path.join(iccv_res_dir, dataset_name,
                                 'edge_feat_pos_cache.bin'), 'rb') as f:
                edge_feat_pos_cache = pickle.load(f)

            # random generate a sub graph
            # todo: fix sub_graph_nodes
            gen = SamplingGenerator(n_Cameras, inoutMat)
            gen.setSamplingSize(sub_graph_nodes)
            gen.setSamplingNumber(sampling_num)
            gen.generation()
            self.edge_sampler[dataset_name] = gen
            self.covis_map[dataset_name] = covis_map
            self.edge_local_feat_cache[dataset_name] = edge_feat_pos_cache

        print('[1dsfm Init] Done')

    def __len__(self):
        return self.num_dataset * self.sampling_num

    def __getitem__(self, idx):

        dataset_idx = idx / self.sampling_num
        sub_graph_id = idx % self.sampling_num  # todo, need a new random idx
        dataset_idx = int(dataset_idx)
        sub_graph_id = int(sub_graph_id)

        dataset_name = self.dataset_names[dataset_idx]
        frame_list = self.frame_list[dataset_name]
        sampled_subgraphs = self.edge_sampler[dataset_name]
        edge_local_feat_cache = self.edge_local_feat_cache[dataset_name]

        subgraph_nodes = sampled_subgraphs.sampling_node[sub_graph_id]
        subgraph_edges = sampled_subgraphs.sampling_edge[sub_graph_id]
        subgraph_label = sampled_subgraphs.sampling_edge_label[sub_graph_id]
        sub_graph_nodes = len(subgraph_nodes)

        # todo: read image
        imgs = []
        img_ori_dim = []
        cam_Es, cam_Cs, cam_Ks = [], [], []
        img_id2sub_id = {}
        sub_id2img_id = {}
        # print(dataset_name)
        for i, imageID in enumerate(subgraph_nodes):

            # image ID
            if self.use_lmdb is True:
                img_key = dataset_name + '/' + frame_list[imageID] + '.jpg'
                img = self.lmdb_db.read_ndarray_by_key(img_key, dtype=np.uint8)
                h, w = self.lmdb_meta[img_key]['dim']
                img = img.reshape(int(h * 0.5), int(w * 0.5), 3)
            else:
                img_path = os.path.join(self.image_dir, dataset_name, 'images',
                                        frame_list[imageID] + '.jpg')
                img = cv2.imread(img_path)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                h, w = img.shape[:2]
            img_ori_dim.append((h, w))
            img = cv2.resize(img,
                             dsize=(int(w * self.downsample_scale),
                                    int(h * self.downsample_scale)))
            img = img.astype(np.float32) / 255.0

            img = torch.from_numpy(img)
            img = img.permute(2, 0, 1)
            if self.transform_func is not None:
                img = self.transform_func(img)

            imgs.append(img)

            camera_C = self.Cs[dataset_name][imageID]
            cam_Cs.append(torch.from_numpy(camera_C).float())
            camera_E = self.Es[dataset_name][imageID]
            cam_Es.append(torch.from_numpy(camera_E).float())
            camera_K = self.K[dataset_name][imageID]
            cam_Ks.append(torch.from_numpy(camera_K).float())

            img_id2sub_id[imageID] = i
            sub_id2img_id[i] = imageID

        cam_Cs = torch.stack(cam_Cs, dim=0)
        cam_Es = torch.stack(cam_Es, dim=0)
        cam_Ks = torch.stack(cam_Ks, dim=0)

        # todo: read edge to adjacent matrix
        out_graph_mat = np.zeros((sub_graph_nodes, sub_graph_nodes),
                                 dtype=np.float32)
        out_covis_mat = np.zeros((sub_graph_nodes, sub_graph_nodes),
                                 dtype=np.float32)

        edge_local_matches_n1 = []
        edge_local_matches_n2 = []
        edge_subnode_idx = []
        edge_type = torch.zeros(len(subgraph_edges), dtype=torch.long)

        for i, edge in enumerate(subgraph_edges):
            reconnect_idx = (img_id2sub_id[edge[0]], img_id2sub_id[edge[1]]
                             )  # remap index to subgraph
            edge_subnode_idx.append(reconnect_idx)

            label = subgraph_label[i]
            covis_value = self.covis_map[dataset_name][edge[0], edge[1]]
            if covis_value == 0:
                covis_value = self.covis_map[dataset_name][edge[1], edge[0]]
            out_graph_mat[reconnect_idx[0], reconnect_idx[1]] = label
            out_graph_mat[reconnect_idx[1], reconnect_idx[0]] = label
            out_covis_mat[reconnect_idx[0], reconnect_idx[1]] = covis_value
            out_covis_mat[reconnect_idx[1], reconnect_idx[0]] = covis_value

            n1 = edge[0]
            n2 = edge[1]
            if '%d-%d' % (n1, n2) in edge_local_feat_cache:
                edge_cache = edge_local_feat_cache['%d-%d' % (n1, n2)]
                pts1 = torch.from_numpy(edge_cache['n1_feat_pos'])
                pts2 = torch.from_numpy(edge_cache['n2_feat_pos'])
                edge_type[i] = 1 if edge_cache['type'] == 'I' else 0
            elif '%d-%d' % (n2, n1) in edge_local_feat_cache:
                edge_cache = edge_local_feat_cache['%d-%d' % (n2, n1)]
                pts1 = torch.from_numpy(edge_cache['n2_feat_pos'])
                pts2 = torch.from_numpy(edge_cache['n1_feat_pos'])
                edge_type[i] = 1 if edge_cache['type'] == 'I' else 0

            edge_local_matches_n1.append(pts1)
            edge_local_matches_n2.append(pts2)

        out_graph_mat = torch.from_numpy(out_graph_mat)
        out_covis_mat = torch.from_numpy(out_covis_mat)

        return imgs, img_ori_dim, cam_Es, cam_Cs, cam_Ks, out_graph_mat, img_id2sub_id, sub_id2img_id, out_covis_mat, edge_subnode_idx, edge_type, edge_local_matches_n1, edge_local_matches_n2
class CaptureDataset(Dataset):
    def __init__(self,
                 iccv_res_dir,
                 image_dir,
                 dataset_list,
                 lmdb_paths=None,
                 node_edge_lmdb=None,
                 img_max_dim=480,
                 sampling_num_range=[100, 500],
                 sub_graph_nodes=24,
                 sample_res_cache=None,
                 sampling_undefined_edge=False,
                 load_img=True,
                 transform_func='default',
                 training=True):
        # sampling_count: sampling the numbers of subgraph for a dataset
        assert node_edge_lmdb is not None
        # assert lmdb_paths is not None

        self.load_img = load_img
        self.num_dataset = len(dataset_list)
        self.iccv_res_dir = iccv_res_dir
        self.image_dir = image_dir
        self.sampling_num_range = sampling_num_range
        self.sub_graph_nodes = sub_graph_nodes
        self.transform_func = transform_func
        self.img_max_dim = img_max_dim
        self.sampling_undefined_edge = sampling_undefined_edge

        if lmdb_paths is not None:
            self.use_lmdb = True
            # self.lmdb_db = LMDBModel(lmdb_paths[0])
            self.lmdb_meta = pickle.load(open(lmdb_paths[1], 'rb'))
        else:
            self.use_lmdb = False

        if self.transform_func == 'default':
            self.transform_func = transforms.Compose([
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])

        # edge and node feature lmdb
        self.edge_lmdb = LMDBModel(node_edge_lmdb['edge'],
                                   lock=False,
                                   read_only=True)
        self.node_lmdb = LMDBModel(node_edge_lmdb['node'],
                                   lock=False,
                                   read_only=True)

        # read image list and calibration
        self.frame_list = {}
        self.K = {}
        self.dataset_names = []
        for ds in dataset_list:
            dataset_name = ds['name']
            bundle_prefix = ds['bundle_prefix']

            self.dataset_names.append(dataset_name)

            frame_list = read_image_list(
                os.path.join(iccv_res_dir, dataset_name,
                             bundle_prefix + '.list.txt'))
            frame_list = [f.split('.')[0].strip() for f in frame_list]

            self.frame_list[dataset_name] = frame_list

            K, img_dim = read_calibration(
                os.path.join(iccv_res_dir, dataset_name, 'calibration.txt'))
            self.K[dataset_name] = K

        self.Es = {}
        self.Cs = {}

        self.covis_map = {}
        self.edge_local_feat_cache = {}
        self.inout_mat = {}
        self.total_sample_num = 0

        max_scene_edges = 0  # determine the max edges for ratio sampling
        min_scene_edges = 1400000
        print('[Captured dataset Init] load in. and out. edges')
        # z_flip = np.diag([1, 1, -1])
        for ds in tqdm(dataset_list):
            dataset_name = ds['name']
            bundle_prefix = ds['bundle_prefix']

            Es, Cs = read_poses(
                os.path.join(iccv_res_dir, dataset_name, bundle_prefix))
            # Es = [np.matmul(z_flip, E) for E in Es]
            self.Es[dataset_name] = Es
            self.Cs[dataset_name] = Cs

            inoutMat = np.load(
                os.path.join(iccv_res_dir, dataset_name, 'inoutMat.npy'))
            covis_map = np.load(
                os.path.join(iccv_res_dir, dataset_name, 'covis.npy'))

            with open(
                    os.path.join(iccv_res_dir, dataset_name,
                                 'edge_feat_pos_cache.bin'), 'rb') as f:
                edge_feat_pos_cache = pickle.load(f)

                # check refine_Rt:
                sampled_key = list(edge_feat_pos_cache.keys())[0]
                # if 'refine_Rt' not in edge_feat_pos_cache[sampled_key]:
                #     raise Exception('dataset: %s has no refine_Rt' % dataset_name)

            self.edge_local_feat_cache[dataset_name] = edge_feat_pos_cache
            num_edges = len(edge_feat_pos_cache)
            if num_edges > max_scene_edges:
                max_scene_edges = num_edges
            if num_edges < min_scene_edges:
                min_scene_edges = num_edges

            self.inout_mat[dataset_name] = inoutMat
            self.covis_map[dataset_name] = covis_map
            self.edge_local_feat_cache[dataset_name] = edge_feat_pos_cache

        if min_scene_edges * 40 < max_scene_edges:
            # sampling ratio from the scene has most edges should be clamped.
            max_scene_edges = 40 * min_scene_edges
        """ Sampling ---------------------------------------------------------------------------------------------------
        """
        self.edge_sampler = {}
        self.samples = []  # (dataset_id, sub-graph sample_id)

        if sample_res_cache is None or not os.path.exists(sample_res_cache):

            print('[Captured dataset Init] sampling sub_graphs')
            for ds_id, ds in enumerate(dataset_list):
                dataset_name = ds['name']
                edge_feat_pos_cache = self.edge_local_feat_cache[dataset_name]

                n_Cameras = len(self.Cs[dataset_name])
                inoutMat = self.inout_mat[dataset_name]
                # edge_feat_pos_cache = self.edge_local_feat_cache[dataset_name]
                for i in range(n_Cameras):
                    for j in range(n_Cameras):
                        if inoutMat[i, j] != 1 and (
                            ("%d-%d" % (i, j)) in edge_feat_pos_cache or
                            ("%d-%d" % (j, i)) in edge_feat_pos_cache):
                            inoutMat[i, j] = -1
                for i in range(n_Cameras):
                    for j in range(n_Cameras):
                        if ("%d-%d" % (i, j) not in edge_feat_pos_cache) and (
                                "%d-%d" % (j, i) not in edge_feat_pos_cache):
                            inoutMat[i, j] = 0
                num_edges = len(self.edge_local_feat_cache[dataset_name])

                # determine sampling number based on ratio of edges among other scenes
                sample_ratio = float(num_edges) / float(max_scene_edges)
                print('%s: Sampling Ratio: %.2f' %
                      (dataset_name, sample_ratio))
                sample_num = int(sampling_num_range[1] * sample_ratio)
                if sample_num < sampling_num_range[0]:
                    sample_num = sampling_num_range[0]
                if sample_num > sampling_num_range[1]:
                    sample_num = sampling_num_range[1]

                # todo: fix sub_graph_nodes
                gen = SamplingGenerator(n_Cameras, inoutMat)
                gen.setSamplingSize(sub_graph_nodes)
                gen.setSamplingNumber(sample_num)
                gen.generation(use_undefine=self.sampling_undefined_edge,
                               get_max_node=False)

                print("test inoutmat")
                for edges in gen.sampling_edge:
                    flag = False
                    for edge in edges:

                        # if (edge[0] == 29 and edge[1] == 21) or (edge[1] == 29 and edge[0] == 21) and dataset_name=='furniture13':
                        #     lenth = 0
                        #     print(lenth)

                        if ("%d-%d" %
                            (edge[0], edge[1]) in edge_feat_pos_cache) or (
                                "%d-%d" %
                                (edge[1], edge[0]) in edge_feat_pos_cache):
                            continue
                        else:
                            #print("%d-%d" % (edge[0], edge[1]))
                            flag = True
                    if flag:
                        print("bad")
                filtered_sampled_num = len(gen.sampling_node)
                print('[Captured dataset Init] %s: (filtered: %d, all: %d)' %
                      (dataset_name, filtered_sampled_num, num_edges))

                self.samples += [(ds_id, i)
                                 for i in range(filtered_sampled_num)]
                self.edge_sampler[dataset_name] = (gen.sampling_node,
                                                   gen.sampling_edge,
                                                   gen.sampling_edge_label)

            if sample_res_cache is not None:
                with open(sample_res_cache, 'wb') as f:
                    pickle.dump([self.samples, self.edge_sampler], f)
                print('[Captured Init] Save subgraph fast cache to %s.' %
                      sample_res_cache)

        elif os.path.exists(sample_res_cache):
            with open(sample_res_cache, 'rb') as f:
                s = pickle.load(f)
                self.samples, self.edge_sampler = s
            print('[Captured Init] Load subgraph fast cache from %s.' %
                  sample_res_cache)

        print('[Captured Init] Done, %d samples' % len(self.samples))
        print('Rt_rel_12: n2 to n1')
        # random.shuffle(self.samples)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):

        dataset_idx, sub_graph_id = self.samples[idx]

        dataset_name = self.dataset_names[dataset_idx]
        frame_list = self.frame_list[dataset_name]
        sampling_node, sampling_edge, sampling_edge_label = self.edge_sampler[
            dataset_name]
        edge_local_feat_cache = self.edge_local_feat_cache[dataset_name]

        subgraph_nodes = sampling_node[sub_graph_id]
        subgraph_edges = sampling_edge[sub_graph_id]
        subgraph_label = sampling_edge_label[sub_graph_id]
        sub_graph_nodes = len(subgraph_nodes)

        # todo: read image
        imgs = []
        img_names = []
        img_ori_dim = []
        cam_Es, cam_Cs, cam_Ks = [], [], []
        img_id2sub_id = {}
        sub_id2img_id = {}
        # print(dataset_name)
        node_feats = []

        for i, imageID in enumerate(subgraph_nodes):

            # image ID
            img_key = dataset_name + '/' + frame_list[imageID] + '.jpg'
            if self.load_img:
                if self.use_lmdb is True:
                    img = self.lmdb_db.read_ndarray_by_key(img_key,
                                                           dtype=np.uint8)
                    h, w = self.lmdb_meta[img_key]['dim']
                    res_h, res_w = self.lmdb_meta[img_key]['lmdb_dim']
                    img = img.reshape(int(res_h), int(res_w), 3)
                else:
                    img_path = os.path.join(self.image_dir, dataset_name,
                                            frame_list[imageID] + '.jpg')
                    img = cv2.imread(img_path)
                    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                    h, w = img.shape[:2]

                # resize the image
                res_h, res_w = img.shape[:2]
                min_dim = res_h if res_h < res_w else res_w
                down_factor = float(self.img_max_dim) / float(min_dim)
                img = cv2.resize(img,
                                 dsize=(int(res_w * down_factor),
                                        int(res_h * down_factor)))
                img = img.astype(np.float32) / 255.0

                img = torch.from_numpy(img)
                img = img.permute(2, 0, 1)
                if self.transform_func is not None:
                    img = self.transform_func(img)

                imgs.append(img)

            else:

                if self.use_lmdb is True:
                    h, w = self.lmdb_meta[img_key]['dim']
                    res_h, res_w = self.lmdb_meta[img_key]['lmdb_dim']

                img_ori_dim.append((h, w))
                img_names.append(img_key)

            try:
                node_feat = self.node_lmdb.read_ndarray_by_key(
                    '%s,%d' % (dataset_name, imageID))
            except Exception:
                print('No key found on %s,%d:%s' %
                      (dataset_name, imageID, img_key))

            node_feats.append(node_feat)

            camera_C = self.Cs[dataset_name][imageID]
            cam_Cs.append(torch.from_numpy(camera_C).float())
            camera_E = self.Es[dataset_name][imageID]
            cam_Es.append(torch.from_numpy(camera_E).float())
            camera_K = self.K[dataset_name][imageID]
            cam_Ks.append(torch.from_numpy(camera_K).float())

            img_id2sub_id[imageID] = i
            sub_id2img_id[i] = imageID

        node_feats = np.asarray(node_feats)
        node_feats = torch.from_numpy(node_feats)

        cam_Cs = torch.stack(cam_Cs, dim=0)
        cam_Es = torch.stack(cam_Es, dim=0)
        cam_Ks = torch.stack(cam_Ks, dim=0)

        # todo: read edge to adjacent matrix
        out_graph_mat = np.zeros((sub_graph_nodes, sub_graph_nodes),
                                 dtype=np.float32)
        out_covis_mat = np.zeros((sub_graph_nodes, sub_graph_nodes),
                                 dtype=np.float32)

        edge_local_matches_n1 = []
        edge_local_matches_n2 = []
        edge_subnode_idx = []
        edge_type = torch.zeros(len(subgraph_edges), dtype=torch.long)
        edge_rel_Rt = []
        edge_feats = []

        for i, edge in enumerate(subgraph_edges):
            # remap index to subgraph
            reconnect_idx = (img_id2sub_id[edge[0]], img_id2sub_id[edge[1]])
            edge_subnode_idx.append(reconnect_idx)

            label = subgraph_label[i]
            covis_value = self.covis_map[dataset_name][edge[0], edge[1]]
            if covis_value == 0:
                covis_value = self.covis_map[dataset_name][edge[1], edge[0]]
            out_graph_mat[reconnect_idx[0], reconnect_idx[1]] = label
            out_graph_mat[reconnect_idx[1], reconnect_idx[0]] = label
            out_covis_mat[reconnect_idx[0], reconnect_idx[1]] = covis_value
            out_covis_mat[reconnect_idx[1], reconnect_idx[0]] = covis_value

            n1 = edge[0]
            n2 = edge[1]

            if '%d-%d' % (n1, n2) in edge_local_feat_cache:
                edge_cache = edge_local_feat_cache['%d-%d' % (n1, n2)]
                # pts1 = torch.from_numpy(edge_cache['n1_feat_pos'])
                # pts2 = torch.from_numpy(edge_cache['n2_feat_pos'])
                edge_feat = self.edge_lmdb.read_ndarray_by_key(
                    key='%s,%d-%d' % (dataset_name, n1, n2))

                edge_type[i] = 1 if edge_cache['type'] == 'I' else 0
                Rt = edge_cache['Rt'].astype(np.float32)
                Rt_inv = cam_opt.camera_pose_inv(Rt[:3, :3], Rt[:3, 3])
                edge_rel_Rt.append(torch.from_numpy(Rt_inv))
            elif '%d-%d' % (n2, n1) in edge_local_feat_cache:
                edge_cache = edge_local_feat_cache['%d-%d' % (n2, n1)]
                # pts1 = torch.from_numpy(edge_cache['n2_feat_pos'])
                # pts2 = torch.from_numpy(edge_cache['n1_feat_pos'])
                try:
                    edge_feat = self.edge_lmdb.read_ndarray_by_key(
                        key='%s,%d-%d' % (dataset_name, n2, n1))
                except Exception:
                    print('No dataset name %s' % dataset_name)

                edge_type[i] = 1 if edge_cache['type'] == 'I' else 0
                Rt_n2ton1 = edge_cache['Rt'].astype(np.float32)
                # Rt_n2ton1 = cam_opt.camera_pose_inv(
                #     Rt_n1ton2[:3, :3], Rt_n1ton2[:3, 3])

                edge_rel_Rt.append(torch.from_numpy(Rt_n2ton1))
            else:
                print("edge not found")
            edge_feats.append(edge_feat)

            # edge_local_matches_n1.append(pts1)
            # edge_local_matches_n2.append(pts2)
        edge_feats = np.asarray(edge_feats)
        edge_feats = torch.from_numpy(edge_feats)

        out_graph_mat = torch.from_numpy(out_graph_mat)
        out_covis_mat = torch.from_numpy(out_covis_mat)

        if len(edge_subnode_idx) != len(edge_rel_Rt):
            raise Exception("Error")

        return idx, img_names, torch.zeros(
            1
        ), img_ori_dim, cam_Es, cam_Cs, cam_Ks, out_graph_mat, img_id2sub_id, sub_id2img_id, out_covis_mat, edge_subnode_idx, edge_type, torch.zeros(
            1), torch.zeros(1), edge_rel_Rt, node_feats, edge_feats
    def __init__(self,
                 iccv_res_dir,
                 image_dir,
                 dataset_list,
                 lmdb_paths=None,
                 node_edge_lmdb=None,
                 img_max_dim=480,
                 sampling_num_range=[100, 500],
                 sub_graph_nodes=24,
                 sample_res_cache=None,
                 sampling_undefined_edge=False,
                 load_img=True,
                 transform_func='default',
                 training=True):
        # sampling_count: sampling the numbers of subgraph for a dataset
        assert node_edge_lmdb is not None
        # assert lmdb_paths is not None

        self.load_img = load_img
        self.num_dataset = len(dataset_list)
        self.iccv_res_dir = iccv_res_dir
        self.image_dir = image_dir
        self.sampling_num_range = sampling_num_range
        self.sub_graph_nodes = sub_graph_nodes
        self.transform_func = transform_func
        self.img_max_dim = img_max_dim
        self.sampling_undefined_edge = sampling_undefined_edge

        if lmdb_paths is not None:
            self.use_lmdb = True
            # self.lmdb_db = LMDBModel(lmdb_paths[0])
            self.lmdb_meta = pickle.load(open(lmdb_paths[1], 'rb'))
        else:
            self.use_lmdb = False

        if self.transform_func == 'default':
            self.transform_func = transforms.Compose([
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])

        # edge and node feature lmdb
        self.edge_lmdb = LMDBModel(node_edge_lmdb['edge'],
                                   lock=False,
                                   read_only=True)
        self.node_lmdb = LMDBModel(node_edge_lmdb['node'],
                                   lock=False,
                                   read_only=True)

        # read image list and calibration
        self.frame_list = {}
        self.K = {}
        self.dataset_names = []
        for ds in dataset_list:
            dataset_name = ds['name']
            bundle_prefix = ds['bundle_prefix']

            self.dataset_names.append(dataset_name)

            frame_list = read_image_list(
                os.path.join(iccv_res_dir, dataset_name,
                             bundle_prefix + '.list.txt'))
            frame_list = [f.split('.')[0].strip() for f in frame_list]

            self.frame_list[dataset_name] = frame_list

            K, img_dim = read_calibration(
                os.path.join(iccv_res_dir, dataset_name, 'calibration.txt'))
            self.K[dataset_name] = K

        self.Es = {}
        self.Cs = {}

        self.covis_map = {}
        self.edge_local_feat_cache = {}
        self.inout_mat = {}
        self.total_sample_num = 0

        max_scene_edges = 0  # determine the max edges for ratio sampling
        min_scene_edges = 1400000
        print('[Captured dataset Init] load in. and out. edges')
        # z_flip = np.diag([1, 1, -1])
        for ds in tqdm(dataset_list):
            dataset_name = ds['name']
            bundle_prefix = ds['bundle_prefix']

            Es, Cs = read_poses(
                os.path.join(iccv_res_dir, dataset_name, bundle_prefix))
            # Es = [np.matmul(z_flip, E) for E in Es]
            self.Es[dataset_name] = Es
            self.Cs[dataset_name] = Cs

            inoutMat = np.load(
                os.path.join(iccv_res_dir, dataset_name, 'inoutMat.npy'))
            covis_map = np.load(
                os.path.join(iccv_res_dir, dataset_name, 'covis.npy'))

            with open(
                    os.path.join(iccv_res_dir, dataset_name,
                                 'edge_feat_pos_cache.bin'), 'rb') as f:
                edge_feat_pos_cache = pickle.load(f)

                # check refine_Rt:
                sampled_key = list(edge_feat_pos_cache.keys())[0]
                # if 'refine_Rt' not in edge_feat_pos_cache[sampled_key]:
                #     raise Exception('dataset: %s has no refine_Rt' % dataset_name)

            self.edge_local_feat_cache[dataset_name] = edge_feat_pos_cache
            num_edges = len(edge_feat_pos_cache)
            if num_edges > max_scene_edges:
                max_scene_edges = num_edges
            if num_edges < min_scene_edges:
                min_scene_edges = num_edges

            self.inout_mat[dataset_name] = inoutMat
            self.covis_map[dataset_name] = covis_map
            self.edge_local_feat_cache[dataset_name] = edge_feat_pos_cache

        if min_scene_edges * 40 < max_scene_edges:
            # sampling ratio from the scene has most edges should be clamped.
            max_scene_edges = 40 * min_scene_edges
        """ Sampling ---------------------------------------------------------------------------------------------------
        """
        self.edge_sampler = {}
        self.samples = []  # (dataset_id, sub-graph sample_id)

        if sample_res_cache is None or not os.path.exists(sample_res_cache):

            print('[Captured dataset Init] sampling sub_graphs')
            for ds_id, ds in enumerate(dataset_list):
                dataset_name = ds['name']
                edge_feat_pos_cache = self.edge_local_feat_cache[dataset_name]

                n_Cameras = len(self.Cs[dataset_name])
                inoutMat = self.inout_mat[dataset_name]
                # edge_feat_pos_cache = self.edge_local_feat_cache[dataset_name]
                for i in range(n_Cameras):
                    for j in range(n_Cameras):
                        if inoutMat[i, j] != 1 and (
                            ("%d-%d" % (i, j)) in edge_feat_pos_cache or
                            ("%d-%d" % (j, i)) in edge_feat_pos_cache):
                            inoutMat[i, j] = -1
                for i in range(n_Cameras):
                    for j in range(n_Cameras):
                        if ("%d-%d" % (i, j) not in edge_feat_pos_cache) and (
                                "%d-%d" % (j, i) not in edge_feat_pos_cache):
                            inoutMat[i, j] = 0
                num_edges = len(self.edge_local_feat_cache[dataset_name])

                # determine sampling number based on ratio of edges among other scenes
                sample_ratio = float(num_edges) / float(max_scene_edges)
                print('%s: Sampling Ratio: %.2f' %
                      (dataset_name, sample_ratio))
                sample_num = int(sampling_num_range[1] * sample_ratio)
                if sample_num < sampling_num_range[0]:
                    sample_num = sampling_num_range[0]
                if sample_num > sampling_num_range[1]:
                    sample_num = sampling_num_range[1]

                # todo: fix sub_graph_nodes
                gen = SamplingGenerator(n_Cameras, inoutMat)
                gen.setSamplingSize(sub_graph_nodes)
                gen.setSamplingNumber(sample_num)
                gen.generation(use_undefine=self.sampling_undefined_edge,
                               get_max_node=False)

                print("test inoutmat")
                for edges in gen.sampling_edge:
                    flag = False
                    for edge in edges:

                        # if (edge[0] == 29 and edge[1] == 21) or (edge[1] == 29 and edge[0] == 21) and dataset_name=='furniture13':
                        #     lenth = 0
                        #     print(lenth)

                        if ("%d-%d" %
                            (edge[0], edge[1]) in edge_feat_pos_cache) or (
                                "%d-%d" %
                                (edge[1], edge[0]) in edge_feat_pos_cache):
                            continue
                        else:
                            #print("%d-%d" % (edge[0], edge[1]))
                            flag = True
                    if flag:
                        print("bad")
                filtered_sampled_num = len(gen.sampling_node)
                print('[Captured dataset Init] %s: (filtered: %d, all: %d)' %
                      (dataset_name, filtered_sampled_num, num_edges))

                self.samples += [(ds_id, i)
                                 for i in range(filtered_sampled_num)]
                self.edge_sampler[dataset_name] = (gen.sampling_node,
                                                   gen.sampling_edge,
                                                   gen.sampling_edge_label)

            if sample_res_cache is not None:
                with open(sample_res_cache, 'wb') as f:
                    pickle.dump([self.samples, self.edge_sampler], f)
                print('[Captured Init] Save subgraph fast cache to %s.' %
                      sample_res_cache)

        elif os.path.exists(sample_res_cache):
            with open(sample_res_cache, 'rb') as f:
                s = pickle.load(f)
                self.samples, self.edge_sampler = s
            print('[Captured Init] Load subgraph fast cache from %s.' %
                  sample_res_cache)

        print('[Captured Init] Done, %d samples' % len(self.samples))
        print('Rt_rel_12: n2 to n1')
    def __init__(self,
                 iccv_res_dir,
                 image_dir,
                 dataset_list,
                 lmdb_paths=None,
                 img_max_dim=480,
                 sampling_num_range=[100, 500],
                 sub_graph_nodes=24,
                 sample_res_cache=None,
                 sampling_undefined_edge=False,
                 transform_func='default'):
        # sampling_count: sampling the numbers of subgraph for a dataset

        self.num_dataset = len(dataset_list)
        self.iccv_res_dir = iccv_res_dir
        self.image_dir = image_dir
        self.sampling_num_range = sampling_num_range
        self.sub_graph_nodes = sub_graph_nodes
        self.transform_func = transform_func
        self.img_max_dim = img_max_dim
        self.sampling_undefined_edge = sampling_undefined_edge

        if lmdb_paths is not None:
            self.use_lmdb = True
            self.lmdb_db = LMDBModel(lmdb_paths[0])
            self.lmdb_meta = pickle.load(open(lmdb_paths[1], 'rb'))
        else:
            self.use_lmdb = False

        if self.transform_func == 'default':
            self.transform_func = transforms.Compose([
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])

        # read image list and calibration
        self.frame_list = {}
        self.K = {}
        self.dataset_names = []
        for ds in dataset_list:
            dataset_name = ds['name']
            bundle_prefix = ds['bundle_prefix']

            self.dataset_names.append(dataset_name)

            frame_list = read_image_list(
                os.path.join(iccv_res_dir, dataset_name,
                             bundle_prefix + '.list.txt'))
            frame_list = [f.split('.')[0].strip() for f in frame_list]

            self.frame_list[dataset_name] = frame_list

            K, img_dim = read_calibration(
                os.path.join(iccv_res_dir, dataset_name, 'calibration.txt'))
            self.K[dataset_name] = K

        self.Es = {}
        self.Cs = {}

        self.covis_map = {}
        self.edge_local_feat_cache = {}
        self.inout_mat = {}
        self.total_sample_num = 0

        max_scene_edges = 0  # determine the max edges for ratio sampling
        min_scene_edges = 1400000
        print('[Captured dataset Init] load in. and out. edges')
        for ds in tqdm(dataset_list):
            dataset_name = ds['name']
            bundle_prefix = ds['bundle_prefix']

            Es, Cs = read_poses(
                os.path.join(iccv_res_dir, dataset_name, bundle_prefix))
            self.Es[dataset_name] = Es
            self.Cs[dataset_name] = Cs

            inoutMat = np.load(
                os.path.join(iccv_res_dir, dataset_name, 'inoutMat.npy'))
            covis_map = np.load(
                os.path.join(iccv_res_dir, dataset_name, 'covis.npy'))

            with open(
                    os.path.join(iccv_res_dir, dataset_name,
                                 'edge_feat_pos_cache.bin'), 'rb') as f:
                edge_feat_pos_cache = pickle.load(f)
            self.edge_local_feat_cache[dataset_name] = edge_feat_pos_cache
            num_edges = len(edge_feat_pos_cache)
            if num_edges > max_scene_edges:
                max_scene_edges = num_edges
            if num_edges < min_scene_edges:
                min_scene_edges = num_edges

            self.inout_mat[dataset_name] = inoutMat
            self.covis_map[dataset_name] = covis_map
            self.edge_local_feat_cache[dataset_name] = edge_feat_pos_cache

        if min_scene_edges * 6 < max_scene_edges:
            # sampling ratio from the scene has most edges should be clamped.
            max_scene_edges = 6 * min_scene_edges
        """ Sampling ---------------------------------------------------------------------------------------------------
        """
        self.edge_sampler = {}
        self.samples = []  # (dataset_id, sub-graph sample_id)

        if sample_res_cache is None or not os.path.exists(sample_res_cache):

            print('[Captured dataset Init] sampling sub_graphs')
            for ds_id, ds in enumerate(dataset_list):
                dataset_name = ds['name']

                n_Cameras = len(self.Cs[dataset_name])
                inoutMat = self.inout_mat[dataset_name]
                num_edges = len(self.edge_local_feat_cache[dataset_name])

                # determine sampling number based on ratio of edges among other scenes
                sample_ratio = num_edges / max_scene_edges
                sample_num = int(sampling_num_range[1] * sample_ratio)
                if sample_num < sampling_num_range[0]:
                    sample_num = sampling_num_range[0]
                if sample_num > sampling_num_range[1]:
                    sample_num = sampling_num_range[1]

                # todo: fix sub_graph_nodes
                gen = SamplingGenerator(n_Cameras, inoutMat)
                gen.setSamplingSize(sub_graph_nodes)
                gen.setSamplingNumber(sample_num)
                gen.generation(use_undefine=self.sampling_undefined_edge)

                filtered_sampled_num = len(gen.sampling_node)
                print('[Captured dataset Init] %s: (filtered: %d, all: %d)' %
                      (dataset_name, filtered_sampled_num, num_edges))

                self.samples += [(ds_id, i)
                                 for i in range(filtered_sampled_num)]
                self.edge_sampler[dataset_name] = (gen.sampling_node,
                                                   gen.sampling_edge,
                                                   gen.sampling_edge_label)

            if sample_res_cache is not None:
                with open(sample_res_cache, 'wb') as f:
                    pickle.dump([self.samples, self.edge_sampler], f)
                print('[Captured Init] Save subgraph fast cache to %s.' %
                      sample_res_cache)

        elif os.path.exists(sample_res_cache):
            with open(sample_res_cache, 'rb') as f:
                s = pickle.load(f)
                self.samples, self.edge_sampler = s
            print('[Captured Init] Load subgraph fast cache from %s.' %
                  sample_res_cache)

        print('[Captured Init] Done, %d samples' % len(self.samples))
class CaptureDataset(Dataset):
    def __init__(self,
                 iccv_res_dir,
                 image_dir,
                 dataset_list,
                 lmdb_paths=None,
                 img_max_dim=480,
                 sampling_num_range=[100, 500],
                 sub_graph_nodes=24,
                 sample_res_cache=None,
                 sampling_undefined_edge=False,
                 transform_func='default'):
        # sampling_count: sampling the numbers of subgraph for a dataset

        self.num_dataset = len(dataset_list)
        self.iccv_res_dir = iccv_res_dir
        self.image_dir = image_dir
        self.sampling_num_range = sampling_num_range
        self.sub_graph_nodes = sub_graph_nodes
        self.transform_func = transform_func
        self.img_max_dim = img_max_dim
        self.sampling_undefined_edge = sampling_undefined_edge

        if lmdb_paths is not None:
            self.use_lmdb = True
            self.lmdb_db = LMDBModel(lmdb_paths[0])
            self.lmdb_meta = pickle.load(open(lmdb_paths[1], 'rb'))
        else:
            self.use_lmdb = False

        if self.transform_func == 'default':
            self.transform_func = transforms.Compose([
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])

        # read image list and calibration
        self.frame_list = {}
        self.K = {}
        self.dataset_names = []
        for ds in dataset_list:
            dataset_name = ds['name']
            bundle_prefix = ds['bundle_prefix']

            self.dataset_names.append(dataset_name)

            frame_list = read_image_list(
                os.path.join(iccv_res_dir, dataset_name,
                             bundle_prefix + '.list.txt'))
            frame_list = [f.split('.')[0].strip() for f in frame_list]

            self.frame_list[dataset_name] = frame_list

            K, img_dim = read_calibration(
                os.path.join(iccv_res_dir, dataset_name, 'calibration.txt'))
            self.K[dataset_name] = K

        self.Es = {}
        self.Cs = {}

        self.covis_map = {}
        self.edge_local_feat_cache = {}
        self.inout_mat = {}
        self.total_sample_num = 0

        max_scene_edges = 0  # determine the max edges for ratio sampling
        min_scene_edges = 1400000
        print('[Captured dataset Init] load in. and out. edges')
        for ds in tqdm(dataset_list):
            dataset_name = ds['name']
            bundle_prefix = ds['bundle_prefix']

            Es, Cs = read_poses(
                os.path.join(iccv_res_dir, dataset_name, bundle_prefix))
            self.Es[dataset_name] = Es
            self.Cs[dataset_name] = Cs

            inoutMat = np.load(
                os.path.join(iccv_res_dir, dataset_name, 'inoutMat.npy'))
            covis_map = np.load(
                os.path.join(iccv_res_dir, dataset_name, 'covis.npy'))

            with open(
                    os.path.join(iccv_res_dir, dataset_name,
                                 'edge_feat_pos_cache.bin'), 'rb') as f:
                edge_feat_pos_cache = pickle.load(f)
            self.edge_local_feat_cache[dataset_name] = edge_feat_pos_cache
            num_edges = len(edge_feat_pos_cache)
            if num_edges > max_scene_edges:
                max_scene_edges = num_edges
            if num_edges < min_scene_edges:
                min_scene_edges = num_edges

            self.inout_mat[dataset_name] = inoutMat
            self.covis_map[dataset_name] = covis_map
            self.edge_local_feat_cache[dataset_name] = edge_feat_pos_cache

        if min_scene_edges * 6 < max_scene_edges:
            # sampling ratio from the scene has most edges should be clamped.
            max_scene_edges = 6 * min_scene_edges
        """ Sampling ---------------------------------------------------------------------------------------------------
        """
        self.edge_sampler = {}
        self.samples = []  # (dataset_id, sub-graph sample_id)

        if sample_res_cache is None or not os.path.exists(sample_res_cache):

            print('[Captured dataset Init] sampling sub_graphs')
            for ds_id, ds in enumerate(dataset_list):
                dataset_name = ds['name']

                n_Cameras = len(self.Cs[dataset_name])
                inoutMat = self.inout_mat[dataset_name]
                num_edges = len(self.edge_local_feat_cache[dataset_name])

                # determine sampling number based on ratio of edges among other scenes
                sample_ratio = num_edges / max_scene_edges
                sample_num = int(sampling_num_range[1] * sample_ratio)
                if sample_num < sampling_num_range[0]:
                    sample_num = sampling_num_range[0]
                if sample_num > sampling_num_range[1]:
                    sample_num = sampling_num_range[1]

                # todo: fix sub_graph_nodes
                gen = SamplingGenerator(n_Cameras, inoutMat)
                gen.setSamplingSize(sub_graph_nodes)
                gen.setSamplingNumber(sample_num)
                gen.generation(use_undefine=self.sampling_undefined_edge)

                filtered_sampled_num = len(gen.sampling_node)
                print('[Captured dataset Init] %s: (filtered: %d, all: %d)' %
                      (dataset_name, filtered_sampled_num, num_edges))

                self.samples += [(ds_id, i)
                                 for i in range(filtered_sampled_num)]
                self.edge_sampler[dataset_name] = (gen.sampling_node,
                                                   gen.sampling_edge,
                                                   gen.sampling_edge_label)

            if sample_res_cache is not None:
                with open(sample_res_cache, 'wb') as f:
                    pickle.dump([self.samples, self.edge_sampler], f)
                print('[Captured Init] Save subgraph fast cache to %s.' %
                      sample_res_cache)

        elif os.path.exists(sample_res_cache):
            with open(sample_res_cache, 'rb') as f:
                s = pickle.load(f)
                self.samples, self.edge_sampler = s
            print('[Captured Init] Load subgraph fast cache from %s.' %
                  sample_res_cache)

        print('[Captured Init] Done, %d samples' % len(self.samples))
        # random.shuffle(self.samples)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):

        dataset_idx, sub_graph_id = self.samples[idx]

        dataset_name = self.dataset_names[dataset_idx]
        frame_list = self.frame_list[dataset_name]
        sampling_node, sampling_edge, sampling_edge_label = self.edge_sampler[
            dataset_name]
        edge_local_feat_cache = self.edge_local_feat_cache[dataset_name]

        subgraph_nodes = sampling_node[sub_graph_id]
        subgraph_edges = sampling_edge[sub_graph_id]
        subgraph_label = sampling_edge_label[sub_graph_id]
        sub_graph_nodes = len(subgraph_nodes)

        # todo: read image
        imgs = []
        img_keys = []
        img_ori_dim = []
        cam_Es, cam_Cs, cam_Ks = [], [], []
        img_id2sub_id = {}
        sub_id2img_id = {}
        # print(dataset_name)
        for i, imageID in enumerate(subgraph_nodes):

            # image ID
            img_key = dataset_name + '/' + frame_list[imageID] + '.jpg'
            if self.use_lmdb is True:
                img = self.lmdb_db.read_ndarray_by_key(img_key, dtype=np.uint8)
                h, w = self.lmdb_meta[img_key]['dim']
                res_h, res_w = self.lmdb_meta[img_key]['lmdb_dim']
                img = img.reshape(int(res_h), int(res_w), 3)
            else:
                img_path = os.path.join(self.image_dir, dataset_name,
                                        frame_list[imageID] + '.jpg')
                img = cv2.imread(img_path)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                h, w = img.shape[:2]
            img_ori_dim.append((h, w))
            img_keys.append(img_key)

            # resize the image
            res_h, res_w = img.shape[:2]
            min_dim = res_h if res_h < res_w else res_w
            down_factor = float(self.img_max_dim) / float(min_dim)
            img = cv2.resize(img,
                             dsize=(int(res_w * down_factor),
                                    int(res_h * down_factor)))
            img = img.astype(np.float32) / 255.0

            img = torch.from_numpy(img)
            img = img.permute(2, 0, 1)
            if self.transform_func is not None:
                img = self.transform_func(img)

            imgs.append(img)

            camera_C = self.Cs[dataset_name][imageID]
            cam_Cs.append(torch.from_numpy(camera_C).float())
            camera_E = self.Es[dataset_name][imageID]
            cam_Es.append(torch.from_numpy(camera_E).float())
            camera_K = self.K[dataset_name][imageID]
            cam_Ks.append(torch.from_numpy(camera_K).float())

            img_id2sub_id[imageID] = i
            sub_id2img_id[i] = imageID

        cam_Cs = torch.stack(cam_Cs, dim=0)
        cam_Es = torch.stack(cam_Es, dim=0)
        cam_Ks = torch.stack(cam_Ks, dim=0)

        # todo: read edge to adjacent matrix
        out_graph_mat = np.zeros((sub_graph_nodes, sub_graph_nodes),
                                 dtype=np.float32)
        out_covis_mat = np.zeros((sub_graph_nodes, sub_graph_nodes),
                                 dtype=np.float32)

        edge_local_matches_n1 = []
        edge_local_matches_n2 = []
        edge_subnode_idx = []
        edge_type = torch.zeros(len(subgraph_edges), dtype=torch.long)

        for i, edge in enumerate(subgraph_edges):
            reconnect_idx = (img_id2sub_id[edge[0]], img_id2sub_id[edge[1]]
                             )  # remap index to subgraph
            edge_subnode_idx.append(reconnect_idx)

            label = subgraph_label[i]
            covis_value = self.covis_map[dataset_name][edge[0], edge[1]]
            if covis_value == 0:
                covis_value = self.covis_map[dataset_name][edge[1], edge[0]]
            out_graph_mat[reconnect_idx[0], reconnect_idx[1]] = label
            out_graph_mat[reconnect_idx[1], reconnect_idx[0]] = label
            out_covis_mat[reconnect_idx[0], reconnect_idx[1]] = covis_value
            out_covis_mat[reconnect_idx[1], reconnect_idx[0]] = covis_value

            n1 = edge[0]
            n2 = edge[1]
            if '%d-%d' % (n1, n2) in edge_local_feat_cache:
                edge_cache = edge_local_feat_cache['%d-%d' % (n1, n2)]
                pts1 = torch.from_numpy(edge_cache['n1_feat_pos'])
                pts2 = torch.from_numpy(edge_cache['n2_feat_pos'])
                edge_type[i] = 1 if edge_cache['type'] == 'I' else 0
            elif '%d-%d' % (n2, n1) in edge_local_feat_cache:
                edge_cache = edge_local_feat_cache['%d-%d' % (n2, n1)]
                pts1 = torch.from_numpy(edge_cache['n2_feat_pos'])
                pts2 = torch.from_numpy(edge_cache['n1_feat_pos'])
                edge_type[i] = 1 if edge_cache['type'] == 'I' else 0

            edge_local_matches_n1.append(pts1)
            edge_local_matches_n2.append(pts2)

        out_graph_mat = torch.from_numpy(out_graph_mat)
        out_covis_mat = torch.from_numpy(out_covis_mat)

        return img_keys, imgs, img_ori_dim, cam_Es, cam_Cs, cam_Ks, out_graph_mat, img_id2sub_id, sub_id2img_id, out_covis_mat, edge_subnode_idx, edge_type, edge_local_matches_n1, edge_local_matches_n2
    def __init__(self,
                 iccv_res_dir,
                 image_dir,
                 dataset_list,
                 lmdb_paths=None,
                 downsample_scale=0.25,
                 sampling_num=100,
                 sub_graph_nodes=24,
                 sample_res_cache=None,
                 manual_modification_dict=None,
                 transform_func='default'):

        # sampling_count: sampling the numbers of subgraph for a dataset
        self.num_dataset = len(dataset_list)
        self.iccv_res_dir = iccv_res_dir
        self.sampling_num = sampling_num
        self.image_dir = image_dir
        self.sub_graph_nodes = sub_graph_nodes
        self.transform_func = transform_func
        self.downsample_scale = downsample_scale
        self.manual_modification_dict = manual_modification_dict

        if lmdb_paths is not None:
            self.use_lmdb = True
            self.lmdb_db = LMDBModel(lmdb_paths[0])
            self.lmdb_meta = pickle.load(open(lmdb_paths[1], 'rb'))
        else:
            self.use_lmdb = False

        if self.transform_func == 'default':
            self.transform_func = transforms.Compose([
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])

        # read image list and calibration
        self.frame_list = {}
        self.K = {}
        self.dataset_names = []
        for ds in dataset_list:
            dataset_name = ds['name']
            self.dataset_names.append(dataset_name)

            frame_list = read_image_list(
                os.path.join(iccv_res_dir, dataset_name, 'ImageList.txt'))
            frame_list = [f[2:].split('.')[0].strip() for f in frame_list]

            self.frame_list[dataset_name] = frame_list

            K, img_dim = read_calibration(
                os.path.join(iccv_res_dir, dataset_name, 'calibration.txt'))
            self.K[dataset_name] = K

        self.Es = {}
        self.Cs = {}

        self.edge_sampler = {}

        self.e_sampling_node = dict()
        self.e_sampling_edge = dict()
        self.e_sampling_edge_label = dict()

        self.covis_map = {}
        self.edge_local_feat_cache = {}

        if sample_res_cache is None or not os.path.exists(sample_res_cache):

            print(
                '[Ambi dataset Init] load in. and out. edges and sampling sub_graphs'
            )
            for ds in tqdm(dataset_list):
                dataset_name = ds['name']
                bundle_file_path = ds['bundle_file']
                # eg_file_path = os.path.join(image_dir, dataset_name, 'EGs.txt')
                # bundle_file_name = os.path.join(image_dir, dataset_name, ds['bundle_file'])

                Es, Cs = read_poses(
                    os.path.join(iccv_res_dir, dataset_name, bundle_file_path))
                self.Es[dataset_name] = Es
                self.Cs[dataset_name] = Cs

                n_Cameras = len(Cs)
                inoutMat = np.load(
                    os.path.join(iccv_res_dir, dataset_name, 'inoutMat.npy'))
                covis_map = np.load(
                    os.path.join(iccv_res_dir, dataset_name, 'covis_map.npy'))

                with open(
                        os.path.join(iccv_res_dir, dataset_name,
                                     'edge_feat_pos_cache.bin'), 'rb') as f:
                    edge_feat_pos_cache = pickle.load(f)
                self.covis_map[dataset_name] = covis_map
                self.edge_local_feat_cache[dataset_name] = edge_feat_pos_cache

                gen = SamplingGenerator(n_Cameras, inoutMat)
                gen.setSamplingSize(sub_graph_nodes)
                gen.setSamplingNumber(sampling_num)
                gen.generation()

                self.edge_sampler[dataset_name] = (gen.sampling_node,
                                                   gen.sampling_edge,
                                                   gen.sampling_edge_label)

            if sample_res_cache is not None and not os.path.exists(
                    sample_res_cache):
                with open(sample_res_cache, 'wb') as f:
                    pickle.dump([self.edge_sampler], f)
                print('[Ambi Init] Save subgraph fast cache to %s.' %
                      sample_res_cache)

        elif os.path.exists(sample_res_cache):

            for ds in tqdm(dataset_list):
                dataset_name = ds['name']
                bundle_file_path = ds['bundle_file']

                Es, Cs = read_poses(
                    os.path.join(iccv_res_dir, dataset_name, bundle_file_path))
                self.Es[dataset_name] = Es
                self.Cs[dataset_name] = Cs

                # n_Cameras = len(Cs)
                # inoutMat = np.load(os.path.join(iccv_res_dir, dataset_name, 'inoutMat.npy'))
                covis_map = np.load(
                    os.path.join(iccv_res_dir, dataset_name, 'covis_map.npy'))

                with open(
                        os.path.join(iccv_res_dir, dataset_name,
                                     'edge_feat_pos_cache.bin'), 'rb') as f:
                    edge_feat_pos_cache = pickle.load(f)
                self.covis_map[dataset_name] = covis_map
                self.edge_local_feat_cache[dataset_name] = edge_feat_pos_cache

            with open(sample_res_cache, 'rb') as f:
                s = pickle.load(f)
                self.edge_sampler = s[0]
            print('[Ambi Init] Load subgraph fast cache from %s.' %
                  sample_res_cache)

        print('[Ambi Init] Done')
class AmbiLocalFeatDataset(Dataset):
    def __init__(self,
                 iccv_res_dir,
                 image_dir,
                 dataset_list,
                 lmdb_paths=None,
                 downsample_scale=0.25,
                 sampling_num=100,
                 sub_graph_nodes=24,
                 sample_res_cache=None,
                 manual_modification_dict=None,
                 transform_func='default'):

        # sampling_count: sampling the numbers of subgraph for a dataset
        self.num_dataset = len(dataset_list)
        self.iccv_res_dir = iccv_res_dir
        self.sampling_num = sampling_num
        self.image_dir = image_dir
        self.sub_graph_nodes = sub_graph_nodes
        self.transform_func = transform_func
        self.downsample_scale = downsample_scale
        self.manual_modification_dict = manual_modification_dict

        if lmdb_paths is not None:
            self.use_lmdb = True
            self.lmdb_db = LMDBModel(lmdb_paths[0])
            self.lmdb_meta = pickle.load(open(lmdb_paths[1], 'rb'))
        else:
            self.use_lmdb = False

        if self.transform_func == 'default':
            self.transform_func = transforms.Compose([
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])

        # read image list and calibration
        self.frame_list = {}
        self.K = {}
        self.dataset_names = []
        for ds in dataset_list:
            dataset_name = ds['name']
            self.dataset_names.append(dataset_name)

            frame_list = read_image_list(
                os.path.join(iccv_res_dir, dataset_name, 'ImageList.txt'))
            frame_list = [f[2:].split('.')[0].strip() for f in frame_list]

            self.frame_list[dataset_name] = frame_list

            K, img_dim = read_calibration(
                os.path.join(iccv_res_dir, dataset_name, 'calibration.txt'))
            self.K[dataset_name] = K

        self.Es = {}
        self.Cs = {}

        self.edge_sampler = {}

        self.e_sampling_node = dict()
        self.e_sampling_edge = dict()
        self.e_sampling_edge_label = dict()

        self.covis_map = {}
        self.edge_local_feat_cache = {}

        if sample_res_cache is None or not os.path.exists(sample_res_cache):

            print(
                '[Ambi dataset Init] load in. and out. edges and sampling sub_graphs'
            )
            for ds in tqdm(dataset_list):
                dataset_name = ds['name']
                bundle_file_path = ds['bundle_file']
                # eg_file_path = os.path.join(image_dir, dataset_name, 'EGs.txt')
                # bundle_file_name = os.path.join(image_dir, dataset_name, ds['bundle_file'])

                Es, Cs = read_poses(
                    os.path.join(iccv_res_dir, dataset_name, bundle_file_path))
                self.Es[dataset_name] = Es
                self.Cs[dataset_name] = Cs

                n_Cameras = len(Cs)
                inoutMat = np.load(
                    os.path.join(iccv_res_dir, dataset_name, 'inoutMat.npy'))
                covis_map = np.load(
                    os.path.join(iccv_res_dir, dataset_name, 'covis_map.npy'))

                with open(
                        os.path.join(iccv_res_dir, dataset_name,
                                     'edge_feat_pos_cache.bin'), 'rb') as f:
                    edge_feat_pos_cache = pickle.load(f)
                self.covis_map[dataset_name] = covis_map
                self.edge_local_feat_cache[dataset_name] = edge_feat_pos_cache

                gen = SamplingGenerator(n_Cameras, inoutMat)
                gen.setSamplingSize(sub_graph_nodes)
                gen.setSamplingNumber(sampling_num)
                gen.generation()

                self.edge_sampler[dataset_name] = (gen.sampling_node,
                                                   gen.sampling_edge,
                                                   gen.sampling_edge_label)

            if sample_res_cache is not None and not os.path.exists(
                    sample_res_cache):
                with open(sample_res_cache, 'wb') as f:
                    pickle.dump([self.edge_sampler], f)
                print('[Ambi Init] Save subgraph fast cache to %s.' %
                      sample_res_cache)

        elif os.path.exists(sample_res_cache):

            for ds in tqdm(dataset_list):
                dataset_name = ds['name']
                bundle_file_path = ds['bundle_file']

                Es, Cs = read_poses(
                    os.path.join(iccv_res_dir, dataset_name, bundle_file_path))
                self.Es[dataset_name] = Es
                self.Cs[dataset_name] = Cs

                # n_Cameras = len(Cs)
                # inoutMat = np.load(os.path.join(iccv_res_dir, dataset_name, 'inoutMat.npy'))
                covis_map = np.load(
                    os.path.join(iccv_res_dir, dataset_name, 'covis_map.npy'))

                with open(
                        os.path.join(iccv_res_dir, dataset_name,
                                     'edge_feat_pos_cache.bin'), 'rb') as f:
                    edge_feat_pos_cache = pickle.load(f)
                self.covis_map[dataset_name] = covis_map
                self.edge_local_feat_cache[dataset_name] = edge_feat_pos_cache

            with open(sample_res_cache, 'rb') as f:
                s = pickle.load(f)
                self.edge_sampler = s[0]
            print('[Ambi Init] Load subgraph fast cache from %s.' %
                  sample_res_cache)

        print('[Ambi Init] Done')

    def __len__(self):
        return self.num_dataset * self.sampling_num

    def __getitem__(self, idx):

        dataset_idx = idx / self.sampling_num
        sub_graph_id = idx % self.sampling_num  # todo, need a new random idx
        dataset_idx = int(dataset_idx)
        sub_graph_id = int(sub_graph_id)

        dataset_name = self.dataset_names[dataset_idx]
        frame_list = self.frame_list[dataset_name]
        edge_local_feat_cache = self.edge_local_feat_cache[dataset_name]
        sampling_node, sampling_edge, sampling_edge_label = self.edge_sampler[
            dataset_name]
        subgraph_nodes = sampling_node[sub_graph_id]
        subgraph_edges = sampling_edge[sub_graph_id]
        subgraph_label = sampling_edge_label[sub_graph_id]
        sub_graph_nodes = len(subgraph_nodes)

        # manual linkage add
        key = '%d_%d' % (dataset_idx, sub_graph_id)
        if self.manual_modification_dict is not None and key in self.manual_modification_dict:
            manual_linkage_list = self.manual_modification_dict[key]
        else:
            manual_linkage_list = None

        # todo: read image
        imgs = []
        img_keys = []
        img_ori_dim = []
        cam_Es, cam_Cs, cam_Ks = [], [], []
        img_id2sub_id = {}
        sub_id2img_id = {}
        # print(dataset_name)
        for i, imageID in enumerate(subgraph_nodes):

            # image ID
            img_key = dataset_name + '/' + frame_list[imageID] + '.jpg'

            if self.use_lmdb is True:
                img = self.lmdb_db.read_ndarray_by_key(img_key, dtype=np.uint8)
                h, w = self.lmdb_meta[img_key]['OriDim']
                img = img.reshape(int(h * 0.5), int(w * 0.5), 3)
            else:
                img_path = os.path.join(self.image_dir, dataset_name,
                                        frame_list[imageID] + '.jpg')
                img = cv2.imread(img_path)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                h, w = img.shape[:2]
            img_ori_dim.append((h, w))
            img_keys.append(img_key)
            img = cv2.resize(img,
                             dsize=(int(w * self.downsample_scale),
                                    int(h * self.downsample_scale)))
            img = img.astype(np.float32) / 255.0

            img = torch.from_numpy(img)
            img = img.permute(2, 0, 1)
            if self.transform_func is not None:
                img = self.transform_func(img)

            imgs.append(img)

            camera_C = self.Cs[dataset_name][imageID]
            cam_Cs.append(torch.from_numpy(camera_C).float())
            camera_E = self.Es[dataset_name][imageID]
            cam_Es.append(torch.from_numpy(camera_E).float())
            camera_K = self.K[dataset_name][imageID]
            cam_Ks.append(torch.from_numpy(camera_K).float())

            img_id2sub_id[imageID] = i
            sub_id2img_id[i] = imageID

        cam_Cs = torch.stack(cam_Cs, dim=0)
        cam_Es = torch.stack(cam_Es, dim=0)
        cam_Ks = torch.stack(cam_Ks, dim=0)

        # todo: read edge to adjacent matrix
        out_graph_mat = np.zeros((sub_graph_nodes, sub_graph_nodes),
                                 dtype=np.float32)
        out_covis_mat = np.zeros((sub_graph_nodes, sub_graph_nodes),
                                 dtype=np.float32)

        edge_local_matches_n1 = []
        edge_local_matches_n2 = []
        edge_subnode_idx = []
        num_manual_add_edges = len(
            manual_linkage_list) if manual_linkage_list is not None else 0
        edge_type = torch.zeros(len(subgraph_edges) + num_manual_add_edges,
                                dtype=torch.long)

        for i, edge in enumerate(subgraph_edges):
            reconnect_idx = (img_id2sub_id[edge[0]], img_id2sub_id[edge[1]]
                             )  # remap index to subgraph
            edge_subnode_idx.append(reconnect_idx)

            label = subgraph_label[i]
            covis_value = self.covis_map[dataset_name][edge[0], edge[1]]
            if covis_value == 0:
                covis_value = self.covis_map[dataset_name][edge[1], edge[0]]
            out_graph_mat[reconnect_idx[0], reconnect_idx[1]] = label
            out_graph_mat[reconnect_idx[1], reconnect_idx[0]] = label
            out_covis_mat[reconnect_idx[0], reconnect_idx[1]] = covis_value
            out_covis_mat[reconnect_idx[1], reconnect_idx[0]] = covis_value

            n1 = edge[0]
            n2 = edge[1]

            if '%d-%d' % (n1, n2) in edge_local_feat_cache:
                edge_cache = edge_local_feat_cache['%d-%d' % (n1, n2)]
                pts1 = torch.from_numpy(edge_cache['n1_feat_pos'])
                pts2 = torch.from_numpy(edge_cache['n2_feat_pos'])
                edge_type[i] = 1 if edge_cache['type'] == 'I' else 0
            elif '%d-%d' % (n2, n1) in edge_local_feat_cache:
                edge_cache = edge_local_feat_cache['%d-%d' % (n2, n1)]
                pts1 = torch.from_numpy(edge_cache['n2_feat_pos'])
                pts2 = torch.from_numpy(edge_cache['n1_feat_pos'])
                edge_type[i] = 1 if edge_cache['type'] == 'I' else 0

            edge_local_matches_n1.append(pts1.float())
            edge_local_matches_n2.append(pts2.float())

        # manual added edges
        if manual_linkage_list is not None:
            for e_i, sub_e in enumerate(manual_linkage_list):
                reconnect_idx = (sub_e[0], sub_e[1])  # remap index to subgraph
                edge_subnode_idx.append(reconnect_idx)
                n1 = sub_id2img_id[sub_e[0]]
                n2 = sub_id2img_id[sub_e[1]]
                label = sub_e[2]

                covis_value = self.covis_map[dataset_name][n1, n2]
                if covis_value == 0:
                    covis_value = self.covis_map[dataset_name][n2, n1]
                out_graph_mat[reconnect_idx[0], reconnect_idx[1]] = label
                out_graph_mat[reconnect_idx[1], reconnect_idx[0]] = label
                out_covis_mat[reconnect_idx[0], reconnect_idx[1]] = covis_value
                out_covis_mat[reconnect_idx[1], reconnect_idx[0]] = covis_value

                if '%d-%d' % (n1, n2) in edge_local_feat_cache:
                    edge_cache = edge_local_feat_cache['%d-%d' % (n1, n2)]
                    pts1 = torch.from_numpy(edge_cache['n1_feat_pos'])
                    pts2 = torch.from_numpy(edge_cache['n2_feat_pos'])
                    edge_type[len(subgraph_edges) +
                              e_i] = 1 if edge_cache['type'] == 'I' else 0
                    edge_local_matches_n1.append(pts1.float())
                    edge_local_matches_n2.append(pts2.float())

                elif '%d-%d' % (n2, n1) in edge_local_feat_cache:
                    edge_cache = edge_local_feat_cache['%d-%d' % (n2, n1)]
                    pts1 = torch.from_numpy(edge_cache['n2_feat_pos'])
                    pts2 = torch.from_numpy(edge_cache['n1_feat_pos'])
                    edge_type[len(subgraph_edges) +
                              e_i] = 1 if edge_cache['type'] == 'I' else 0
                    edge_local_matches_n1.append(pts1.float())
                    edge_local_matches_n2.append(pts2.float())

                else:
                    print('ERROR ADD' + str(sub_e))

        out_graph_mat = torch.from_numpy(out_graph_mat)
        out_covis_mat = torch.from_numpy(out_covis_mat)

        meta_dict = {'dataset_idx': dataset_idx, 'sub_graph_id': sub_graph_id}
        return img_keys, meta_dict, imgs, img_ori_dim, cam_Es, cam_Cs, cam_Ks, out_graph_mat, img_id2sub_id, sub_id2img_id, out_covis_mat, edge_subnode_idx, edge_type, edge_local_matches_n1, edge_local_matches_n2
""" Network ------------------------------------------------------------------------------------------------------------
"""
train_params = TrainParameters()
train_params.DEV_IDS = run_dev_ids
train_params.VERBOSE_MODE = True
prior_box = LocalGlobalGATTrainBox_Prior(train_params=train_params, ckpt_path_dict={
    'vlad': '/mnt/Exp_4/valid_cache/netvlad_vgg16.tar',
    'ckpt': '/mnt/Exp_4/valid_cache/iter_nogat.pth.tar'
})
prior_box._prepare_eval()


""" Pipeline -----------------------------------------------------------------------------------------------------------
"""
img_lmdb = LMDBModel(lmdb_path, lock=False, read_only=True)
img_lmdb_meta = pickle.load(open(lmdb_meta_path, 'rb'))

# load data
for dataset in dataset_list:
    dataset_name = dataset['name']
    bundle_prefix = dataset['bundle_prefix']

    print("Processing on %s" % dataset_name)

    # load edge cache
    with open(os.path.join(cap_res_dir, dataset_name, 'edge_feat_pos_cache.bin'), 'rb') as f:
        edge_cache = pickle.load(f)

    # load frame list
    frame_list = read_image_list(os.path.join(cap_res_dir, dataset_name, bundle_prefix + '.list.txt'))