Exemple #1
0
def test_run(args):
    store = Store(args)
    anc = Ancestry(args, store)
    rb = RelationBuilder(args, store, anc)
    rb.num_rel = 3
    all_patterns = set()
    while True:
        for j in range(len(anc.family_data.keys())):
            rb.build()
            up = rb.unique_patterns()
            all_patterns.update(up)
            print(len(all_patterns))
            rb.reset_puzzle()
        if not rb.anc.next_flip():
            break
    print("Number of unique puzzles : {}".format(len(all_patterns)))

    rb.add_facts()
    rb.generate_puzzles()
    print("Generated {} puzzles".format(len(rb.puzzles)))
    pid = random.choice(list(rb.puzzles.keys()))
    print(rb.puzzles[pid])
Exemple #2
0
def generate_rows(args, store, task_name, split=0.8, prev_patterns=None):
    # pre-flight checks
    combination_length = min(args.combination_length, args.relation_length)
    if not args.use_mturk_template:
        if combination_length > 1:
            raise NotImplementedError("combination of two or more relations not implemented in Synthetic templating")
    else:
        if combination_length > 3:
            raise NotImplementedError("combinations of > 3 not implemented in AMT Templating")
    # generate
    print(args.relation_length)
    print("Loading templates...")
    all_puzzles = {}
    if args.template_split:
        train_templates = json.load(open(args.template_file + '.train.json'))
        test_templates = json.load(open(args.template_file + '.test.json'))
    else:
        train_templates = json.load(open(args.template_file + '.json'))
        test_templates = json.load(open(args.template_file + '.json'))
    if args.use_mturk_template:
        templatorClass = TemplatorAMT
    else:
        synthetic_templates_per_rel = {}
        for key, val in store.relations_store.items():
            for gender, gv in val.items():
                synthetic_templates_per_rel[gv['rel']] = gv['p']
        templatorClass = TemplatorSynthetic
        train_templates = synthetic_templates_per_rel
        test_templates = synthetic_templates_per_rel

    # Build a mapping from ANY relation to the SAME list of sentences for asking queries
    query_templates = {}
    for key, val in store.relations_store.items():
        for gender, gv in val.items():
            query_templates[gv['rel']] = store.question_store['relational']
    query_templator_class = TemplatorSynthetic

    pb = tqdm(total=args.num_rows)
    num_stories = args.num_rows
    stories_left = num_stories
    columns = ['id', 'story', 'query', 'text_query', 'target', 'text_target', 'clean_story', 'proof_state', 'f_comb',
               'task_name','story_edges','edge_types','query_edge','genders', 'syn_story', 'node_mapping', 'task_split']
    f_comb_count = {}
    rows = []
    anc_num = 0
    anc_num += 1
    anc = Ancestry(args, store)
    rb = RelationBuilder(args, store, anc)
    while stories_left > 0:
        status = rb.build()
        if not status:
            rb.reset_puzzle()
            rb.anc.next_flip()
            continue
        rb.add_facts()
        # keeping a count of generated patterns to make sure we have homogenous distribution
        if len(f_comb_count) > 0 and args.equal:
            min_c = min([v for k,v in f_comb_count.items()])
            weight = {k:(min_c/v) for k,v in f_comb_count.items()}
            rb.generate_puzzles(weight)
        else:
            rb.generate_puzzles()
        # if unique_test_pattern flag is set, and split is 0 (which indicates the task is test),
        # only take the same test patterns as before
        # also assert that the relation - test is present
        if args.unique_test_pattern and split == 0 and len(prev_patterns) > 0 and len(prev_patterns[args.relation_length]['test']) > 0:
            # if all these conditions met, prune the puzzles
            todel = []
            for pid,puzzle in rb.puzzles.items():
                if puzzle.relation_comb not in prev_patterns[args.relation_length]['test']:
                    todel.append(pid)
            for pid in todel:
                del rb.puzzles[pid]
        # now we have got the puzzles, assign the templators
        for pid, puzzle in rb.puzzles.items():
            if puzzle.relation_comb not in f_comb_count:
                f_comb_count[puzzle.relation_comb] = 0
            f_comb_count[puzzle.relation_comb] += 1
            pb.update(1)
            stories_left -= 1
        # store the puzzles
        all_puzzles.update(rb.puzzles)
        rb.reset_puzzle()
        rb.anc.next_flip()
    pb.close()
    print("Puzzles created. Now splitting train and test on pattern level")
    print("Number of unique puzzles : {}".format(len(all_puzzles)))
    pattern_puzzles = {}
    for pid, pz in all_puzzles.items():
        if pz.relation_comb not in pattern_puzzles:
            pattern_puzzles[pz.relation_comb] = []
        pattern_puzzles[pz.relation_comb].append(pid)
    print("Number of unique patterns : {}".format(len(pattern_puzzles)))
    train_puzzles = []
    test_puzzles = []
    sp = int(len(pattern_puzzles) * split)
    all_patterns = list(pattern_puzzles.keys())

    no_pattern_overlap = not args.holdout
    # if k=2, then set no_pattern_overlap=True
    if args.relation_length == 2:
        no_pattern_overlap = True

    if not no_pattern_overlap:
        # for case > 3, strict no pattern overlap
        train_patterns = all_patterns[:sp]
        pzs = [pattern_puzzles[p] for p in train_patterns]
        pzs = [s for p in pzs for s in p]
        train_puzzles.extend(pzs)
        test_patterns = all_patterns[sp:]
        pzs = [pattern_puzzles[p] for p in test_patterns]
        pzs = [s for p in pzs for s in p]
        test_puzzles.extend(pzs)
    else:
        # for case of 2, pattern overlap but templators are different
        # In this case, we have overlapping patterns, first choose the overlapping patterns
        # we directly split on puzzle level
        train_patterns = all_patterns
        test_patterns = all_patterns[sp:]
        pzs_train = []
        pzs_test = []
        for pattern in all_patterns:
            pz = pattern_puzzles[pattern]
            if pattern in test_patterns:
                # now split - hacky way
                sz = int(len(pz) * (split - 0.2))
                pzs_train.extend(pz[:sz])
                pzs_test.extend(pz[sz:])
            else:
                pzs_train.extend(pz)
        train_puzzles.extend(pzs_train)
        test_puzzles.extend(pzs_test)

    print("# Train puzzles : {}".format(len(train_puzzles)))
    print("# Test puzzles : {}".format(len(test_puzzles)))
    pb = tqdm(total=len(all_puzzles))
    # saving in csv
    for pid, puzzle in all_puzzles.items():
        task_split = ''
        if pid in train_puzzles:
            task_split = 'train'
            templator = templatorClass(templates=train_templates, family=puzzle.anc.family_data)
        elif pid in test_puzzles:
            task_split = 'test'
            templator = templatorClass(templates=test_templates, family=puzzle.anc.family_data)
        else:
            AssertionError("pid must be either in train or test")
        story_text = puzzle.generate_text(stype='story', combination_length=combination_length, templator=templator)
        fact_text = puzzle.generate_text(stype='fact', combination_length=combination_length, templator=templator)
        story = story_text + fact_text
        story = random.sample(story, len(story))
        story = ' '.join(story)
        clean_story = ' '.join(story_text)
        target_text = puzzle.generate_text(stype='target', combination_length=1, templator=templator)

        story_key_edges = puzzle.get_story_relations(stype='story') + puzzle.get_story_relations(stype='fact')
        # Build query text
        query_templator = query_templator_class(templates=query_templates, family=puzzle.anc.family_data)
        query_text = puzzle.generate_text(stype='query', combination_length=1, templator=query_templator)
        query_text = ' '.join(query_text)
        query_text = query_text.replace('?.', '?')  # remove trailing '.'
        puzzle.convert_node_ids(stype='story')
        puzzle.convert_node_ids(stype='fact')
        story_keys_changed_ids = puzzle.get_sorted_story_edges(stype='story') + puzzle.get_sorted_story_edges(stype='fact')
        query_edge = puzzle.get_sorted_query_edge()

        genders = puzzle.get_name_gender_string()

        rows.append([pid, story, puzzle.query_text, query_text, puzzle.target_edge_rel, target_text,
                     clean_story, puzzle.proof_trace, puzzle.relation_comb, task_name, story_keys_changed_ids,
                     story_key_edges, query_edge, genders, '', puzzle.story_sort_dict, task_split])
        pb.update(1)
    pb.close()

    print("{} ancestries created".format(anc_num))
    print("Number of unique patterns : {}".format(len(f_comb_count)))
    return columns, rows, all_puzzles, train_patterns, test_patterns
