def _build_category2label(self): if cfg.CONST.DATASET == 'shapenet': category2label = {'03001627': 0, '04379243': 1} elif cfg.CONST.DATASET == 'primitives': train_inputs_dict = open_pickle(cfg.DIR.PRIMITIVES_TRAIN_DATA_PATH) val_inputs_dict = open_pickle(cfg.DIR.PRIMITIVES_VAL_DATA_PATH) test_inputs_dict = open_pickle(cfg.DIR.PRIMITIVES_TEST_DATA_PATH) f = lambda inputs_dict: list(inputs_dict['category_matches'].keys( )) categories = f(train_inputs_dict) + f(val_inputs_dict) + f( test_inputs_dict) categories = list(set(categories)) category2label = {cat: idx for idx, cat in enumerate(categories)} else: raise ValueError('Please select a valid dataset.') return category2label
def build_architecture(self, inputs_dict): x = inputs_dict['shape_batch'] if cfg.CONST.DATASET == 'shapenet': num_classes = 2 # Chair/table classification elif cfg.CONST.DATASET == 'primitives': train_inputs_dict = open_pickle(cfg.DIR.PRIMITIVES_TRAIN_DATA_PATH) val_inputs_dict = open_pickle(cfg.DIR.PRIMITIVES_VAL_DATA_PATH) test_inputs_dict = open_pickle(cfg.DIR.PRIMITIVES_TEST_DATA_PATH) f = lambda inputs_dict: list(inputs_dict['category_matches'].keys()) categories = f(train_inputs_dict) + f(val_inputs_dict) + f(test_inputs_dict) categories = list(set(categories)) num_classes = len(categories) else: raise ValueError('Please select a valid dataset') x = layers.conv3d(x, 64, 3, strides=2, padding='same', name='conv1', reuse=self.reuse) x = tf.layers.batch_normalization(x, training=self.is_training, name='conv1_batch_norm', reuse=self.reuse) x = layers.relu(x, name='conv1_relu') x = layers.conv3d(x, 128, 3, strides=2, padding='same', name='conv2', reuse=self.reuse) x = tf.layers.batch_normalization(x, training=self.is_training, name='conv2_batch_norm', reuse=self.reuse) x = layers.relu(x, name='conv2_relu') x = layers.conv3d(x, 256, 3, strides=2, padding='same', name='conv3', reuse=self.reuse) x = tf.layers.batch_normalization(x, training=self.is_training, name='conv3_batch_norm', reuse=self.reuse) x = layers.relu(x, name='conv3_relu') # get intermediate results intermediate_output = x x = layers.avg_pooling3d(x, name='avg_pool4') x = layers.dense(x, 128, name='fc5', reuse=self.reuse) encoder_output = x x = layers.dense(x, num_classes, name='fc6', reuse=self.reuse) prob = layers.softmax(x, name='softmax_layer') output_dict = { 'logits': x, 'probabilities': prob, 'encoder_output': encoder_output, 'intermediate_output': intermediate_output } return output_dict
def get_inputs_dict(opts): """ Gets the input dict for the current model and dataset. """ if opts.dataset == 'shapenet': pass # if (args.text_encoder is True) or (args.end2end is True) or (args.classifier is True): # inputs_dict = utils.open_pickle(cfg.DIR.TRAIN_DATA_PATH) # val_inputs_dict = utils.open_pickle(cfg.DIR.VAL_DATA_PATH) # test_inputs_dict = utils.open_pickle(cfg.DIR.TEST_DATA_PATH) # else: # Learned embeddings # inputs_dict = utils.open_pickle(cfg.DIR.SHAPENET_METRIC_EMBEDDINGS_TRAIN) # val_inputs_dict = utils.open_pickle(cfg.DIR.SHAPENET_METRIC_EMBEDDINGS_VAL) # test_inputs_dict = utils.open_pickle(cfg.DIR.SHAPENET_METRIC_EMBEDDINGS_TEST) elif opts.dataset == 'primitives': # Primitive dataset if ((opts.synth_embedding is True) or (opts.text_encoder is True) or (opts.classifier is True)): if opts.classifier and not opts.reed_classifier: # Train on all splits for classifier # tf.logging.info('Using all (train/val/test) splits for training') # logging using all (trian/val/test) splits for training print('using all (train/val/test) splits for training.') inputs_dict = utils.open_pickle(opts.primitives_all_splits_data_path) else: print('training using train split only.') inputs_dict = utils.open_pickle(opts.primitives_train_data_path) val_inputs_dict = utils.open_pickle(opts.primitives_val_data_path) test_inputs_dict = utils.open_pickle(opts.primitives_test_data_path) else: # Learned embeddings inputs_dict = utils.open_pickle(opts.primitives_metric_embeddings_train) val_inputs_dict = utils.open_pickle(opts.primitives_metric_embeddings_val) test_inputs_dict = utils.open_pickle(opts.primitives_metric_embeddings_test) else: raise ValueError('Please use a valid dataset (shapenet, primitives).') # Select the validation/test split if opts.val_split == 'train': val_split_str = 'train' val_inputs_dict = inputs_dict elif (opts.val_split == 'val') or (opts.val_split is None): val_split_str = 'val' val_inputs_dict = val_inputs_dict elif opts.val_split == 'test': val_split_str = 'test' val_inputs_dict = test_inputs_dict else: raise ValueError('Please select a valid split (train, val, test).') print('Validation/testing on {} split.'.format(val_split_str)) if opts.dataset == 'shapenet' and opts.shapenet_ct_classifier is True: pass return inputs_dict, val_inputs_dict
def test_lba_process(): from multiprocessing import Queue from lib.config import cfg from lib.utils import open_pickle, print_sentences, get_json_path cfg.CONST.BATCH_SIZE = 8 cfg.CONST.DATASET = 'shapenet' cfg.CONST.SYNTH_EMBEDDING = False caption_data = open_pickle(cfg.DIR.VAL_DATA_PATH) data_queue = Queue(3) json_path = get_json_path() data_process = LBADataProcess(data_queue, caption_data, repeat=True) data_process.start() caption_batch = data_queue.get() category_list = caption_batch['category_list'] model_list = caption_batch['model_list'] for k, v in caption_batch.items(): if isinstance(v, list): print('Key:', k) print('Value length:', len(v)) elif isinstance(v, np.ndarray): print('Key:', k) print('Value shape:', v.shape) else: print('Other:', k) print('') for i in range(len(category_list)): print('---------- %03d ------------' % i) category = category_list[i] model_id = model_list[i] # Generate sentence for j in range(data_process.n_captions_per_model): caption_idx = data_process.n_captions_per_model * i + j caption = caption_batch['raw_embedding_batch'][caption_idx] # print('Caption:', caption) # print('Converted caption:') data_list = [{'raw_caption_embedding': caption}] print_sentences(json_path, data_list) print('Label:', caption_batch['caption_label_batch'][caption_idx]) print('Category:', category) print('Model ID:', model_id) kill_processes(data_queue, [data_process])
def test_caption_process(): from multiprocessing import Queue from lib.config import cfg from lib.utils import open_pickle, print_sentences cfg.CONST.DATASET = 'primitives' cfg.CONST.SYNTH_EMBEDDING = False asdf_captions = open_pickle(cfg.DIR.PRIMITIVES_VAL_DATA_PATH) data_queue = Queue(3) data_process = CaptionDataProcess(data_queue, asdf_captions, repeat=True) data_process.start() caption_batch = data_queue.get() captions_tensor, category_list, model_list = caption_batch assert captions_tensor.shape[0] == len(category_list) assert len(category_list) == len(model_list) for i in range(len(category_list)): print('---------- %03d ------------' % i) caption = captions_tensor[i] category = category_list[i] model_id = model_list[i] # print('Caption:', caption) # print('Converted caption:') # Generate sentence # data_list = [{'raw_caption_embedding': caption}] # print_sentences(json_path, data_list) print('Category:', category) # print('Model ID:', model_id) kill_processes(data_queue, [data_process])
def retrieval(text_encoder, shape_encoder, ret_dict, opts, ret_type='text_to_shape'): #text_encoder.load_state_dict(torch.load('MODELS/METRIC_and_TST/txt_enc_acc.pth')) #shape_encoder.load_state_dict(torch.load('MODELS/METRIC_and_TST/shape_enc_acc.pth')) text_encoder.load_state_dict( torch.load('MODELS/METRIC_ONLY/txt_enc_acc.pth')) shape_encoder.load_state_dict( torch.load('MODELS/METRIC_ONLY/shape_enc_acc.pth')) text_encoder.eval() shape_encoder.eval() num_of_retrieval = 50 n_neighbors = 20 if (ret_type == 'text_to_shape'): embeddings_trained = utils.open_pickle('shape_only.p') embeddings, model_ids = utils.create_embedding_tuples( embeddings_trained, embedd_type='shape') length_trained = len(model_ids) queried_captions = [] print("start retrieval") num_of_captions = len(ret_dict['caption_tuples']) iteration = 0 while (iteration < num_of_retrieval): #rand_ind = np.random.randint(0,num_of_captions) rand_ind = iteration caption_tuple = ret_dict['caption_tuples'][rand_ind] caption = caption_tuple[0] model_id = caption_tuple[2] queried_captions.append(caption) input_caption = torch.from_numpy(caption).unsqueeze( 0).long().cuda() txt_embedd_output = text_encoder(input_caption) #add the embedding to the trained embeddings embeddings = np.append(embeddings, txt_embedd_output.detach().cpu().numpy(), axis=0) model_ids.append(model_id) iteration += 1 #n_neighbors = 10 model_ids = np.array(model_ids) embeddings_fused = [(i, j) for i, j in zip(model_ids, embeddings)] outputs_dict = { 'caption_embedding_tuples': embeddings_fused, 'dataset_size': len(embeddings_fused) } (embeddings_matrix, labels, num_embeddings, label_counter) = ut.construct_embeddings_matrix(outputs_dict) indices = ut._compute_nearest_neighbors_cosine(embeddings_matrix, embeddings_matrix, n_neighbors, True) important_indices = indices[-num_of_retrieval::] important_model_id = model_ids[-num_of_retrieval::] caption_file = open('Retrieval/text_to_shape/inp_captions.txt', 'w') counter = 0 for q in range(num_of_retrieval): cur_model_id = important_model_id[q] all_nn = important_indices[q] #kick out all neighbors which are in the queried ones all_nn = [n for n in all_nn if n < length_trained] NN = np.argwhere(model_ids[all_nn] == cur_model_id) print(" found correct one :", NN) if (len(NN) < 1): NN = important_indices[q][0] else: counter += 1 NN = NN[0][0] NN = important_indices[q][NN] q_caption = queried_captions[q] sentence = utils.convert_idx_to_words(q_caption) caption_file.write('{}\n'.format(sentence)) try: os.mkdir('Retrieval/text_to_shape/{0}/'.format(q)) except OSError: pass for ii, nn in enumerate(all_nn): q_model_id = embeddings_fused[nn][0] voxel_file = opts.png_dir % (q_model_id, q_model_id) img = mpimg.imread(voxel_file) imgplot = plt.imshow(img) plt.savefig('Retrieval/text_to_shape/{0}/{1}.png'.format( q, ii)) plt.clf() #q_caption = queried_captions[q] #q_model_id = embeddings_fused[NN][0] #sentence = utils.convert_idx_to_words(q_caption) #caption_file.write('{}\n'.format(sentence)) #voxel_file = opts.png_dir % (q_model_id,q_model_id) #img = mpimg.imread(voxel_file) #imgplot = plt.imshow(img) #plt.savefig('Retrieval/text_to_shape/{0}.png'.format(q)) #plt.clf() caption_file.close() print("ACC :", (counter / num_of_retrieval)) elif (ret_type == 'shape_to_text'): embeddings_trained = utils.open_pickle('text_only.p') embeddings, model_ids, raw_caption = utils.create_embedding_tuples( embeddings_trained, embedd_type='text') queried_shapes = [] print("start retrieval") num_of_captions = len(ret_dict['caption_tuples']) #num_of_retrieval = 10 iteration = 0 while (iteration < num_of_retrieval): #rand_ind = np.random.randint(0,num_of_captions) rand_ind = iteration caption_tuple = ret_dict['caption_tuples'][rand_ind] #caption = caption_tuple[0] model_id = caption_tuple[2] cur_shape = utils.load_voxel(None, model_id, opts) queried_shapes.append(cur_shape) input_shape = torch.from_numpy(cur_shape).unsqueeze(0).permute( 0, 4, 1, 2, 3).float().cuda() #input_caption = torch.from_numpy(caption).unsqueeze(0).long().cuda() #txt_embedd_output = text_encoder(input_caption) shape_embedd_output = shape_encoder(input_shape) #add the embedding to the trained embeddings embeddings = np.append(embeddings, shape_embedd_output.detach().cpu().numpy(), axis=0) model_ids.append(model_id) iteration += 1 model_ids = np.array(model_ids) embeddings_fused = [(i, j) for i, j in zip(model_ids, embeddings)] outputs_dict = { 'caption_embedding_tuples': embeddings_fused, 'dataset_size': len(embeddings_fused) } (embeddings_matrix, labels, num_embeddings, label_counter) = ut.construct_embeddings_matrix(outputs_dict) indices = ut._compute_nearest_neighbors_cosine(embeddings_matrix, embeddings_matrix, n_neighbors, True) important_indices = indices[-num_of_retrieval::] important_model_id = model_ids[-num_of_retrieval::] caption_file = open('Retrieval/shape_to_text/inp_captions.txt', 'w') counter = 0 for q in range(num_of_retrieval): cur_model_id = important_model_id[q] all_nn = important_indices[q] all_nn = [n for n in all_nn if n < len(raw_caption)] NN = np.argwhere(model_ids[all_nn] == cur_model_id) print(" found correct one :", NN) if (len(NN) < 1): NN = all_nn[0] else: counter += 1 NN = NN[0][0] NN = all_nn[NN] #--------------------------------------------------------- q_shape = queried_shapes[q] caption_file = open( 'Retrieval/shape_to_text/inp_captions{0}.txt'.format(q), 'w') for ii, nn in enumerate(all_nn): q_caption = raw_caption[nn].data.numpy() #q_model_id = embeddings_fused[nn][0] sentence = utils.convert_idx_to_words(q_caption) caption_file.write('{}\n'.format(sentence)) caption_file.close() #--------------------------------------------------------- #q_shape = queried_shapes[q] #q_caption = raw_caption[NN].data.numpy() #q_model_id = embeddings_fused[NN][0] #sentence = utils.convert_idx_to_words(q_caption) #caption_file.write('{}\n'.format(sentence)) q_model_id = cur_model_id voxel_file = opts.png_dir % (q_model_id, q_model_id) img = mpimg.imread(voxel_file) imgplot = plt.imshow(img) plt.savefig('Retrieval/shape_to_text/{0}.png'.format(q)) plt.clf() #caption_file.close() print("ACC :", (counter / num_of_retrieval))
def main(): load = True parser = argparse.ArgumentParser( description='main text2voxel train/test file') parser.add_argument('--dataset', help='dataset', default='shapenet', type=str) parser.add_argument('--tensorboard', type=str, default='results') parser.add_argument('--batch_size', type=int, default=100) parser.add_argument('--data_dir', type=str, default=cfg.DIR.RGB_VOXEL_PATH) parser.add_argument('--png_dir', type=str, default=cfg.DIR.RGB_PNG_PATH) parser.add_argument('--num_workers', type=int, default=cfg.CONST.NUM_WORKERS) parser.add_argument('--LBA_n_captions_per_model', type=int, default=2) parser.add_argument('--rho', type=float, default=0.5) parser.add_argument('--learning_rate', type=float, default=cfg.TRAIN.LEARNING_RATE) parser.add_argument('--probablematic_nrrd_path', type=str, default=cfg.DIR.PROBLEMATIC_NRRD_PATH) parser.add_argument('--train', type=bool, default=True) parser.add_argument('--retrieval', type=bool, default=False) parser.add_argument('--tensorboard_name', type=str, default='test') opts = parser.parse_args() writer = SummaryWriter( os.path.join(opts.tensorboard, opts.tensorboard_name)) inputs_dict = utils.open_pickle(cfg.DIR.TRAIN_DATA_PATH) val_inputs_dict = utils.open_pickle(cfg.DIR.VAL_DATA_PATH) #we basiaclly neglectthe problematic ones later in the dataloader #opts.probablematic_nrrd_path = cfg.DIR.PROBLEMATIC_NRRD_PATH opts_val = copy.deepcopy(opts) opts_val.batch_size = 256 opts_val.val_inputs_dict = val_inputs_dict text_encoder = CNNRNNTextEncoder( vocab_size=inputs_dict['vocab_size']).cuda() shape_encoder = ShapeEncoder().cuda() shape_gen_raw, shape_mod_list = gn.SS_generator(opts_val.val_inputs_dict, opts) text_gen_raw, text_mod_list = gn.TS_generator(opts_val.val_inputs_dict, opts) parameter = { 'shape_encoder': shape_encoder, 'text_encoder': text_encoder, 'shape_gen_raw': shape_gen_raw, 'shape_mod_list': shape_mod_list, 'text_gen_raw': text_gen_raw, 'text_mod_list': text_mod_list, 'writer': writer, 'opts': opts, 'inputs_dict': inputs_dict, 'opts_val': opts_val } if (opts.train): train(parameter)
else: break cur_shape = load_voxel(cur_category, cur_model_id, self.opts) selected_tuples = [s[0] for s in selected_tuples] return selected_tuples, cur_shape, cur_model_id, cur_category """ #Test --------------- if __name__ == "__main__": parser = argparse.ArgumentParser(description='main text2voxel train/test file') opts = parser.parse_args() opts.data_dir = cfg.DIR.RGB_VOXEL_PATH inputs_dict = ut.open_pickle(cfg.DIR.TRAIN_DATA_PATH) params = {'batch_size': 64, 'shuffle': True, 'num_workers': 6, 'captions_per_model':2 } params_2 = {'batch_size': 64, 'shuffle': True, 'num_workers': 6, 'collate_fn' : collate_wrapper } dat_loader = ShapeNetDataset(inputs_dict,params,opts) training_generator = torch.utils.data.DataLoader(dat_loader, **params_2) #mask_ndarray = np.asarray([1., 0.] * (64))[:, np.newaxis]
def get_inputs_dict(args): """Gets the input dict for the current model and dataset. """ if cfg.CONST.DATASET == 'shapenet': if (args.text_encoder is True) or (args.end2end is True) or (args.classifier is True): inputs_dict = utils.open_pickle(cfg.DIR.TRAIN_DATA_PATH) val_inputs_dict = utils.open_pickle(cfg.DIR.VAL_DATA_PATH) test_inputs_dict = utils.open_pickle(cfg.DIR.TEST_DATA_PATH) else: # Learned embeddings inputs_dict = utils.open_pickle( cfg.DIR.SHAPENET_METRIC_EMBEDDINGS_TRAIN) val_inputs_dict = utils.open_pickle( cfg.DIR.SHAPENET_METRIC_EMBEDDINGS_VAL) test_inputs_dict = utils.open_pickle( cfg.DIR.SHAPENET_METRIC_EMBEDDINGS_TEST) elif cfg.CONST.DATASET == 'primitives': if ((cfg.CONST.SYNTH_EMBEDDING is True) or (args.text_encoder is True) or (args.classifier is True)): if args.classifier and not cfg.CONST.REED_CLASSIFIER: # Train on all splits for classifier tf.logging.info( 'Using all (train/val/test) splits for training') inputs_dict = utils.open_pickle( cfg.DIR.PRIMITIVES_ALL_SPLITS_DATA_PATH) else: tf.logging.info('Using train split only for training') inputs_dict = utils.open_pickle( cfg.DIR.PRIMITIVES_TRAIN_DATA_PATH) val_inputs_dict = utils.open_pickle( cfg.DIR.PRIMITIVES_VAL_DATA_PATH) test_inputs_dict = utils.open_pickle( cfg.DIR.PRIMITIVES_TEST_DATA_PATH) else: # Learned embeddings inputs_dict = utils.open_pickle( cfg.DIR.PRIMITIVES_METRIC_EMBEDDINGS_TRAIN) val_inputs_dict = utils.open_pickle( cfg.DIR.PRIMITIVES_METRIC_EMBEDDINGS_VAL) test_inputs_dict = utils.open_pickle( cfg.DIR.PRIMITIVES_METRIC_EMBEDDINGS_TEST) else: raise ValueError('Please use a valid dataset (shapenet, primitives).') if args.tiny_dataset is True: if ((cfg.CONST.DATASET == 'primitives' and cfg.CONST.SYNTH_EMBEDDING is True) or (args.text_encoder is True)): raise NotImplementedError( 'Tiny dataset not supported for synthetic embeddings.') ds = 5 # New dataset size if cfg.CONST.BATCH_SIZE > ds: raise ValueError( 'Please use a smaller batch size than {}.'.format(ds)) inputs_dict = utils.change_dataset_size(inputs_dict, new_dataset_size=ds) val_inputs_dict = utils.change_dataset_size(val_inputs_dict, new_dataset_size=ds) test_inputs_dict = utils.change_dataset_size(test_inputs_dict, new_dataset_size=ds) # Select the validation/test split if args.split == 'train': split_str = 'train' val_inputs_dict = inputs_dict elif (args.split == 'val') or (args.split is None): split_str = 'val' val_inputs_dict = val_inputs_dict elif args.split == 'test': split_str = 'test' val_inputs_dict = test_inputs_dict else: raise ValueError('Please select a valid split (train, val, test).') print('Validation/testing on {} split.'.format(split_str)) if (cfg.CONST.DATASET == 'shapenet') and (cfg.CONST.SHAPENET_CT_CLASSIFIER is True): category_model_list, class_labels = Classifier.set_up_classification( inputs_dict) val_category_model_list, val_class_labels = Classifier.set_up_classification( val_inputs_dict) assert class_labels == val_class_labels # Update inputs dicts inputs_dict['category_model_list'] = category_model_list inputs_dict['class_labels'] = class_labels val_inputs_dict['category_model_list'] = val_category_model_list val_inputs_dict['class_labels'] = val_class_labels return inputs_dict, val_inputs_dict
def test_lba_process(): from multiprocessing import Queue from lib.utils import print_sentences parser = argparse.ArgumentParser(description='test data process') parser.add_argument('--dataset', dest='dataset', help='dataset', default='shapenet', type=str) opts = parser.parse_args() opts.batch_size = 8 opts.LBA_n_captions_per_model = 5 opts.synth_embedding = False opts.probablematic_nrrd_path = '/home/hxw/project_work_on/shape_research/datasets/text2shape/shapenet_dataset/shapenet_info/problematic_nrrds_shapenet_unverified_256_filtered_div_with_err_textures.p' opts.LBA_model_type = 'STS' opts.val_data_path = '/home/hxw/project_work_on/shape_research/datasets/text2shape/shapenet_dataset/shapenet_info/processed_captions_val.p' opts.data_dir = '/home/hxw/project_work_on/shape_research/datasets/text2shape/shapenet_dataset/nrrd_256_filter_div_32_solid/%s/%s.nrrd' caption_data = open_pickle(opts.val_data_path) data_queue = Queue(3) # 3代表队列中存放的数据个数上线,达到上限,就会发生阻塞,直到队列中的数据被消费掉 json_path = '/home/hxw/project_work_on/shape_research/datasets/text2shape/shapenet_dataset/shapenet_info/shapenet.json' pdb.set_trace() data_process = LBADataProcess(data_queue, caption_data, opts, repeat=True) data_process.start() caption_batch = data_queue.get() category_list = caption_batch['category_list'] model_list = caption_batch['model_list'] for k, v in caption_batch.items(): if isinstance(v, list): print('key: ', k) print('value length: ', len(v)) elif isinstance(v, np.ndarray): print('key: ', k) print('Value shape: ', v.shape) else: print('Other: ', k) print('') pdb.set_trace() """ for i in range(len(category_list)): print('-------%03d------'%i) category = category_list[i] model_id = model_list[i] # generate sentencce for j in range(data_process.n_captions_per_model): caption_idx = data_process.n_captions_per_model * i + j caption = caption_batch['raw_embedding_batch'][caption_idx] # print('caption:', caption) # print('converted caption: ') data_list = [{'raw_caption_embedding': caption}] print_sentences(json_path, data_list) print('label: ', caption_batch['caption_label_batch'][caption_idx].item()) print('category: ', category) print('model id: ', model_id) """ pdb.set_trace() kill_processes(data_queue, [data_process])