def get_init_progs(sampler, cur_samples, cmd_args):
    raw_io_embed = sampler.forward_io(cur_samples)
    io_embed = raw_io_embed.repeat(cmd_args.num_importance_samples, 1)
    if cmd_args.inf_type == 'argmax':
        _, init_progs, _, _ = sampler.forward_q0(raw_io_embed,
                                                 gen_method='argmax')
        init_progs = init_progs * cmd_args.num_importance_samples
    elif cmd_args.inf_type == 'sample':
        _, init_progs, _, _ = sampler.forward_q0(io_embed, gen_method='sample')
    else:
        _, raw_init_progs, sizes, _ = sampler.forward_q0(
            io_embed, gen_method=cmd_args.inf_type)
        init_progs = []
        offset = 0
        public_inputs, public_outputs = [
            x[:cmd_args.numPublicIO] for x in list_inputs
        ], [x[:cmd_args.numPublicIO] for x in list_outputs]
        for i, s in enumerate(sizes):
            the_prog = raw_init_progs[offset]
            for j in range(s):
                prog = raw_init_progs[offset + j]
                prog = RFillNode.from_tokens(prog)
                assert prog is not None
                if test_passed(public_inputs[i], public_outputs[i], prog):
                    the_prog = raw_init_progs[offset + j]
                    break
            init_progs.append(the_prog)
            offset += s
    return io_embed, init_progs
def check_prog(pred_prog, gt_prog, list_inputs, list_outputs):
    prog = RFillNode.from_tokens(pred_prog)
    if prog is None:
        return 0, 0, 0
    passed = True
    for x, y in zip(list_inputs, list_outputs):
        out = evaluate_prog(prog, x)
        if y != out:
            passed = False
            break
    if passed:
        #same = ''.join(pred_prog) == ''.join(gt_prog)
        same = False
        return 1, 1, same
    else:
        return 1, 0, 0
def test_topk(test_db, eval_func, epoch_load=None):
    if epoch_load is not None and epoch_load >= 0:
        load_model(epoch_load)
    test_gen = DataLoader(test_db,
                          batch_size=cmd_args.batch_size,
                          shuffle=False,
                          collate_fn=test_db.collate_fn,
                          num_workers=cmd_args.num_proc,
                          drop_last=False)

    pbar = tqdm(test_gen)
    acc = 0.0
    num_done = 0
    for cur_samples in pbar:
        list_inputs, list_outputs, _, _ = cur_samples
        public_inputs, private_inputs = [
            x[:cmd_args.numPublicIO] for x in list_inputs
        ], [x[cmd_args.numPublicIO:] for x in list_inputs]
        public_outputs, private_outputs = [
            x[:cmd_args.numPublicIO] for x in list_outputs
        ], [x[cmd_args.numPublicIO:] for x in list_outputs]

        _, list_progs, sizes, _ = eval_func(public_inputs, public_outputs)
        offset = 0
        for i, s in enumerate(sizes):
            eval_prog = None
            k_cnt = 0
            cur_passed = False
            for j in range(s):
                prog = list_progs[offset + j]
                prog = RFillNode.from_tokens(prog)
                assert prog is not None
                if not test_passed(public_inputs[i], public_outputs[i], prog):
                    continue
                k_cnt += 1
                cur_passed = test_passed(private_inputs[i], private_outputs[i],
                                         prog)
                if cur_passed:
                    break
                if k_cnt >= cmd_args.eval_topk:
                    break
            acc += cur_passed
            offset += s
        num_done += len(list_inputs)
        pbar.set_description('frac: %.2f, acc: %.2f' %
                             (num_done / test_db.num_programs, acc / num_done))
    return acc / test_db.num_programs, 1.0