예제 #1
0
 def _build_category2label(self):
     if cfg.CONST.DATASET == 'shapenet':
         category2label = {'03001627': 0, '04379243': 1}
     elif cfg.CONST.DATASET == 'primitives':
         train_inputs_dict = open_pickle(cfg.DIR.PRIMITIVES_TRAIN_DATA_PATH)
         val_inputs_dict = open_pickle(cfg.DIR.PRIMITIVES_VAL_DATA_PATH)
         test_inputs_dict = open_pickle(cfg.DIR.PRIMITIVES_TEST_DATA_PATH)
         f = lambda inputs_dict: list(inputs_dict['category_matches'].keys(
         ))
         categories = f(train_inputs_dict) + f(val_inputs_dict) + f(
             test_inputs_dict)
         categories = list(set(categories))
         category2label = {cat: idx for idx, cat in enumerate(categories)}
     else:
         raise ValueError('Please select a valid dataset.')
     return category2label
예제 #2
0
    def build_architecture(self, inputs_dict):
        x = inputs_dict['shape_batch']
        if cfg.CONST.DATASET == 'shapenet':
            num_classes = 2  # Chair/table classification
        elif cfg.CONST.DATASET == 'primitives':
            train_inputs_dict = open_pickle(cfg.DIR.PRIMITIVES_TRAIN_DATA_PATH)
            val_inputs_dict = open_pickle(cfg.DIR.PRIMITIVES_VAL_DATA_PATH)
            test_inputs_dict = open_pickle(cfg.DIR.PRIMITIVES_TEST_DATA_PATH)
            f = lambda inputs_dict: list(inputs_dict['category_matches'].keys())
            categories = f(train_inputs_dict) + f(val_inputs_dict) + f(test_inputs_dict)
            categories = list(set(categories))
            num_classes = len(categories)
        else:
            raise ValueError('Please select a valid dataset')

        x = layers.conv3d(x, 64, 3, strides=2, padding='same', name='conv1', reuse=self.reuse)
        x = tf.layers.batch_normalization(x, training=self.is_training, name='conv1_batch_norm',
                                          reuse=self.reuse)
        x = layers.relu(x, name='conv1_relu')
        x = layers.conv3d(x, 128, 3, strides=2, padding='same', name='conv2', reuse=self.reuse)
        x = tf.layers.batch_normalization(x, training=self.is_training, name='conv2_batch_norm',
                                          reuse=self.reuse)
        x = layers.relu(x, name='conv2_relu')
        x = layers.conv3d(x, 256, 3, strides=2, padding='same', name='conv3', reuse=self.reuse)
        x = tf.layers.batch_normalization(x, training=self.is_training, name='conv3_batch_norm',
                                          reuse=self.reuse)
        x = layers.relu(x, name='conv3_relu')
        # get intermediate results
        intermediate_output = x
        x = layers.avg_pooling3d(x, name='avg_pool4')
        x = layers.dense(x, 128, name='fc5', reuse=self.reuse)
        encoder_output = x
        x = layers.dense(x, num_classes, name='fc6', reuse=self.reuse)
        prob = layers.softmax(x, name='softmax_layer')

        output_dict = {
            'logits': x,
            'probabilities': prob,
            'encoder_output': encoder_output,
            'intermediate_output': intermediate_output
        }

        return output_dict