Exemple #3
0
def generate_rows(args, store, task_name):
    # generate
    print(args.relation_length)
    print("Loading templates...")
    templates = json.load(open(args.template_file))
    pb = tqdm(total=args.num_rows)
    num_stories = args.num_rows
    stories_left = num_stories
    columns = [
        'id', 'story', 'query', 'text_query', 'target', 'text_target',
        'clean_story', 'proof_state', 'f_comb', 'task_name', 'story_edges',
        'edge_types', 'query_edge', 'genders', 'syn_story'
    ]
    f_comb_count = {}
    rows = []
    anc_num = 0
    anc_num += 1
    anc = Ancestry(args, store)
    rb = RelationBuilder(args, store, anc)
    while stories_left > 0:
        status = rb.build()
        if not status:
            rb.reset_puzzle()
            rb.anc.next_flip()
            continue
        rb.add_facts()
        # keeping a count of generated patterns to make sure we have homogenous distribution
        if len(f_comb_count) > 0 and args.equal:
            min_c = min([v for k, v in f_comb_count.items()])
            weight = {k: (min_c / v) for k, v in f_comb_count.items()}
            rb.generate_puzzles(weight)
        else:
            rb.generate_puzzles()
        # now we have got the puzzles, add them to the story
        for pid, puzzle in rb.puzzles.items():
            story_edges = puzzle['text_story']  # dict of edge:text
            clean_story = ''.join(
                [puzzle['text_story'][e] for e in story_edges])
            noise_edge_list = [
                v for k, v in puzzle.items() if 'text_fact' in k
            ]
            for d in noise_edge_list:
                if type(d) != dict:
                    print(d)
                    print(noise_edge_list)
                    raise AssertionError()
                story_edges.update(d)  # adds the noise edge:text
            #noise = [y for x in noise for y in x] # flatten
            #story += noise
            story_keys = random.sample(list(story_edges.keys()),
                                       len(story_edges))
            story = ''.join([story_edges[k] for k in story_keys])
            story_key_edges = [rb.get_edge_relation(k) for k in story_keys]
            all_edge_rows = []
            all_edge_rows.append(puzzle['story'])
            all_edge_rows.extend(puzzle['all_noise'])

            # Templating Logic
            # all_edge_rows = list of two list : [story, noise]
            # where, story = list of edges, noise = list of edges
            # story and noise = sequence

            templated_rows = []
            if args.use_mturk_template:
                for seq in all_edge_rows:
                    #print(seq)
                    # find all grouping combinations
                    group_combs = comb_indexes(seq, args.template_length)
                    #print(group_combs)
                    temp_rows = []
                    temp_user = TemplateUser(templates=templates,
                                             family=rb.anc.family_data)
                    for group in group_combs:
                        try:
                            fcombs = [
                                '-'.join([
                                    rb.get_edge_relation(edge)
                                    for edge in edge_group
                                ]) for edge_group in group
                            ]
                            fentities = [[
                                ent for edge in edge_group for ent in edge
                            ] for edge_group in group]
                            prows = [
                                temp_user.replace_template(
                                    edge_group, fentities[group_id])
                                for group_id, edge_group in enumerate(fcombs)
                            ]
                            temp_rows.append((group, prows))
                        except:
                            pass
                    #print(len(temp_rows))
                    if len(temp_rows) == 0:
                        print(group_combs)
                        print(all_edge_rows)
                        print(seq)
                        fcombs = [
                            '-'.join([
                                rb.get_edge_relation(edge)
                                for edge in edge_group
                            ]) for edge_group in group
                        ]
                        fentities = [[
                            ent for edge in edge_group for ent in edge
                        ] for edge_group in group]
                        prows = [
                            temp_user.replace_template(edge_group,
                                                       fentities[group_id],
                                                       verbose=True)
                            for group_id, edge_group in enumerate(fcombs)
                        ]
                    chosen_row = random.choice(temp_rows)
                    #print('chosen row', chosen_row)
                    templated_rows.append(chosen_row)

                templated_rows = [row[-1] for row in templated_rows]
                # flatten
                templated_rows = [xt for t in templated_rows for xt in t]

            ## The same thing above without the try catch block
            '''
            random_combs = [choose_random_subsequence(ae, args.template_length) for ae in all_edge_rows]
            random_f_combs = [['-'.join([rb.get_edge_relation(edge) for edge in cr]) for cr in row] for row in random_combs]
            random_entities = [[[e for edge in cr for e in edge] for cr in row] for row in random_combs]
            placed_rows = [[replace_template(templates, cr, random_entities[row_id][c_id], rb.anc.family_data)
                            for c_id, cr in enumerate(row)] for row_id,row in enumerate(random_f_combs)]
            print('a',all_edge_rows)
            print(random_combs)
            print(random_f_combs)
            print(random_entities)
            '''
            # convert edge list into e_0, e_1
            node_ct = 0
            node_id_dict = {}
            for key in story_keys:
                if key[0] not in node_id_dict:
                    node_id_dict[key[0]] = node_ct
                    node_ct += 1
                if key[1] not in node_id_dict:
                    node_id_dict[key[1]] = node_ct
                    node_ct += 1
            story_keys_changed_id = [(node_id_dict[key[0]],
                                      node_id_dict[key[1]])
                                     for key in story_keys]
            # add the query edges with respect to the same id
            query_edge = (node_id_dict[puzzle['query'][0]],
                          node_id_dict[puzzle['query'][1]])
            text_question = rb.generate_question(puzzle['query'])
            # also store the gender for postprocessing
            genders = ','.join([
                '{}:{}'.format(rb.anc.family_data[node_id].name,
                               rb.anc.family_data[node_id].gender)
                for node_id in node_id_dict.keys()
            ])
            if puzzle['f_comb'] not in f_comb_count:
                f_comb_count[puzzle['f_comb']] = 0
            f_comb_count[puzzle['f_comb']] += 1
            stories_left -= 1
            if stories_left < 0:
                break

            syn_story = ''
            if args.use_mturk_template:
                syn_story = story
                story = ' '.join(templated_rows)
            rows.append([
                pid, story, puzzle['query_text'], text_question,
                puzzle['target'], puzzle['text_target'], clean_story,
                puzzle['proof'], puzzle['f_comb'], task_name,
                story_keys_changed_id, story_key_edges, query_edge, genders,
                syn_story
            ])
            pb.update(1)
        rb.reset_puzzle()
        rb.anc.next_flip()
    pb.close()
    print("{} ancestries created".format(anc_num))
    print("Number of unique patterns : {}".format(len(f_comb_count)))
    return columns, rows