Exemplo n.º 1
0
def SS_generator(val_inputs_dict, opts):
    
    new_tuples = []
    seen_shapes = []
    
    probablematic_nrrd_path = cfg.DIR.PROBLEMATIC_NRRD_PATH
    with open(probablematic_nrrd_path, 'rb') as f: 
        bad_model_ids = pickle.load(f)

    for cur_tup in val_inputs_dict['caption_tuples']:
        if(cur_tup[2] in bad_model_ids):
            continue

        #04379243 CLASS : TABLE
        #03001627 CLASS : CHAIR
        #if(cur_tup[1] == '04379243'):
            #continue
        cur_model_id = cur_tup[2]
        if cur_model_id not in seen_shapes:
            
            seen_shapes.append(cur_model_id)
            cur_model_id = cur_tup[2] #changed it to category instead of model id. model is is cur_tup[2]
            #cur_category = cur_tup[1]
            cur_shape = load_voxel(None, cur_model_id, opts)
            new_tuples.append((cur_model_id,cur_shape))#,cur_shape))
            
            


    caption_tuples = new_tuples
    #raw_caption_list = [tup[1] for tup in caption_tuples]
    raw_shape_list = [tup[1] for tup in caption_tuples]
    model_list = [tup[0] for tup in caption_tuples]

    return raw_shape_list, model_list
Exemplo n.º 2
0
    def get_caption_data(self, db_ind):
        """
        Get the caption data corresponding to the index specified by db_ind 
        Args:
            db_ind: the integet index corresponding to the index of the cpation in the dataset. 
        Returns: 
            cur_raw_embedding
            cur_category 
            cur_model_id 
            cur_voxel_tensor 
        """
        while True:  # This will wait until we sucessfuly get the data
            caption_tuple = self.caption_tuples[
                db_ind]  # get the caption tuple corresponding to the index db_ind
            cur_raw_embedding = caption_tuple[0]
            cur_category = caption_tuple[1]
            cur_model_id = caption_tuple[2]

            if self.is_bad_model_id(
                    cur_model_id
            ):  # check whether the model_id is a bad model_id
                db_ind = np.random.randint(self.num_data)  # choose new caption
                continue

            try:
                cur_learned_embedding = self.get_learned_embedding(
                    caption_tuple)
                cur_voxel_tensor = load_voxel(cur_category, cur_model_id,
                                              self.opts)
                # do data augmentation on the voxel
                cur_voxel_tensor = augment_voxel_tensor(
                    cur_voxel_tensor, max_noise=self.opts.train_augment_max)

                if self.class_labels is not None:
                    cur_class_label = self.class_labels[cur_category]
                else:
                    cur_class_label = None

            except FileNotFoundError:  # Retry if we donot have binvoxes
                db_ind = np.random.randint(self.num_data)
                continue
            break
        caption_data = {
            'raw_embedding': cur_raw_embedding,
            'learned_embedding': cur_learned_embedding,
            'category': cur_category,
            'model_id': cur_model_id,
            'voxel_tensor': cur_voxel_tensor,
            'class_label': cur_class_label
        }
        return caption_data
Exemplo n.º 3
0
    def __getitem__(self, idx):


        
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        db_ind = idx
        while(True):
            cur_key = self.matches_keys[db_ind]
            caption_idxs = self.caption_matches[cur_key]
            if len(caption_idxs) < self.n_captions_per_model: # until len(caption_idxs) == self.n_captions_per_model
                db_ind = np.random.randint(self.num_data) # take a random index
                continue
            

            selected_caption_idxs = random.sample(caption_idxs, k=self.n_captions_per_model)
            selected_tuples = [self.caption_tuples[idx] for idx in selected_caption_idxs] 
            cur_model_id = selected_tuples[0][2]
            cur_category = selected_tuples[0][1]

            #04379243 CLASS : TABLE
            #03001627 CLASS : CHAIR

            #if(cur_category == '04379243'):
            #    db_ind = np.random.randint(self.num_data)
            #    continue

            selected_model_ids = cur_model_id

            if cur_model_id in self.bad_model_ids:
                db_ind = np.random.randint(self.num_data)
                continue
            else:
                break 
        cur_shape = load_voxel(cur_category, cur_model_id, self.opts)

        return selected_tuples[0][0],selected_tuples[1][0], cur_shape, cur_model_id, cur_category