def get_inputs_dict(opts):
    """
    Gets the input dict for the current model and dataset.
    """
    if opts.dataset == 'shapenet':
        pass 
        # if (args.text_encoder is True) or (args.end2end is True) or (args.classifier is True):
        #    inputs_dict = utils.open_pickle(cfg.DIR.TRAIN_DATA_PATH)
        #    val_inputs_dict = utils.open_pickle(cfg.DIR.VAL_DATA_PATH)
        #    test_inputs_dict = utils.open_pickle(cfg.DIR.TEST_DATA_PATH)
        # else:  # Learned embeddings
        #    inputs_dict = utils.open_pickle(cfg.DIR.SHAPENET_METRIC_EMBEDDINGS_TRAIN)
        #    val_inputs_dict = utils.open_pickle(cfg.DIR.SHAPENET_METRIC_EMBEDDINGS_VAL)
        #    test_inputs_dict = utils.open_pickle(cfg.DIR.SHAPENET_METRIC_EMBEDDINGS_TEST)
    
    elif opts.dataset == 'primitives':
        # Primitive dataset 
        if ((opts.synth_embedding is True) or (opts.text_encoder is True) or (opts.classifier is True)):
            
            if opts.classifier and not opts.reed_classifier:  # Train on all splits for classifier
                # tf.logging.info('Using all (train/val/test) splits for training')
                # logging using all (trian/val/test) splits for training 
                print('using all (train/val/test) splits for training.')
                inputs_dict = utils.open_pickle(opts.primitives_all_splits_data_path)
            else:
                print('training using train split only.')
                inputs_dict = utils.open_pickle(opts.primitives_train_data_path)
            val_inputs_dict = utils.open_pickle(opts.primitives_val_data_path)
            test_inputs_dict = utils.open_pickle(opts.primitives_test_data_path)
        else:  # Learned embeddings
            inputs_dict = utils.open_pickle(opts.primitives_metric_embeddings_train)
            val_inputs_dict = utils.open_pickle(opts.primitives_metric_embeddings_val)
            test_inputs_dict = utils.open_pickle(opts.primitives_metric_embeddings_test)
    else:
        raise ValueError('Please use a valid dataset (shapenet, primitives).')

    # Select the validation/test split
    if opts.val_split == 'train':
        val_split_str = 'train' 
        val_inputs_dict = inputs_dict
    elif (opts.val_split == 'val') or (opts.val_split is None):
        val_split_str = 'val'
        val_inputs_dict = val_inputs_dict
    elif opts.val_split == 'test':
        val_split_str = 'test'
        val_inputs_dict = test_inputs_dict
    else:
        raise ValueError('Please select a valid split (train, val, test).')

    print('Validation/testing on {} split.'.format(val_split_str))

    if opts.dataset == 'shapenet' and opts.shapenet_ct_classifier is True: 
        pass 

    return inputs_dict, val_inputs_dict
예제 #4
0
def test_lba_process():
    from multiprocessing import Queue
    from lib.config import cfg
    from lib.utils import open_pickle, print_sentences, get_json_path

    cfg.CONST.BATCH_SIZE = 8
    cfg.CONST.DATASET = 'shapenet'
    cfg.CONST.SYNTH_EMBEDDING = False

    caption_data = open_pickle(cfg.DIR.VAL_DATA_PATH)
    data_queue = Queue(3)
    json_path = get_json_path()

    data_process = LBADataProcess(data_queue, caption_data, repeat=True)
    data_process.start()
    caption_batch = data_queue.get()
    category_list = caption_batch['category_list']
    model_list = caption_batch['model_list']

    for k, v in caption_batch.items():
        if isinstance(v, list):
            print('Key:', k)
            print('Value length:', len(v))
        elif isinstance(v, np.ndarray):
            print('Key:', k)
            print('Value shape:', v.shape)
        else:
            print('Other:', k)
    print('')

    for i in range(len(category_list)):
        print('---------- %03d ------------' % i)
        category = category_list[i]
        model_id = model_list[i]

        # Generate sentence
        for j in range(data_process.n_captions_per_model):
            caption_idx = data_process.n_captions_per_model * i + j
            caption = caption_batch['raw_embedding_batch'][caption_idx]
            # print('Caption:', caption)
            # print('Converted caption:')
            data_list = [{'raw_caption_embedding': caption}]
            print_sentences(json_path, data_list)
            print('Label:', caption_batch['caption_label_batch'][caption_idx])

        print('Category:', category)
        print('Model ID:', model_id)

    kill_processes(data_queue, [data_process])
