def test_run(args): store = Store(args) anc = Ancestry(args, store) rb = RelationBuilder(args, store, anc) rb.num_rel = 3 all_patterns = set() while True: for j in range(len(anc.family_data.keys())): rb.build() up = rb.unique_patterns() all_patterns.update(up) print(len(all_patterns)) rb.reset_puzzle() if not rb.anc.next_flip(): break print("Number of unique puzzles : {}".format(len(all_patterns))) rb.add_facts() rb.generate_puzzles() print("Generated {} puzzles".format(len(rb.puzzles))) pid = random.choice(list(rb.puzzles.keys())) print(rb.puzzles[pid])
def generate_rows(args, store, task_name, split=0.8, prev_patterns=None): # pre-flight checks combination_length = min(args.combination_length, args.relation_length) if not args.use_mturk_template: if combination_length > 1: raise NotImplementedError("combination of two or more relations not implemented in Synthetic templating") else: if combination_length > 3: raise NotImplementedError("combinations of > 3 not implemented in AMT Templating") # generate print(args.relation_length) print("Loading templates...") all_puzzles = {} if args.template_split: train_templates = json.load(open(args.template_file + '.train.json')) test_templates = json.load(open(args.template_file + '.test.json')) else: train_templates = json.load(open(args.template_file + '.json')) test_templates = json.load(open(args.template_file + '.json')) if args.use_mturk_template: templatorClass = TemplatorAMT else: synthetic_templates_per_rel = {} for key, val in store.relations_store.items(): for gender, gv in val.items(): synthetic_templates_per_rel[gv['rel']] = gv['p'] templatorClass = TemplatorSynthetic train_templates = synthetic_templates_per_rel test_templates = synthetic_templates_per_rel # Build a mapping from ANY relation to the SAME list of sentences for asking queries query_templates = {} for key, val in store.relations_store.items(): for gender, gv in val.items(): query_templates[gv['rel']] = store.question_store['relational'] query_templator_class = TemplatorSynthetic pb = tqdm(total=args.num_rows) num_stories = args.num_rows stories_left = num_stories columns = ['id', 'story', 'query', 'text_query', 'target', 'text_target', 'clean_story', 'proof_state', 'f_comb', 'task_name','story_edges','edge_types','query_edge','genders', 'syn_story', 'node_mapping', 'task_split'] f_comb_count = {} rows = [] anc_num = 0 anc_num += 1 anc = Ancestry(args, store) rb = RelationBuilder(args, store, anc) while stories_left > 0: status = rb.build() if not status: rb.reset_puzzle() rb.anc.next_flip() continue rb.add_facts() # keeping a count of generated patterns to make sure we have homogenous distribution if len(f_comb_count) > 0 and args.equal: min_c = min([v for k,v in f_comb_count.items()]) weight = {k:(min_c/v) for k,v in f_comb_count.items()} rb.generate_puzzles(weight) else: rb.generate_puzzles() # if unique_test_pattern flag is set, and split is 0 (which indicates the task is test), # only take the same test patterns as before # also assert that the relation - test is present if args.unique_test_pattern and split == 0 and len(prev_patterns) > 0 and len(prev_patterns[args.relation_length]['test']) > 0: # if all these conditions met, prune the puzzles todel = [] for pid,puzzle in rb.puzzles.items(): if puzzle.relation_comb not in prev_patterns[args.relation_length]['test']: todel.append(pid) for pid in todel: del rb.puzzles[pid] # now we have got the puzzles, assign the templators for pid, puzzle in rb.puzzles.items(): if puzzle.relation_comb not in f_comb_count: f_comb_count[puzzle.relation_comb] = 0 f_comb_count[puzzle.relation_comb] += 1 pb.update(1) stories_left -= 1 # store the puzzles all_puzzles.update(rb.puzzles) rb.reset_puzzle() rb.anc.next_flip() pb.close() print("Puzzles created. Now splitting train and test on pattern level") print("Number of unique puzzles : {}".format(len(all_puzzles))) pattern_puzzles = {} for pid, pz in all_puzzles.items(): if pz.relation_comb not in pattern_puzzles: pattern_puzzles[pz.relation_comb] = [] pattern_puzzles[pz.relation_comb].append(pid) print("Number of unique patterns : {}".format(len(pattern_puzzles))) train_puzzles = [] test_puzzles = [] sp = int(len(pattern_puzzles) * split) all_patterns = list(pattern_puzzles.keys()) no_pattern_overlap = not args.holdout # if k=2, then set no_pattern_overlap=True if args.relation_length == 2: no_pattern_overlap = True if not no_pattern_overlap: # for case > 3, strict no pattern overlap train_patterns = all_patterns[:sp] pzs = [pattern_puzzles[p] for p in train_patterns] pzs = [s for p in pzs for s in p] train_puzzles.extend(pzs) test_patterns = all_patterns[sp:] pzs = [pattern_puzzles[p] for p in test_patterns] pzs = [s for p in pzs for s in p] test_puzzles.extend(pzs) else: # for case of 2, pattern overlap but templators are different # In this case, we have overlapping patterns, first choose the overlapping patterns # we directly split on puzzle level train_patterns = all_patterns test_patterns = all_patterns[sp:] pzs_train = [] pzs_test = [] for pattern in all_patterns: pz = pattern_puzzles[pattern] if pattern in test_patterns: # now split - hacky way sz = int(len(pz) * (split - 0.2)) pzs_train.extend(pz[:sz]) pzs_test.extend(pz[sz:]) else: pzs_train.extend(pz) train_puzzles.extend(pzs_train) test_puzzles.extend(pzs_test) print("# Train puzzles : {}".format(len(train_puzzles))) print("# Test puzzles : {}".format(len(test_puzzles))) pb = tqdm(total=len(all_puzzles)) # saving in csv for pid, puzzle in all_puzzles.items(): task_split = '' if pid in train_puzzles: task_split = 'train' templator = templatorClass(templates=train_templates, family=puzzle.anc.family_data) elif pid in test_puzzles: task_split = 'test' templator = templatorClass(templates=test_templates, family=puzzle.anc.family_data) else: AssertionError("pid must be either in train or test") story_text = puzzle.generate_text(stype='story', combination_length=combination_length, templator=templator) fact_text = puzzle.generate_text(stype='fact', combination_length=combination_length, templator=templator) story = story_text + fact_text story = random.sample(story, len(story)) story = ' '.join(story) clean_story = ' '.join(story_text) target_text = puzzle.generate_text(stype='target', combination_length=1, templator=templator) story_key_edges = puzzle.get_story_relations(stype='story') + puzzle.get_story_relations(stype='fact') # Build query text query_templator = query_templator_class(templates=query_templates, family=puzzle.anc.family_data) query_text = puzzle.generate_text(stype='query', combination_length=1, templator=query_templator) query_text = ' '.join(query_text) query_text = query_text.replace('?.', '?') # remove trailing '.' puzzle.convert_node_ids(stype='story') puzzle.convert_node_ids(stype='fact') story_keys_changed_ids = puzzle.get_sorted_story_edges(stype='story') + puzzle.get_sorted_story_edges(stype='fact') query_edge = puzzle.get_sorted_query_edge() genders = puzzle.get_name_gender_string() rows.append([pid, story, puzzle.query_text, query_text, puzzle.target_edge_rel, target_text, clean_story, puzzle.proof_trace, puzzle.relation_comb, task_name, story_keys_changed_ids, story_key_edges, query_edge, genders, '', puzzle.story_sort_dict, task_split]) pb.update(1) pb.close() print("{} ancestries created".format(anc_num)) print("Number of unique patterns : {}".format(len(f_comb_count))) return columns, rows, all_puzzles, train_patterns, test_patterns
def generate_rows(args, store, task_name): # generate print(args.relation_length) print("Loading templates...") templates = json.load(open(args.template_file)) pb = tqdm(total=args.num_rows) num_stories = args.num_rows stories_left = num_stories columns = [ 'id', 'story', 'query', 'text_query', 'target', 'text_target', 'clean_story', 'proof_state', 'f_comb', 'task_name', 'story_edges', 'edge_types', 'query_edge', 'genders', 'syn_story' ] f_comb_count = {} rows = [] anc_num = 0 anc_num += 1 anc = Ancestry(args, store) rb = RelationBuilder(args, store, anc) while stories_left > 0: status = rb.build() if not status: rb.reset_puzzle() rb.anc.next_flip() continue rb.add_facts() # keeping a count of generated patterns to make sure we have homogenous distribution if len(f_comb_count) > 0 and args.equal: min_c = min([v for k, v in f_comb_count.items()]) weight = {k: (min_c / v) for k, v in f_comb_count.items()} rb.generate_puzzles(weight) else: rb.generate_puzzles() # now we have got the puzzles, add them to the story for pid, puzzle in rb.puzzles.items(): story_edges = puzzle['text_story'] # dict of edge:text clean_story = ''.join( [puzzle['text_story'][e] for e in story_edges]) noise_edge_list = [ v for k, v in puzzle.items() if 'text_fact' in k ] for d in noise_edge_list: if type(d) != dict: print(d) print(noise_edge_list) raise AssertionError() story_edges.update(d) # adds the noise edge:text #noise = [y for x in noise for y in x] # flatten #story += noise story_keys = random.sample(list(story_edges.keys()), len(story_edges)) story = ''.join([story_edges[k] for k in story_keys]) story_key_edges = [rb.get_edge_relation(k) for k in story_keys] all_edge_rows = [] all_edge_rows.append(puzzle['story']) all_edge_rows.extend(puzzle['all_noise']) # Templating Logic # all_edge_rows = list of two list : [story, noise] # where, story = list of edges, noise = list of edges # story and noise = sequence templated_rows = [] if args.use_mturk_template: for seq in all_edge_rows: #print(seq) # find all grouping combinations group_combs = comb_indexes(seq, args.template_length) #print(group_combs) temp_rows = [] temp_user = TemplateUser(templates=templates, family=rb.anc.family_data) for group in group_combs: try: fcombs = [ '-'.join([ rb.get_edge_relation(edge) for edge in edge_group ]) for edge_group in group ] fentities = [[ ent for edge in edge_group for ent in edge ] for edge_group in group] prows = [ temp_user.replace_template( edge_group, fentities[group_id]) for group_id, edge_group in enumerate(fcombs) ] temp_rows.append((group, prows)) except: pass #print(len(temp_rows)) if len(temp_rows) == 0: print(group_combs) print(all_edge_rows) print(seq) fcombs = [ '-'.join([ rb.get_edge_relation(edge) for edge in edge_group ]) for edge_group in group ] fentities = [[ ent for edge in edge_group for ent in edge ] for edge_group in group] prows = [ temp_user.replace_template(edge_group, fentities[group_id], verbose=True) for group_id, edge_group in enumerate(fcombs) ] chosen_row = random.choice(temp_rows) #print('chosen row', chosen_row) templated_rows.append(chosen_row) templated_rows = [row[-1] for row in templated_rows] # flatten templated_rows = [xt for t in templated_rows for xt in t] ## The same thing above without the try catch block ''' random_combs = [choose_random_subsequence(ae, args.template_length) for ae in all_edge_rows] random_f_combs = [['-'.join([rb.get_edge_relation(edge) for edge in cr]) for cr in row] for row in random_combs] random_entities = [[[e for edge in cr for e in edge] for cr in row] for row in random_combs] placed_rows = [[replace_template(templates, cr, random_entities[row_id][c_id], rb.anc.family_data) for c_id, cr in enumerate(row)] for row_id,row in enumerate(random_f_combs)] print('a',all_edge_rows) print(random_combs) print(random_f_combs) print(random_entities) ''' # convert edge list into e_0, e_1 node_ct = 0 node_id_dict = {} for key in story_keys: if key[0] not in node_id_dict: node_id_dict[key[0]] = node_ct node_ct += 1 if key[1] not in node_id_dict: node_id_dict[key[1]] = node_ct node_ct += 1 story_keys_changed_id = [(node_id_dict[key[0]], node_id_dict[key[1]]) for key in story_keys] # add the query edges with respect to the same id query_edge = (node_id_dict[puzzle['query'][0]], node_id_dict[puzzle['query'][1]]) text_question = rb.generate_question(puzzle['query']) # also store the gender for postprocessing genders = ','.join([ '{}:{}'.format(rb.anc.family_data[node_id].name, rb.anc.family_data[node_id].gender) for node_id in node_id_dict.keys() ]) if puzzle['f_comb'] not in f_comb_count: f_comb_count[puzzle['f_comb']] = 0 f_comb_count[puzzle['f_comb']] += 1 stories_left -= 1 if stories_left < 0: break syn_story = '' if args.use_mturk_template: syn_story = story story = ' '.join(templated_rows) rows.append([ pid, story, puzzle['query_text'], text_question, puzzle['target'], puzzle['text_target'], clean_story, puzzle['proof'], puzzle['f_comb'], task_name, story_keys_changed_id, story_key_edges, query_edge, genders, syn_story ]) pb.update(1) rb.reset_puzzle() rb.anc.next_flip() pb.close() print("{} ancestries created".format(anc_num)) print("Number of unique patterns : {}".format(len(f_comb_count))) return columns, rows