def parser(): parser = clevr_parser.Parser( backend='spacy', model='en_core_web_sm', has_spatial=True, has_matching=True).get_backend(identifier='spacy') return parser
def main(args): """ Save nx.graph (Gss, Gts,...) and corresponding torch_geometric.data.PairData (via clevr_parse embedder api). """ if (args.input_vocab_json == '') and (args.output_vocab_json == ''): logger.info( 'Must give one of --input_vocab_json or --output_vocab_json') return graph_parser = clevr_parser.Parser( backend='spacy', model=args.parser_lm, has_spatial=True, has_matching=True).get_backend(identifier='spacy') embedder = clevr_parser.Embedder( backend='torch', parser=graph_parser).get_backend(identifier='torch') is_directed_graph = args.is_directed_graph # Parse graphs as nx.MultiDiGraph out_dir, out_f_prefix = _get_out_dir_and_file_prefix(args) checkpoint_dir = f"{out_dir}/checkpoints" utils.mkdirs(checkpoint_dir) questions, img_scenes = get_questions_and_parsed_scenes( args.input_questions_json, args.input_parsed_img_scenes_json) if args.is_debug: set_default_level(10) questions = questions[: 128] # default BSZ is 64 ensuring enought for batch iter logger.debug( f"In DEBUG mode, sampling {len(questions)} questions only..") # Process Vocab # vocab = _process_vocab(args, questions) # Encode all questions and programs logger.info('Encoding data') questions_encoded, programs_encoded, answers, image_idxs = [], [], [], [] question_families = [] orig_idxs = [] # Graphs and Embeddings # data_s_list = [] # List [torch_geometric.data.Data] data_t_list = [] # List [torch_geometric.data.Data] num_samples = 0 # Counter for keeping track of processed samples num_skipped = 0 # Counter for tracking num of samples skipped for orig_idx, q in enumerate(questions): # First See if Gss, Gts are possible to extract. # If not (for e.g., some edges cases like plurality, skip data sample img_idx = q['image_index'] img_fn = q['image_filename'] logger.debug(f"\tProcessing Image - {img_idx}: {img_fn} ...") # q_idx = q['question_index'] # q_fam_idx = q['question_family_index'] ## 1: Ensure both Gs,Gt is parseable for this question sample, o.w. skip img_scene = list( filter(lambda x: x['image_index'] == img_idx, img_scenes))[0] try: Gt, t_doc = graph_parser.get_doc_from_img_scene( img_scene, is_directed_graph=is_directed_graph) X_t, ei_t, e_attr_t = embedder.embed_t( img_idx, args.input_parsed_img_scenes_json) except AssertionError as ae: logger.warning(f"AssertionError Encountered: {ae}") logger.warning(f"[{img_fn}] Excluding images with > 10 objects") num_skipped += 1 continue if Gt is None and ("SKIP" in t_doc): # If the derendering pipeline failed, then just skip the # scene, don't process the labels (and text_scenes) for the image print(f"Got None img_doc at image_index: {img_idx}") print(f"Skipping all text_scenes for imgage idx: {img_idx}") num_skipped += 1 continue s = q['question'] orig_idx = q['question_index'] try: Gs, s_doc = graph_parser.parse(s, return_doc=True, is_directed_graph=is_directed_graph) X_s, ei_s, e_attr_s = embedder.embed_s(s) except ValueError as ve: logger.warning(f"ValueError Encountered: {ve}") logger.warning(f"Skipping question: {s} for {img_fn}") num_skipped += 1 continue if Gs is None and ("SKIP" in s_doc): logger.warning( "Got None as Gs and 'SKIP' in Gs_embd. (likely plural with CLEVR_OBJS label) " ) logger.warning( f"SKIPPING processing {s} for {img_fn} and at {img_idx}") num_skipped += 1 continue # Using ClevrData allows us a debug extension to Data data_s = ClevrData(x=X_s, edge_index=ei_s, edge_attr=e_attr_s) data_t = ClevrData(x=X_t, edge_index=ei_t, edge_attr=e_attr_t) data_s_list.append(data_s) data_t_list.append(data_t) question = q['question'] orig_idxs.append(orig_idx) image_idxs.append(img_idx) if 'question_family_index' in q: question_families.append(q['question_family_index']) question_tokens = preprocess_utils.tokenize(question, punct_to_keep=[';', ','], punct_to_remove=['?', '.']) question_encoded = preprocess_utils.encode( question_tokens, vocab['question_token_to_idx'], allow_unk=args.encode_unk == 1) questions_encoded.append(question_encoded) has_prog_seq = 'program' in q if has_prog_seq: program = q['program'] program_str = program_to_str(program, args.mode) program_tokens = preprocess_utils.tokenize(program_str) program_encoded = preprocess_utils.encode( program_tokens, vocab['program_token_to_idx']) programs_encoded.append(program_encoded) if 'answer' in q: ans = q['answer'] answers.append(vocab['answer_token_to_idx'][ans]) num_samples += 1 logger.info("-" * 50) logger.info(f"Samples processed count = {num_samples}") if has_prog_seq: logger.info(f"\n[{orig_idx}]: question: {question} \n" f"\tprog_str: {program_str} \n" f"\tanswer: {ans}") logger.info("-" * 50) # ---- CHECKPOINT ---- # if num_samples % args.checkpoint_every == 0: logger.info(f"Checkpointing at {num_samples}") checkpoint_fn_prefix = f"{out_f_prefix}_{num_samples}" _out_dir = f"{checkpoint_dir}/{out_f_prefix}_{num_samples}" utils.mkdirs(_out_dir) out_fpp = f"{_out_dir}/{checkpoint_fn_prefix}" # ------------ Checkpoint .H5 ------------# logger.info( f"CHECKPOINT: Saving checkpoint files at directory: {out_fpp}") save_h5(f"{out_fpp}.h5", vocab, questions_encoded, image_idxs, orig_idxs, programs_encoded, question_families, answers) # ------------ Checkpoint GRAPH DATA ------------# save_graph_pairdata(out_fpp, data_s_list, data_t_list, is_directed_graph=is_directed_graph) logger.info(f"-------------- CHECKPOINT: COMPLETED --------") if (args.max_sample > 0) and (num_samples >= args.max_sample): logger.info(f"len(questions_encoded = {len(questions_encoded)}") logger.info("args.max_sample reached: Completing ... ") break logger.debug(f"Total samples skipped = {num_skipped}") logger.debug(f"Total samples processed = {num_samples}") out_fpp = f"{out_dir}/{out_f_prefix}" ## SAVE .H5: Baseline {dataset}_h5.h5 file (q,p,ans,img_idx) as usual logger.info(f"Saving baseline (processed) data in: {out_fpp}.h5") save_h5(f"{out_fpp}.h5", vocab, questions_encoded, image_idxs, orig_idxs, programs_encoded, question_families, answers) ## ------------ SAVE GRAPH DATA ------------ ## ## N.b. Ensure the len of theses lists are all equals save_graph_pairdata(out_fpp, data_s_list, data_t_list, is_directed_graph=is_directed_graph) logger.info(f"Saved Graph Data in: {out_fpp}_*.[h5|.gpickle|.npz|.pt] ")
def main(args): """ Save nx.graph (Gss, Gts,...) and corresponding torch_geometric.data.PairData (via clevr_parse embedder api). """ if args.is_debug: set_default_level(10) is_directed_graph = args.is_directed_graph logger.debug(f"Parser flag is_directed_graph = {is_directed_graph}") graph_parser = clevr_parser.Parser( backend="spacy", model='en_core_web_sm', has_spatial=True, has_matching=True).get_backend(identifier='spacy') embedder = clevr_parser.Embedder( backend='torch', parser=graph_parser).get_backend(identifier='torch') raw_questions, img_scenes = get_questions_and_parsed_scenes( args.input_questions_json, args.input_parsed_img_scenes_json) logger.info('| importing questions from %s' % args.input_question_h5) input_questions = h5py.File(args.input_question_h5, 'r') #N = len(input_questions['questions']) # Baseline Entities # questions, programs, answers, question_families, orig_idxs, img_idxs = [], [], [], [], [], [] family_count = np.zeros(90) # Graphs and Embeddings # data_s_list = [] # List [torch_geometric.data.Data] data_t_list = [] # List [torch_geometric.data.Data] filename = get_output_filename(args) __all_question_families: np.ndarray = input_questions['question_families'][ ()] __all_enc_questions: np.ndarray = input_questions['questions'][()] __all_img_indices: np.ndarray = input_questions['image_idxs'][()] logger.debug(f"__all_question_families len {len(__all_question_families)}") # Sample N items for each 90 families # fam2indices = get_question_fam_to_indices(args) M = len(fam2indices.keys()) # 90 N = args.n_questions_per_family # 50 max_sample = N * M # 90 * 50 = 4500 family_count = np.zeros(M) # family_count = Counter() # TODO: accumulating values here need to be parallelized, and joined write ex-post num_skipped = 0 # Counter for tracking num of samples skipped for fam_idx, i_samples in enumerate(fam2indices): all_fam_samples = fam2indices[fam_idx] logger.debug( f"Question_family {fam_idx} has {len(all_fam_samples)} samples to choose {N} samples" ) N_question_sample_indices = np.random.choice( all_fam_samples, N, replace=False) # N.b seed is fixed assert len(N_question_sample_indices) == N # TODO: parallelize this iteration loop for i in N_question_sample_indices: try: img_idx = __all_img_indices[i] logger.debug( f"\tProcessing Image - {img_idx} from fam_idx {fam_idx}: {i} of {i_samples}" ) img_scene = list( filter(lambda x: x['image_index'] == img_idx, img_scenes))[0] except IndexError as ie: logger.warning(f"For {img_idx}: {ie}") num_skipped += 1 continue try: Gt, t_doc = graph_parser.get_doc_from_img_scene( img_scene, is_directed_graph=is_directed_graph) X_t, ei_t, e_attr_t = embedder.embed_t( img_idx, args.input_parsed_img_scenes_json) except AssertionError as ae: logger.warning(f"AssertionError Encountered: {ae}") logger.warning( f"[{img_idx}] Excluding images with > 10 objects") num_skipped += 1 continue if Gt is None and ("SKIP" in t_doc): # If the derendering pipeline failed, then just skip the # scene, don't process the labels (and text_scenes) for the image logger.warning(f"Got None img_doc at image_index: {img_idx}") print(f"Skipping all text_scenes for imgage idx: {img_idx}") num_skipped += 1 continue q_idx = input_questions['orig_idxs'][i] q_obj = list( filter(lambda x: x['question_index'] == q_idx, raw_questions))[0] assert q_obj['image_index'] == img_idx s = q_obj['question'] try: Gs, s_doc = graph_parser.parse( s, return_doc=True, is_directed_graph=is_directed_graph) X_s, ei_s, e_attr_s = embedder.embed_s(s) except ValueError as ve: logger.warning(f"ValueError Encountered: {ve}") logger.warning(f"Skipping question: {s} for {img_fn}") num_skipped += 1 continue if Gs is None and ("SKIP" in s_doc): logger.warning( "Got None as Gs and 'SKIP' in Gs_embd. (likely plural with CLEVR_OBJS label) " ) logger.warning(f"SKIPPING processing {s} at {img_idx}") num_skipped += 1 continue data_s = ClevrData(x=X_s, edge_index=ei_s, edge_attr=e_attr_s) data_t = ClevrData(x=X_t, edge_index=ei_t, edge_attr=e_attr_t) data_s_list.append(data_s) data_t_list.append(data_t) family_count[fam_idx] += 1 questions.append(input_questions['questions'][i]) programs.append(input_questions['programs'][i]) answers.append(input_questions['answers'][i]) question_families.append(input_questions['question_families'][i]) orig_idxs.append(input_questions['orig_idxs'][i]) img_idxs.append(img_idx) logger.info(f"\nCount = {family_count.sum()}\n") if family_count.sum() >= max_sample: break logger.debug( f"Total samples skipped (due to errors/exceptions) = {num_skipped}") # ---------------------------------------------------------------------------# ## SAVE .H5 if not os.path.isdir(args.output_dir): os.mkdir(args.output_dir) output_file = os.path.join(args.output_dir, filename) out_dir = args.output_dir out_f_prefix = filename.split('.')[0] out_fpp = f"{out_dir}/{out_f_prefix}" logger.debug(f"out_fpp = {out_fpp}") print('sampled question family distribution') print(family_count) print('| saving output file to %s' % output_file) with h5py.File(output_file, 'w') as f: f.create_dataset('questions', data=np.asarray(questions, dtype=np.int32)) f.create_dataset('programs', data=np.asarray(programs, dtype=np.int32)) f.create_dataset('answers', data=np.asarray(answers)) f.create_dataset('image_idxs', data=np.asarray(img_idxs)) f.create_dataset('orig_idxs', data=np.asarray(orig_idxs)) f.create_dataset('question_families', data=np.asarray(question_families)) ## ------------ SAVE GRAPH DATA ------------ ## save_graph_pairdata(out_fpp, data_s_list, data_t_list, is_directed_graph=is_directed_graph) logger.info(f"Saved Graph Data in: {out_fpp}_*.[h5|.gpickle|.npz|.pt] ") print('| done')