예제 #5
0
def test_caption_process():
    from multiprocessing import Queue
    from lib.config import cfg
    from lib.utils import open_pickle, print_sentences

    cfg.CONST.DATASET = 'primitives'
    cfg.CONST.SYNTH_EMBEDDING = False

    asdf_captions = open_pickle(cfg.DIR.PRIMITIVES_VAL_DATA_PATH)

    data_queue = Queue(3)

    data_process = CaptionDataProcess(data_queue, asdf_captions, repeat=True)
    data_process.start()
    caption_batch = data_queue.get()
    captions_tensor, category_list, model_list = caption_batch

    assert captions_tensor.shape[0] == len(category_list)
    assert len(category_list) == len(model_list)

    for i in range(len(category_list)):
        print('---------- %03d ------------' % i)
        caption = captions_tensor[i]
        category = category_list[i]
        model_id = model_list[i]
        # print('Caption:', caption)
        # print('Converted caption:')

        # Generate sentence
        # data_list = [{'raw_caption_embedding': caption}]
        # print_sentences(json_path, data_list)

        print('Category:', category)
        # print('Model ID:', model_id)

    kill_processes(data_queue, [data_process])
예제 #6
0
def retrieval(text_encoder,
              shape_encoder,
              ret_dict,
              opts,
              ret_type='text_to_shape'):
    #text_encoder.load_state_dict(torch.load('MODELS/METRIC_and_TST/txt_enc_acc.pth'))
    #shape_encoder.load_state_dict(torch.load('MODELS/METRIC_and_TST/shape_enc_acc.pth'))

    text_encoder.load_state_dict(
        torch.load('MODELS/METRIC_ONLY/txt_enc_acc.pth'))
    shape_encoder.load_state_dict(
        torch.load('MODELS/METRIC_ONLY/shape_enc_acc.pth'))

    text_encoder.eval()
    shape_encoder.eval()

    num_of_retrieval = 50
    n_neighbors = 20

    if (ret_type == 'text_to_shape'):
        embeddings_trained = utils.open_pickle('shape_only.p')
        embeddings, model_ids = utils.create_embedding_tuples(
            embeddings_trained, embedd_type='shape')
        length_trained = len(model_ids)
        queried_captions = []

        print("start retrieval")
        num_of_captions = len(ret_dict['caption_tuples'])

        iteration = 0
        while (iteration < num_of_retrieval):

            #rand_ind = np.random.randint(0,num_of_captions)
            rand_ind = iteration
            caption_tuple = ret_dict['caption_tuples'][rand_ind]
            caption = caption_tuple[0]
            model_id = caption_tuple[2]
            queried_captions.append(caption)

            input_caption = torch.from_numpy(caption).unsqueeze(
                0).long().cuda()

            txt_embedd_output = text_encoder(input_caption)

            #add the embedding to the trained embeddings
            embeddings = np.append(embeddings,
                                   txt_embedd_output.detach().cpu().numpy(),
                                   axis=0)
            model_ids.append(model_id)

            iteration += 1

        #n_neighbors = 10
        model_ids = np.array(model_ids)
        embeddings_fused = [(i, j) for i, j in zip(model_ids, embeddings)]
        outputs_dict = {
            'caption_embedding_tuples': embeddings_fused,
            'dataset_size': len(embeddings_fused)
        }

        (embeddings_matrix, labels, num_embeddings,
         label_counter) = ut.construct_embeddings_matrix(outputs_dict)
        indices = ut._compute_nearest_neighbors_cosine(embeddings_matrix,
                                                       embeddings_matrix,
                                                       n_neighbors, True)
        important_indices = indices[-num_of_retrieval::]
        important_model_id = model_ids[-num_of_retrieval::]

        caption_file = open('Retrieval/text_to_shape/inp_captions.txt', 'w')
        counter = 0
        for q in range(num_of_retrieval):

            cur_model_id = important_model_id[q]
            all_nn = important_indices[q]
            #kick out all neighbors which are in the queried ones
            all_nn = [n for n in all_nn if n < length_trained]

            NN = np.argwhere(model_ids[all_nn] == cur_model_id)
            print(" found correct one :", NN)
            if (len(NN) < 1):
                NN = important_indices[q][0]
            else:
                counter += 1
                NN = NN[0][0]
                NN = important_indices[q][NN]

            q_caption = queried_captions[q]
            sentence = utils.convert_idx_to_words(q_caption)
            caption_file.write('{}\n'.format(sentence))
            try:
                os.mkdir('Retrieval/text_to_shape/{0}/'.format(q))
            except OSError:
                pass
            for ii, nn in enumerate(all_nn):
                q_model_id = embeddings_fused[nn][0]
                voxel_file = opts.png_dir % (q_model_id, q_model_id)
                img = mpimg.imread(voxel_file)
                imgplot = plt.imshow(img)

                plt.savefig('Retrieval/text_to_shape/{0}/{1}.png'.format(
                    q, ii))
                plt.clf()

            #q_caption = queried_captions[q]
            #q_model_id = embeddings_fused[NN][0]
            #sentence = utils.convert_idx_to_words(q_caption)
            #caption_file.write('{}\n'.format(sentence))

            #voxel_file = opts.png_dir % (q_model_id,q_model_id)
            #img = mpimg.imread(voxel_file)
            #imgplot = plt.imshow(img)
            #plt.savefig('Retrieval/text_to_shape/{0}.png'.format(q))
            #plt.clf()

        caption_file.close()
        print("ACC :", (counter / num_of_retrieval))

    elif (ret_type == 'shape_to_text'):

        embeddings_trained = utils.open_pickle('text_only.p')
        embeddings, model_ids, raw_caption = utils.create_embedding_tuples(
            embeddings_trained, embedd_type='text')

        queried_shapes = []

        print("start retrieval")
        num_of_captions = len(ret_dict['caption_tuples'])
        #num_of_retrieval = 10
        iteration = 0
        while (iteration < num_of_retrieval):

            #rand_ind = np.random.randint(0,num_of_captions)
            rand_ind = iteration
            caption_tuple = ret_dict['caption_tuples'][rand_ind]
            #caption = caption_tuple[0]
            model_id = caption_tuple[2]

            cur_shape = utils.load_voxel(None, model_id, opts)
            queried_shapes.append(cur_shape)
            input_shape = torch.from_numpy(cur_shape).unsqueeze(0).permute(
                0, 4, 1, 2, 3).float().cuda()
            #input_caption = torch.from_numpy(caption).unsqueeze(0).long().cuda()

            #txt_embedd_output = text_encoder(input_caption)
            shape_embedd_output = shape_encoder(input_shape)

            #add the embedding to the trained embeddings
            embeddings = np.append(embeddings,
                                   shape_embedd_output.detach().cpu().numpy(),
                                   axis=0)
            model_ids.append(model_id)

            iteration += 1

        model_ids = np.array(model_ids)
        embeddings_fused = [(i, j) for i, j in zip(model_ids, embeddings)]
        outputs_dict = {
            'caption_embedding_tuples': embeddings_fused,
            'dataset_size': len(embeddings_fused)
        }

        (embeddings_matrix, labels, num_embeddings,
         label_counter) = ut.construct_embeddings_matrix(outputs_dict)
        indices = ut._compute_nearest_neighbors_cosine(embeddings_matrix,
                                                       embeddings_matrix,
                                                       n_neighbors, True)
        important_indices = indices[-num_of_retrieval::]
        important_model_id = model_ids[-num_of_retrieval::]

        caption_file = open('Retrieval/shape_to_text/inp_captions.txt', 'w')
        counter = 0
        for q in range(num_of_retrieval):

            cur_model_id = important_model_id[q]
            all_nn = important_indices[q]

            all_nn = [n for n in all_nn if n < len(raw_caption)]

            NN = np.argwhere(model_ids[all_nn] == cur_model_id)
            print(" found correct one :", NN)
            if (len(NN) < 1):
                NN = all_nn[0]
            else:
                counter += 1
                NN = NN[0][0]
                NN = all_nn[NN]

            #---------------------------------------------------------
            q_shape = queried_shapes[q]
            caption_file = open(
                'Retrieval/shape_to_text/inp_captions{0}.txt'.format(q), 'w')
            for ii, nn in enumerate(all_nn):

                q_caption = raw_caption[nn].data.numpy()
                #q_model_id = embeddings_fused[nn][0]
                sentence = utils.convert_idx_to_words(q_caption)
                caption_file.write('{}\n'.format(sentence))
            caption_file.close()

            #---------------------------------------------------------

            #q_shape = queried_shapes[q]
            #q_caption = raw_caption[NN].data.numpy()
            #q_model_id = embeddings_fused[NN][0]
            #sentence = utils.convert_idx_to_words(q_caption)
            #caption_file.write('{}\n'.format(sentence))
            q_model_id = cur_model_id
            voxel_file = opts.png_dir % (q_model_id, q_model_id)
            img = mpimg.imread(voxel_file)
            imgplot = plt.imshow(img)
            plt.savefig('Retrieval/shape_to_text/{0}.png'.format(q))
            plt.clf()

        #caption_file.close()
        print("ACC :", (counter / num_of_retrieval))