Exemplo n.º 4
0
def retrieval(text_encoder,
              shape_encoder,
              ret_dict,
              opts,
              ret_type='text_to_shape'):
    #text_encoder.load_state_dict(torch.load('MODELS/METRIC_and_TST/txt_enc_acc.pth'))
    #shape_encoder.load_state_dict(torch.load('MODELS/METRIC_and_TST/shape_enc_acc.pth'))

    text_encoder.load_state_dict(
        torch.load('MODELS/METRIC_ONLY/txt_enc_acc.pth'))
    shape_encoder.load_state_dict(
        torch.load('MODELS/METRIC_ONLY/shape_enc_acc.pth'))

    text_encoder.eval()
    shape_encoder.eval()

    num_of_retrieval = 50
    n_neighbors = 20

    if (ret_type == 'text_to_shape'):
        embeddings_trained = utils.open_pickle('shape_only.p')
        embeddings, model_ids = utils.create_embedding_tuples(
            embeddings_trained, embedd_type='shape')
        length_trained = len(model_ids)
        queried_captions = []

        print("start retrieval")
        num_of_captions = len(ret_dict['caption_tuples'])

        iteration = 0
        while (iteration < num_of_retrieval):

            #rand_ind = np.random.randint(0,num_of_captions)
            rand_ind = iteration
            caption_tuple = ret_dict['caption_tuples'][rand_ind]
            caption = caption_tuple[0]
            model_id = caption_tuple[2]
            queried_captions.append(caption)

            input_caption = torch.from_numpy(caption).unsqueeze(
                0).long().cuda()

            txt_embedd_output = text_encoder(input_caption)

            #add the embedding to the trained embeddings
            embeddings = np.append(embeddings,
                                   txt_embedd_output.detach().cpu().numpy(),
                                   axis=0)
            model_ids.append(model_id)

            iteration += 1

        #n_neighbors = 10
        model_ids = np.array(model_ids)
        embeddings_fused = [(i, j) for i, j in zip(model_ids, embeddings)]
        outputs_dict = {
            'caption_embedding_tuples': embeddings_fused,
            'dataset_size': len(embeddings_fused)
        }

        (embeddings_matrix, labels, num_embeddings,
         label_counter) = ut.construct_embeddings_matrix(outputs_dict)
        indices = ut._compute_nearest_neighbors_cosine(embeddings_matrix,
                                                       embeddings_matrix,
                                                       n_neighbors, True)
        important_indices = indices[-num_of_retrieval::]
        important_model_id = model_ids[-num_of_retrieval::]

        caption_file = open('Retrieval/text_to_shape/inp_captions.txt', 'w')
        counter = 0
        for q in range(num_of_retrieval):

            cur_model_id = important_model_id[q]
            all_nn = important_indices[q]
            #kick out all neighbors which are in the queried ones
            all_nn = [n for n in all_nn if n < length_trained]

            NN = np.argwhere(model_ids[all_nn] == cur_model_id)
            print(" found correct one :", NN)
            if (len(NN) < 1):
                NN = important_indices[q][0]
            else:
                counter += 1
                NN = NN[0][0]
                NN = important_indices[q][NN]

            q_caption = queried_captions[q]
            sentence = utils.convert_idx_to_words(q_caption)
            caption_file.write('{}\n'.format(sentence))
            try:
                os.mkdir('Retrieval/text_to_shape/{0}/'.format(q))
            except OSError:
                pass
            for ii, nn in enumerate(all_nn):
                q_model_id = embeddings_fused[nn][0]
                voxel_file = opts.png_dir % (q_model_id, q_model_id)
                img = mpimg.imread(voxel_file)
                imgplot = plt.imshow(img)

                plt.savefig('Retrieval/text_to_shape/{0}/{1}.png'.format(
                    q, ii))
                plt.clf()

            #q_caption = queried_captions[q]
            #q_model_id = embeddings_fused[NN][0]
            #sentence = utils.convert_idx_to_words(q_caption)
            #caption_file.write('{}\n'.format(sentence))

            #voxel_file = opts.png_dir % (q_model_id,q_model_id)
            #img = mpimg.imread(voxel_file)
            #imgplot = plt.imshow(img)
            #plt.savefig('Retrieval/text_to_shape/{0}.png'.format(q))
            #plt.clf()

        caption_file.close()
        print("ACC :", (counter / num_of_retrieval))

    elif (ret_type == 'shape_to_text'):

        embeddings_trained = utils.open_pickle('text_only.p')
        embeddings, model_ids, raw_caption = utils.create_embedding_tuples(
            embeddings_trained, embedd_type='text')

        queried_shapes = []

        print("start retrieval")
        num_of_captions = len(ret_dict['caption_tuples'])
        #num_of_retrieval = 10
        iteration = 0
        while (iteration < num_of_retrieval):

            #rand_ind = np.random.randint(0,num_of_captions)
            rand_ind = iteration
            caption_tuple = ret_dict['caption_tuples'][rand_ind]
            #caption = caption_tuple[0]
            model_id = caption_tuple[2]

            cur_shape = utils.load_voxel(None, model_id, opts)
            queried_shapes.append(cur_shape)
            input_shape = torch.from_numpy(cur_shape).unsqueeze(0).permute(
                0, 4, 1, 2, 3).float().cuda()
            #input_caption = torch.from_numpy(caption).unsqueeze(0).long().cuda()

            #txt_embedd_output = text_encoder(input_caption)
            shape_embedd_output = shape_encoder(input_shape)

            #add the embedding to the trained embeddings
            embeddings = np.append(embeddings,
                                   shape_embedd_output.detach().cpu().numpy(),
                                   axis=0)
            model_ids.append(model_id)

            iteration += 1

        model_ids = np.array(model_ids)
        embeddings_fused = [(i, j) for i, j in zip(model_ids, embeddings)]
        outputs_dict = {
            'caption_embedding_tuples': embeddings_fused,
            'dataset_size': len(embeddings_fused)
        }

        (embeddings_matrix, labels, num_embeddings,
         label_counter) = ut.construct_embeddings_matrix(outputs_dict)
        indices = ut._compute_nearest_neighbors_cosine(embeddings_matrix,
                                                       embeddings_matrix,
                                                       n_neighbors, True)
        important_indices = indices[-num_of_retrieval::]
        important_model_id = model_ids[-num_of_retrieval::]

        caption_file = open('Retrieval/shape_to_text/inp_captions.txt', 'w')
        counter = 0
        for q in range(num_of_retrieval):

            cur_model_id = important_model_id[q]
            all_nn = important_indices[q]

            all_nn = [n for n in all_nn if n < len(raw_caption)]

            NN = np.argwhere(model_ids[all_nn] == cur_model_id)
            print(" found correct one :", NN)
            if (len(NN) < 1):
                NN = all_nn[0]
            else:
                counter += 1
                NN = NN[0][0]
                NN = all_nn[NN]

            #---------------------------------------------------------
            q_shape = queried_shapes[q]
            caption_file = open(
                'Retrieval/shape_to_text/inp_captions{0}.txt'.format(q), 'w')
            for ii, nn in enumerate(all_nn):

                q_caption = raw_caption[nn].data.numpy()
                #q_model_id = embeddings_fused[nn][0]
                sentence = utils.convert_idx_to_words(q_caption)
                caption_file.write('{}\n'.format(sentence))
            caption_file.close()

            #---------------------------------------------------------

            #q_shape = queried_shapes[q]
            #q_caption = raw_caption[NN].data.numpy()
            #q_model_id = embeddings_fused[NN][0]
            #sentence = utils.convert_idx_to_words(q_caption)
            #caption_file.write('{}\n'.format(sentence))
            q_model_id = cur_model_id
            voxel_file = opts.png_dir % (q_model_id, q_model_id)
            img = mpimg.imread(voxel_file)
            imgplot = plt.imshow(img)
            plt.savefig('Retrieval/shape_to_text/{0}.png'.format(q))
            plt.clf()

        #caption_file.close()
        print("ACC :", (counter / num_of_retrieval))
    def run(self):
        """
        category and model lists dynamically change size depending on whether it is STS or TST mode 
        """
        # run the loop until exit flag is set
        while not self.exit.is_set() and self.cur < self.num_data:
            # print('{0}/{1} samples'.format(self.cur, self.num_data))

            # Ensure that the network sees (almost) all the data per epoch
            db_inds = self.get_next_minibatch()

            shapes_list = []
            captions_list = []
            category_list = []
            model_id_list = []

            for db_ind in db_inds:  # Loop through each selected shape
                selected_shapes = []
                while True:
                    # cur_key is the model id for shapenet, category for primitives
                    cur_key = self.matches_keys[db_ind]
                    caption_idxs = self.caption_matches[cur_key]

                    ## Ensure theat len(caption_idxs) >= self.n_captions_per_model
                    if len(
                            caption_idxs
                    ) < self.n_captions_per_model:  # until len(caption_idxs) == self.n_captions_per_model
                        db_ind = np.random.randint(
                            self.num_data)  # take a random index
                        continue

                    # randomly sample self.n_captions_per_model captions from caption_idxs
                    selected_caption_idxs = random.sample(
                        caption_idxs, k=self.n_captions_per_model)
                    selected_tuples = [
                        self.caption_tuples[idx]
                        for idx in selected_caption_idxs
                    ]
                    # model id is cur_key
                    cur_category, cur_model_id = self.verify_batch(
                        selected_tuples)

                    # select shapes/models
                    if self.opts.dataset == 'shapenet':
                        selected_model_ids = [cur_model_id]
                    elif self.opts.dataset == 'primitives':
                        category_model_ids = self.category2modelid[
                            cur_category]
                        selected_model_ids = random.sample(
                            category_model_ids,
                            k=self.opts.LBA_n_primitive_shapes_per_category)
                    else:
                        raise ValueError('Please select a valid dataset')

                    # append cur_shape to selected_shapes
                    # for shapenet, selected_model_ids = [cur_model_id]
                    # for primitives, category_model_ids = self.category2modelid[cur_category], and
                    # we will saample self.LBA_n_primitive_shapes_per_category models for this category
                    for cur_model_id in selected_model_ids:
                        if self.is_bad_model_id(cur_model_id):
                            db_ind = np.random.randint(self.num_data)
                            continue
                        try:
                            cur_shape = load_voxel(cur_category, cur_model_id,
                                                   self.opts)
                        except FileNotFoundError:
                            print(
                                'Error: cannot find file with the following model id: ',
                                cur_key)
                            db_ind = np.random.randint(self.num_data)
                            continue
                        selected_shapes.append(cur_shape)
                    break
                # 每个model有self.n_captions_per_model个captions
                selected_captions = [tup[0] for tup in selected_tuples]
                captions_list.extend(selected_captions)
                # 每个类(对于shapenet,选择1个),选择LBA_n_primitive_shapes_per_category个model
                for selected_shape in selected_shapes:
                    shapes_list.append(selected_shape)

                if self.opts.LBA_model_type == 'STS':
                    category_list.append(cur_category)
                    model_id_list.append(cur_model_id)
                elif self.opts.LBA_model_type == 'TST' or self.opts.LBA_model_type == 'MM':
                    cur_categories = [cur_category for _ in selected_captions
                                      ]  # 复制label self.n_captions_per_model次
                    cur_model_ids = [cur_model_id for _ in selected_captions
                                     ]  # 复制model_id self.n_captions_per_model次
                    category_list.extend(cur_categories)
                    model_id_list.extend(cur_model_ids)
                else:
                    raise ValueError('Please select a valid LBA mode')

            # Length is the number of captions
            # Index/label indicates which captions comes from the same shape
            # [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7]
            label_list = [
                x for x in range(self.n_unique_shape_categories)
                for _ in range(self.n_captions_per_model)
            ]

            batch_captions = np.array(captions_list).astype(np.int32)
            batch_shapes = np.array(shapes_list).astype(np.float32)
            # convert dim
            batch_shapes = batch_shapes.transpose(
                (0, 4, 2, 3,
                 1))  # bz x 32 x 32 x 32 x 4 -> bz x 4 x 32 x 32 x 32
            batch_label = np.array(label_list).astype(np.int32)

            # item in the batch_data is pytorch Tensor
            # the following will wait until the queue frees
            batch_data = {
                "raw_embedding_batch": batch_captions,
                'voxel_tensor_batch': batch_shapes,
                'caption_label_batch': batch_label,
                'category_list': category_list,
                'model_list': model_id_list,
            }

            # kill_processes will run okay when the item in the batch_data is not tensor
            # batch_data = {
            #    "raw_embedding_batch": batch_captions.numpy(),
            #    'voxel_tensor_batch': batch_shapes.numpy(),
            #    'caption_label_batch': batch_label.numpy(),
            #    'category_list':category_list,
            #    'model_list':model_id_list,
            #}

            self.data_queue.put(batch_data, block=True)
    def run(self):
        # Run the loop until exit flag is set
        while not self.exit.is_set() and self.cur < self.num_data:
            # Ensure that the network sees (almost) all data per epoch
            # print('{0}/{1} samples'.format(self.cur, self.num_data))

            db_inds = self.get_next_minibatch()

            data_list = []
            category_list = []  # categories
            model_list = []  # models
            shapes_list = []

            continue_while_loop = False
            for db_ind in db_inds:
                if self.mode == 'text':
                    caption_tuple = self.caption_tuples[db_ind]
                elif self.mode == 'shape':
                    cur_key = self.matches_keys[db_ind]
                    caption_idxs = self.caption_matches[cur_key]

                    # Pick the first caption tuple in the matches keys list
                    caption_tuple = self.caption_tuples[caption_idxs[0]]
                else:
                    raise ValueError('Please enter a valid LBA test mode')

                cur_category = caption_tuple[1]
                cur_model_id = caption_tuple[2]
                try:
                    cur_shape = load_voxel(cur_category, cur_model_id,
                                           self.opts)
                except FileNotFoundError:
                    assert len(db_inds) == 1
                    print('File not found.')
                    print('Category:', cur_category)
                    print('Model ID:', cur_model_id)
                    print('Skipping.')
                    db_ind = np.random.randint(
                        self.num_data)  # Choose new caption
                    continue_while_loop = True
                    break

                data_list.append(
                    caption_tuple[0])  # 0th element is the caption
                category_list.append(cur_category)
                model_list.append(cur_model_id)
                shapes_list.append(cur_shape)

            if continue_while_loop is True:
                continue
            ##################################################
            ##
            ##################################################
            batch_captions = np.array(data_list).astype(np.int32)
            batch_shapes = np.array(shapes_list).astype(np.float32)
            batch_shapes = batch_shapes.transpose(
                (0, 4, 2, 3, 1))  # bz x 4 x 32 x 32 x 32

            if self.opts.LBA_test_mode == 'text':
                # Length is number of captions
                # Index/label indicates which captions come from the same shape
                if self.opts.dataset == 'shapenet':
                    # Map IDs for each shape
                    ids = {}
                    next_id = 0
                    for model_id in model_list:
                        if model_id not in ids:
                            ids[model_id] = next_id
                            next_id += 1

                    label_list = [ids[model_id] for model_id in model_list]
                    batch_label = np.array(label_list).astype(np.int32)
                elif self.opts.dataset == 'primitives':
                    # Map IDs for each shape
                    ids = {}
                    next_id = 0
                    for category_id in category_list:
                        if category_id not in ids:
                            ids[category_id] = next_id
                            next_id += 1

                    label_list = [
                        ids[category_id] for category_id in category_list
                    ]
                    batch_label = np.array(label_list).astype(np.int32)
                else:
                    raise ValueError('Please select a valid dataset.')
            elif self.opts.LBA_test_mode == 'shape':
                batch_label = np.array(range(self.opts.batch_size))
            else:
                raise ValueError('Please select a valid LBA test phase mode.')

            batch_data = {
                'raw_embedding_batch': batch_captions,
                'voxel_tensor_batch': batch_shapes,
                'caption_label_batch': batch_label,
                'category_list': category_list,
                'model_list': model_list,
            }

            # The following will wait until the queue frees
            self.data_queue.put(batch_data, block=True)