Exemplo n.º 1
0
def worker(args, queue, worker_id):
    #lm = config.lm
    if args.partial_lm:
        lm = build_language_model(args.num_props, new=args.gen_lm, _all=False)
    else:
        lm = load_language_model(_all=False, new=args.gen_lm)
        if args.data_sampling > 0:
            load_proof_steps_into_lm(lm, ['train'], args.data_sampling)
    #lm = load_language_model(_all=False, new=args.gen_lm)
    #load_proof_steps_into_lm(lm, ['train'], args.data_sampling)
    config = get_config(lm)
    device_idx = args.num_gpus - 1 - (worker_id // args.num_workers)
    if args.random_generate:
        device_idx = 0
    device_id = args.gpu_list[device_idx]
    print ('build a worker on gpu %d' % (device_id))
    torch.cuda.set_device(device_id)
    args.device = torch.device('cuda:'+str(device_id))
    interface = None if args.random_generate else interface_lm.LMInterface(args, lm)
    generator = Constructor(args, config, interface)
    generator.initialize()
    _logger = log.get_logger('worker%d'%(worker_id), args, append=True)
    _logger.info('worker %d initialize', worker_id)
    tt = 0
    cnt = 0
    while True:
        t = time.time()
        if args.random_generate:
            expr = generator.random_generate()
        else:
            expr = generator.parameterized_generate()
        tt += time.time() - t
        if expr is not None:
            if args.task == 'pred' and len(expr.prop.e) > 0:
                queue.put(generator.encode_pred_tasks([expr]))
            if args.task == 'gen' and len(expr.unconstrained) > 0:
                queue.put(generator.encode_gen_tasks([expr]))
            #if not (args.task == 'gen' and len(expr.unconstrained) == 0):
            #    if args.task == 'pred':
            #        queue.put(generator.encode_pred_tasks([expr]))
            #    else:
            #        queue.put(generator.encode_gen_tasks([expr]))
            if len(generator.expressions_list) > args.num_cons_exprs+generator.num_initial_expr:
                generator.reinitialize_expressions()
                _logger.info('worker %d initialize', worker_id)
            if cnt == 5000:
                _logger.info('worker %d generate time per expr %s seconds', worker_id, tt/cnt)
                cnt = 0
                tt = 0
            cnt += 1
Exemplo n.º 2
0
def DataLoader(args, outqueue, inqueue, batch_size=None):

    if batch_size == None:
        batch_size = args.batch_size
    if args.partial_lm:
        lm = build_language_model(args.num_props, new=args.gen_lm, _all=False)
    else:
        lm = load_language_model(_all=False, new=args.gen_lm)
        if args.data_sampling > 0:
            load_proof_steps_into_lm(lm, ['train'], args.data_sampling)
    config = get_config(lm)
    args.device = torch.device('cuda:%d' % (args.gpu_list[0]))
    torch.cuda.set_device(args.gpu_list[0])
    generator = Constructor(args, config)
    generator.initialize()
    old_data = [[],[],[],[],[],[]] if args.task == 'gen' else [[],[],[],[]]
    _logger = log.get_logger('dataloader', args, append=True)
    _logger.info('dataloader is ready')
    while True:
        batch = [[],[],[],[],[],[]] if args.task == 'gen' else [[],[],[],[]]
        while len(batch[0]) < batch_size:
            if len(old_data[0]) < args.num_old_exprs or random.random() < args.new_expr_ratio:
                # use new exprs
                data = inqueue.get()
                for i in range(len(data)):
                    batch[i] += data[i]
                    if args.task == 'gen':
                        old_data[i] += data[i]
                    else:
                        old_data[i].append(data[i])
                if len(old_data[0]) > args.num_old_exprs:
                    for d in old_data:
                        d.pop(0)
            else:
                # use old exprs
                k = random.randrange(len(old_data[0]))
                for i in range(len(batch)):
                    if args.task == 'gen':
                        batch[i].append( old_data[i][k] )
                    else:
                        batch[i] += old_data[i][k]
            #_logger.info('%d %d %d', outqueue.qsize(), inqueue.qsize(), len(batch[0]))
            #print (len(batch))
        #if args.task == 'pred':
        #    data = generator.encode_pred_tasks(exprs)
        #else:
        #    data = generator.encode_gen_tasks(exprs)
        outqueue.put(batch)
Exemplo n.º 3
0
 def __init__(self, args, config):
     tmp = args.data_sampling
     args.data_sampling = 1
     generator = Constructor(args, config)
     generator.initialize()
     args.data_sampling = tmp
     replacement_dict = config.lm.deterministic_replacement_dict_f(
         f=generator.all_f)
     self.searcher = TrieTree([])
     for e in generator.expressions_list:
         if e.is_hyps != 2:
             t, _ = generator.encode_expr(
                 e.tree,
                 [generator.expressions_list[i].tree
                  for i in e.hyps], replacement_dict)
             self.searcher.insert(t)
     return
Exemplo n.º 4
0
if __name__ == "__main__":
    args = params.get_args()
    _logger = log.get_logger(__name__, args)
    _logger.info(print_args(args))

    if args.partial_lm:
        lm = build_language_model(args.num_props, new=args.gen_lm, _all=False)
    else:
        lm = load_language_model(_all=False, new=args.gen_lm)
        load_proof_steps_into_lm(lm, ['train'])
    config = get_config(lm)
    args.vocab_size = len(config.encode) + 1

    interface = interface_rl.LMInterface(args, lm)
    generator = Constructor(args, config, interface)
    generator.initialize()
    print(generator.prop_dist)
    fake_reward_fn = FakeReward(args, config)
    reward_fn = LMReward(args, config)
    if args.fake_reward:
        reward_fn = fake_reward_fn
    opt = get_opt({
        'pred': interface.pred_model,
        'gen': interface.gen_model
    }, args)
    rewards_all = 0
    for e in range(args.epoch):
        rewards_all += train(args, generator, reward_fn, fake_reward_fn, opt,
                             e, _logger)
        print('acc rewards', rewards_all)
        aux = {'epoch': e, 'cur_iter': 1}
Exemplo n.º 5
0
def worker_pre(args, queue, batch_size, ii):
    _logger = log.get_logger('worker_pre', args, append=True)
    _logger.info('worker_pre initialize')
    if args.partial_lm:
        lm = build_language_model(args.num_props, new=args.gen_lm, _all=False)
    else:
        lm = load_language_model(_all=False, new=args.gen_lm, iset=args.iset)
    config = get_config(lm)
    #exprs = load_exprs(args.expr_list[0], lm)
    #interface = interface_lm.LMInterface(args, lm)
    #generator = Constructor(args, config)
    #generator.initialize_prop()
    #generator.expressions_list = exprs
    #for e in exprs:
    #    generator.expressions[e.id] = e
    #generator.num_initial_expr = len(generator.expressions)
    #generator.initialize_searcher()
    #_logger.info('initialize generator with %d exprs', generator.num_initial_expr)

    fl = os.listdir(args.exprs_pre)
    if True:
        if args.data_sampling > 0:
            exprs = load_exprs(args.expr_list[0], lm)
            generator = Constructor(args, config)
            generator.initialize_prop()
            generator.expressions_list = exprs
            for e in exprs:
                generator.expressions[e.id] = e
            generator.num_initial_expr = len(generator.expressions)
            generator.initialize_searcher()
        else:
            generator = Constructor(args, config)
            generator.initialize()
        #generator.reinitialize_expressions()
        print ('--loading pre exprs--')
        exprs_pre = load_exprs(os.path.join(args.exprs_pre, fl[ii]), lm)
        print ('--done--')
        #if args.train_from_queue:
        #    exprs_pre = exprs_pre[:300000]
        generator.expressions_list += exprs_pre
        for e in exprs_pre:
            generator.expressions[e.id] = e
        _logger.info('load %d exprs' % (len(generator.expressions)))
        i = 0
        while True:
        #for i in range(len(exprs_pre)//batch_size//5):
            _exprs = []
            while len(_exprs) < batch_size:
                expr = random.choice(exprs_pre)
                if args.task == 'gen' and len(expr.unconstrained) == 0:
                    continue
                if args.task == 'pred' and len(expr.prop.e) == 0:
                    continue
                _exprs.append(expr)
            if args.task == 'pred':
                data = generator.encode_pred_tasks(_exprs)
            else:
                data = generator.encode_gen_tasks(_exprs)
            #print (i, data)
            queue.put(data)
            i += 1
            if i >= len(exprs_pre)//batch_size//5 and (not args.train_from_queue) and (not args.cons_pre_one):
                break
        print ('finish processing current exprs')
Exemplo n.º 6
0
 def __init__(self, args, worker_id, split='train', task='pred'):
     self.args = args
     self.split = split
     self.task = task
     #self.generator = generator
     self.file_list = [
         s for s in os.listdir(os.path.join(args.data_path, split))
         if s.isdigit()
     ]
     per_worker = len(self.file_list) // args.num_workers + 1
     self.file_list = self.file_list[worker_id * per_worker:min(
         (worker_id + 1) * per_worker, len(self.file_list))]
     self.length = 0
     self.steps = None
     self.step_hyps = None
     if self.args.partial_lm:
         self.lm = build_language_model(self.args.num_props,
                                        new=self.args.gen_lm,
                                        _all=False)
         self.cnt = 0
         self.cur_file = 0
         self.cur_epoch = 0
         self.config = get_config(self.lm)
         #self.exprs = load_exprs(self.lm)
         generator = Constructor(args, self.config)
         generator.initialize()
         self.exprs = generator.expressions_list
         self.step_expr_pos = generator.step_expr_pos
         self.prop_hyps_pos = generator.prop_hyps_pos
         self.all_f = generator.all_f
         if self.split == 'train':
             self.propositions = self.lm.training_propositions
             #self.steps = self.lm.training_proof_steps
         elif self.split == 'valid':
             self.propositions = self.lm.validation_propositions
             #self.steps = self.lm.validation_proof_steps
         else:
             self.propositions = self.lm.test_propositions
             #self.steps = self.lm.test_proof_steps
         self.load_forward_info()
         self.steps = []
         steps = []
         if self.args.data_sampling > 0:
             num = int(self.args.data_sampling * 10)
         for i, p in enumerate(self.propositions):
             if self.split != 'train' or (self.args.data_sampling > 0
                                          and i % 10 < num):
                 steps += [
                     step for step in p.entails_proof_steps
                     if not (step.prop.type == 'f' or step.prop.type == 'e')
                 ]
                 for j, step in enumerate(p.entails_proof_steps):
                     step.pos_in_context = j
         for step in steps:
             if self.task == 'pred':
                 if len(self.lm.database.propositions[
                         step.prop_label].e) > 0:
                     self.steps.append(step)
             else:
                 if len(self.prop_goal_var[step.prop_label]) > 0:
                     self.steps.append(step)
         if self.task == 'pred':
             self.idxes = list(range(len(self.steps)))
         else:
             self.idxes = [
                 (i, j) for i in range(len(self.steps)) for j in range(
                     len(self.prop_goal_var[self.steps[i].prop_label]))
             ]
         if self.split == 'train':
             random.shuffle(self.idxes)
         return
     else:
         self.lm = load_language_model(new=args.gen_lm, iset=args.iset)
     self.exprs = load_exprs(
         os.path.join(args.data_path,
                      'expressions_list_%.1f' % (self.args.data_sampling)),
         self.lm)
     self.config = get_config(self.lm)
     self.all_f = {}
     for l, p in self.lm.database.propositions.items():
         for f in p.f:
             if f not in self.all_f:
                 self.all_f[f] = p.f[f]
     self.load_forward_info()
     if self.split == 'train':
         self.propositions = self.lm.training_propositions
     elif self.split == 'valid':
         self.propositions = self.lm.validation_propositions
     else:
         self.propositions = self.lm.test_propositions
     self.props_filter = []
     num = int(self.args.data_sampling * 10)
     for i, p in enumerate(self.propositions):
         if self.split != 'train' or (self.args.data_sampling > 0
                                      and i % 10 < num):
             self.props_filter.append(p.label)
     print('%d props after filtering' % (len(self.props_filter)))
     #self.length = len(self.file_list) * 50000
     if self.args.short_file_list:
         self.file_list = self.file_list[:1]
     self.load_steps_all()
     self.length = len(self.idxes)
     self.cur_file = 0
     self.cur_epoch = 0
     self.cnt = 0