示例#1
0
    def generate(self, choice, args, num_rows=0, data_type='train', multi=False, split=None):
        """
        Choose the task and the relation length
        Return the used args for storing
        :param choice:
        :param args:
        :param num_rows:
        :param data_type:
        :param multi:
        :return:
        """
        args = copy.deepcopy(args)
        args.num_rows = num_rows
        args.data_type = data_type
        if not multi:
            task, relation_length = choice.split('.')
            task_name = 'task_{}'.format(task)
            logger.info("mode : {}, task : {}, rel_length : {}".format(data_type, task_name, relation_length))
            task_method = getattr(self, task_name, lambda: "Task {} not implemented".format(choice))
            args = task_method(args)
            args.relation_length = int(relation_length)
            store = Store(args)
            columns, rows, all_puzzles, train_patterns, test_patterns = generate_rows(args,
                        store, task_name  + '.{}'.format(relation_length), split=split, prev_patterns=self.unique_patterns)
            self.unique_patterns[int(relation_length)] = {
                'train': train_patterns,
                'test': test_patterns
            }
            return (columns, rows, all_puzzles), args

        else:
            rows = []
            columns = []
            puzzles = {}
            for ch in choice:
                task, relation_length = ch.split('.')
                task_name = 'task_{}'.format(task)
                logger.info("task : {}, rel_length : {}".format(task_name, relation_length))
                task_method = getattr(self, task_name, lambda: "Task {} not implemented".format(choice))
                args = task_method(args)
                args.relation_length = int(relation_length)
                store = Store(args)
                columns,r,pz = generate_rows(args, store, task_name + '.{}'.format(relation_length))
                rows.extend(r)
                puzzles.update(pz)
            return ((columns, rows, puzzles), args)
示例#2
0
def main(args):
    store = Store(args)
    header, rows = generate_rows(args, store)
    df = pd.DataFrame(columns=header, data=rows)
    # split test train
    msk = np.random.rand(len(df)) > args.test
    train_df = df[msk]
    test_df = df[~msk]
    train_df.to_csv(args.output + '_train.csv')
    test_df.to_csv(args.output + '_test.csv')
示例#3
0
 def generate(self, choice, args, num_rows=0, data_type='train', multi=False):
     """
     Choose the task and the relation length
     Return the used args for storing
     :param choice:
     :param args:
     :param num_rows:
     :param data_type:
     :param multi:
     :return:
     """
     args = copy.deepcopy(args)
     args.num_rows = num_rows
     args.data_type = data_type
     if not multi:
         task, relation_length = choice.split('.')
         task_name = 'task_{}'.format(task)
         logger.info("mode : {}, task : {}, rel_length : {}".format(data_type, task_name, relation_length))
         task_method = getattr(self, task_name, lambda: "Task {} not implemented".format(choice))
         args = task_method(args)
         args.relation_length = int(relation_length)
         store = Store(args)
         return (generate_rows(args, store, task_name  + '.{}'.format(relation_length)), args)
     else:
         rows = []
         columns = []
         for ch in choice:
             task, relation_length = ch.split('.')
             task_name = 'task_{}'.format(task)
             logger.info("mode : {}, task : {}, rel_length : {}".format(data_type, task_name, relation_length))
             task_method = getattr(self, task_name, lambda: "Task {} not implemented".format(choice))
             args = task_method(args)
             args.relation_length = int(relation_length)
             store = Store(args)
             columns,r = generate_rows(args, store, task_name + '.{}'.format(relation_length))
             rows.extend(r)
         return ((columns, rows), args)
示例#4
0
def test_run(args):
    store = Store(args)
    anc = Ancestry(args, store)
    rb = RelationBuilder(args, store, anc)
    rb.num_rel = 3
    all_patterns = set()
    while True:
        for j in range(len(anc.family_data.keys())):
            rb.build()
            up = rb.unique_patterns()
            all_patterns.update(up)
            print(len(all_patterns))
            rb.reset_puzzle()
        if not rb.anc.next_flip():
            break
    print("Number of unique puzzles : {}".format(len(all_patterns)))

    rb.add_facts()
    rb.generate_puzzles()
    print("Generated {} puzzles".format(len(rb.puzzles)))
    pid = random.choice(list(rb.puzzles.keys()))
    print(rb.puzzles[pid])