예제 #7
0
def main():

    load = True
    parser = argparse.ArgumentParser(
        description='main text2voxel train/test file')
    parser.add_argument('--dataset',
                        help='dataset',
                        default='shapenet',
                        type=str)
    parser.add_argument('--tensorboard', type=str, default='results')
    parser.add_argument('--batch_size', type=int, default=100)
    parser.add_argument('--data_dir', type=str, default=cfg.DIR.RGB_VOXEL_PATH)
    parser.add_argument('--png_dir', type=str, default=cfg.DIR.RGB_PNG_PATH)
    parser.add_argument('--num_workers',
                        type=int,
                        default=cfg.CONST.NUM_WORKERS)
    parser.add_argument('--LBA_n_captions_per_model', type=int, default=2)
    parser.add_argument('--rho', type=float, default=0.5)
    parser.add_argument('--learning_rate',
                        type=float,
                        default=cfg.TRAIN.LEARNING_RATE)
    parser.add_argument('--probablematic_nrrd_path',
                        type=str,
                        default=cfg.DIR.PROBLEMATIC_NRRD_PATH)
    parser.add_argument('--train', type=bool, default=True)
    parser.add_argument('--retrieval', type=bool, default=False)
    parser.add_argument('--tensorboard_name', type=str, default='test')

    opts = parser.parse_args()

    writer = SummaryWriter(
        os.path.join(opts.tensorboard, opts.tensorboard_name))
    inputs_dict = utils.open_pickle(cfg.DIR.TRAIN_DATA_PATH)
    val_inputs_dict = utils.open_pickle(cfg.DIR.VAL_DATA_PATH)
    #we basiaclly neglectthe problematic ones later in the dataloader
    #opts.probablematic_nrrd_path = cfg.DIR.PROBLEMATIC_NRRD_PATH
    opts_val = copy.deepcopy(opts)
    opts_val.batch_size = 256
    opts_val.val_inputs_dict = val_inputs_dict

    text_encoder = CNNRNNTextEncoder(
        vocab_size=inputs_dict['vocab_size']).cuda()
    shape_encoder = ShapeEncoder().cuda()

    shape_gen_raw, shape_mod_list = gn.SS_generator(opts_val.val_inputs_dict,
                                                    opts)
    text_gen_raw, text_mod_list = gn.TS_generator(opts_val.val_inputs_dict,
                                                  opts)

    parameter = {
        'shape_encoder': shape_encoder,
        'text_encoder': text_encoder,
        'shape_gen_raw': shape_gen_raw,
        'shape_mod_list': shape_mod_list,
        'text_gen_raw': text_gen_raw,
        'text_mod_list': text_mod_list,
        'writer': writer,
        'opts': opts,
        'inputs_dict': inputs_dict,
        'opts_val': opts_val
    }

    if (opts.train):
        train(parameter)
예제 #8
0
파일: DataLoader.py 프로젝트: sycz00/DL_Lab
            else:
                break 
        cur_shape = load_voxel(cur_category, cur_model_id, self.opts)

        selected_tuples = [s[0] for s in selected_tuples]
        return selected_tuples, cur_shape, cur_model_id, cur_category

"""

#Test ---------------
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='main text2voxel train/test file')
    opts = parser.parse_args()
    opts.data_dir = cfg.DIR.RGB_VOXEL_PATH

    inputs_dict = ut.open_pickle(cfg.DIR.TRAIN_DATA_PATH)
    params = {'batch_size': 64,
          'shuffle': True,
          'num_workers': 6,
           'captions_per_model':2 }

    params_2 = {'batch_size': 64,
          'shuffle': True,
          'num_workers': 6,
          'collate_fn' : collate_wrapper
           }
    dat_loader = ShapeNetDataset(inputs_dict,params,opts)
    
    training_generator = torch.utils.data.DataLoader(dat_loader, **params_2)

    #mask_ndarray = np.asarray([1., 0.] * (64))[:, np.newaxis]
예제 #9
0
파일: main.py 프로젝트: zehuiw/text2shape
def get_inputs_dict(args):
    """Gets the input dict for the current model and dataset.
    """
    if cfg.CONST.DATASET == 'shapenet':
        if (args.text_encoder is True) or (args.end2end is
                                           True) or (args.classifier is True):
            inputs_dict = utils.open_pickle(cfg.DIR.TRAIN_DATA_PATH)
            val_inputs_dict = utils.open_pickle(cfg.DIR.VAL_DATA_PATH)
            test_inputs_dict = utils.open_pickle(cfg.DIR.TEST_DATA_PATH)
        else:  # Learned embeddings
            inputs_dict = utils.open_pickle(
                cfg.DIR.SHAPENET_METRIC_EMBEDDINGS_TRAIN)
            val_inputs_dict = utils.open_pickle(
                cfg.DIR.SHAPENET_METRIC_EMBEDDINGS_VAL)
            test_inputs_dict = utils.open_pickle(
                cfg.DIR.SHAPENET_METRIC_EMBEDDINGS_TEST)
    elif cfg.CONST.DATASET == 'primitives':
        if ((cfg.CONST.SYNTH_EMBEDDING is True) or (args.text_encoder is True)
                or (args.classifier is True)):
            if args.classifier and not cfg.CONST.REED_CLASSIFIER:  # Train on all splits for classifier
                tf.logging.info(
                    'Using all (train/val/test) splits for training')
                inputs_dict = utils.open_pickle(
                    cfg.DIR.PRIMITIVES_ALL_SPLITS_DATA_PATH)
            else:
                tf.logging.info('Using train split only for training')
                inputs_dict = utils.open_pickle(
                    cfg.DIR.PRIMITIVES_TRAIN_DATA_PATH)
            val_inputs_dict = utils.open_pickle(
                cfg.DIR.PRIMITIVES_VAL_DATA_PATH)
            test_inputs_dict = utils.open_pickle(
                cfg.DIR.PRIMITIVES_TEST_DATA_PATH)
        else:  # Learned embeddings
            inputs_dict = utils.open_pickle(
                cfg.DIR.PRIMITIVES_METRIC_EMBEDDINGS_TRAIN)
            val_inputs_dict = utils.open_pickle(
                cfg.DIR.PRIMITIVES_METRIC_EMBEDDINGS_VAL)
            test_inputs_dict = utils.open_pickle(
                cfg.DIR.PRIMITIVES_METRIC_EMBEDDINGS_TEST)
    else:
        raise ValueError('Please use a valid dataset (shapenet, primitives).')

    if args.tiny_dataset is True:
        if ((cfg.CONST.DATASET == 'primitives'
             and cfg.CONST.SYNTH_EMBEDDING is True)
                or (args.text_encoder is True)):
            raise NotImplementedError(
                'Tiny dataset not supported for synthetic embeddings.')

        ds = 5  # New dataset size
        if cfg.CONST.BATCH_SIZE > ds:
            raise ValueError(
                'Please use a smaller batch size than {}.'.format(ds))
        inputs_dict = utils.change_dataset_size(inputs_dict,
                                                new_dataset_size=ds)
        val_inputs_dict = utils.change_dataset_size(val_inputs_dict,
                                                    new_dataset_size=ds)
        test_inputs_dict = utils.change_dataset_size(test_inputs_dict,
                                                     new_dataset_size=ds)

    # Select the validation/test split
    if args.split == 'train':
        split_str = 'train'
        val_inputs_dict = inputs_dict
    elif (args.split == 'val') or (args.split is None):
        split_str = 'val'
        val_inputs_dict = val_inputs_dict
    elif args.split == 'test':
        split_str = 'test'
        val_inputs_dict = test_inputs_dict
    else:
        raise ValueError('Please select a valid split (train, val, test).')
    print('Validation/testing on {} split.'.format(split_str))

    if (cfg.CONST.DATASET
            == 'shapenet') and (cfg.CONST.SHAPENET_CT_CLASSIFIER is True):
        category_model_list, class_labels = Classifier.set_up_classification(
            inputs_dict)
        val_category_model_list, val_class_labels = Classifier.set_up_classification(
            val_inputs_dict)
        assert class_labels == val_class_labels

        # Update inputs dicts
        inputs_dict['category_model_list'] = category_model_list
        inputs_dict['class_labels'] = class_labels
        val_inputs_dict['category_model_list'] = val_category_model_list
        val_inputs_dict['class_labels'] = val_class_labels

    return inputs_dict, val_inputs_dict
def test_lba_process():
    from multiprocessing import Queue
    from lib.utils import print_sentences
    parser = argparse.ArgumentParser(description='test data process')
    parser.add_argument('--dataset',
                        dest='dataset',
                        help='dataset',
                        default='shapenet',
                        type=str)
    opts = parser.parse_args()
    opts.batch_size = 8
    opts.LBA_n_captions_per_model = 5
    opts.synth_embedding = False
    opts.probablematic_nrrd_path = '/home/hxw/project_work_on/shape_research/datasets/text2shape/shapenet_dataset/shapenet_info/problematic_nrrds_shapenet_unverified_256_filtered_div_with_err_textures.p'
    opts.LBA_model_type = 'STS'
    opts.val_data_path = '/home/hxw/project_work_on/shape_research/datasets/text2shape/shapenet_dataset/shapenet_info/processed_captions_val.p'
    opts.data_dir = '/home/hxw/project_work_on/shape_research/datasets/text2shape/shapenet_dataset/nrrd_256_filter_div_32_solid/%s/%s.nrrd'

    caption_data = open_pickle(opts.val_data_path)
    data_queue = Queue(3)  # 3代表队列中存放的数据个数上线,达到上限,就会发生阻塞,直到队列中的数据被消费掉
    json_path = '/home/hxw/project_work_on/shape_research/datasets/text2shape/shapenet_dataset/shapenet_info/shapenet.json'

    pdb.set_trace()
    data_process = LBADataProcess(data_queue, caption_data, opts, repeat=True)
    data_process.start()
    caption_batch = data_queue.get()
    category_list = caption_batch['category_list']
    model_list = caption_batch['model_list']

    for k, v in caption_batch.items():
        if isinstance(v, list):
            print('key: ', k)
            print('value length: ', len(v))
        elif isinstance(v, np.ndarray):
            print('key: ', k)
            print('Value shape: ', v.shape)
        else:
            print('Other: ', k)
    print('')
    pdb.set_trace()
    """
    for i in range(len(category_list)):
        print('-------%03d------'%i)
        category = category_list[i] 
        model_id = model_list[i] 

        # generate sentencce 
        for j in range(data_process.n_captions_per_model):
            caption_idx = data_process.n_captions_per_model * i + j 
            caption = caption_batch['raw_embedding_batch'][caption_idx] 

            # print('caption:', caption)
            # print('converted caption: ')
            data_list = [{'raw_caption_embedding': caption}] 
            print_sentences(json_path, data_list)
            print('label: ', caption_batch['caption_label_batch'][caption_idx].item()) 

        print('category: ', category) 
        print('model id: ', model_id) 
    """
    pdb.set_trace()

    kill_processes(data_queue, [data_process])