コード例 #1
0
    def bootstrap(
        config: Config,
        tokenizer: ProofTraceTokenizer,
        model: Model,
        ground: ProofTraceActions,
        target: Thm,
    ):
        ptra = ProofTraceActions(
            'TREE-{}-{}'.format(
                datetime.datetime.now().strftime("%Y%m%d_%H%M_%S.%f"),
                random.randint(0, 9999),
            ),
            [a for a in ground.actions() if a.value in INV_PREPARE_TOKENS],
        )
        repl = REPL(tokenizer)
        repl.prepare(ptra)

        pre_trc, pre_idx = \
            Node.prepare(ptra, None, config.get('prooftrace_sequence_length'))
        trc = [pre_trc]
        idx = [pre_idx]

        prd_actions, prd_lefts, prd_rights, prd_values = \
            model.infer(trc, idx)

        return Node(
            config,
            None,
            model,
            ground,
            target,
            ptra,
            repl,
            prd_actions[0].to(torch.device('cpu')),
            prd_lefts[0].to(torch.device('cpu')),
            prd_rights[0].to(torch.device('cpu')),
            # prd_values[0].item(),
        )
コード例 #2
0
def search():
    parser = argparse.ArgumentParser(description="")

    parser.add_argument(
        'config_path',
        type=str, help="path to the config file",
    )
    parser.add_argument(
        '--dataset_size',
        type=str, help="config override",
    )
    parser.add_argument(
        '--load_dir',
        type=str, help="config override",
    )

    parser.add_argument(
        '--device',
        type=str, help="config override",
    )

    args = parser.parse_args()

    config = Config.from_file(args.config_path)

    if args.device is not None:
        config.override('device', args.device)

    if args.dataset_size is not None:
        config.override(
            'prooftrace_dataset_size',
            args.dataset_size,
        )
    if args.load_dir is not None:
        config.override(
            'prooftrace_load_dir',
            os.path.expanduser(args.load_dir),
        )

    dataset_dir = os.path.join(
        os.path.expanduser(config.get('prooftrace_dataset_dir')),
        config.get('prooftrace_dataset_size'),
        'test_traces'
    )

    assert os.path.isdir(dataset_dir)
    files = [
        os.path.join(dataset_dir, f)
        for f in os.listdir(dataset_dir)
        if os.path.isfile(os.path.join(dataset_dir, f))
    ]
    cases = []

    with gzip.open(
            os.path.join(
                os.path.expanduser(config.get('prooftrace_dataset_dir')),
                config.get('prooftrace_dataset_size'),
                'traces.tokenizer',
            ), 'rb') as f:
        tokenizer = pickle.load(f)

    for p in files:
        match = re.search("_(\\d+)_(\\d+)\\.actions$", p)
        if match is None:
            continue
        ptra_len = int(match.group(1))
        cases.append((p, ptra_len))

    Log.out(
        "Loaded ProofTraceActions", {
            'cases': len(cases),
        })

    model = SearchModel(config).load()

    cases = sorted(cases, key=lambda c: c[1])

    for i in range(len(cases)):
        c = cases[i][0]
        with gzip.open(c, 'rb') as f:
            ground = pickle.load(f)

        ptra = ProofTraceActions(
            'BEAM-{}-{}'.format(
                datetime.datetime.now().strftime("%Y%m%d_%H%M_%S.%f"),
                random.randint(0, 9999),
            ),
            [
                ground.actions()[i] for i in range(ground.len())
                if ground.actions()[i].value in INV_PREPARE_TOKENS
            ],
            [
                ground.arguments()[i] for i in range(ground.len())
                if ground.actions()[i].value in INV_PREPARE_TOKENS
            ],
        )
        repl = REPL(tokenizer)
        target = repl.prepare(ptra)

        offset = 0
        fixed_gamma = 4
        if fixed_gamma > 0:
            gamma_len = max(ground.action_len() - fixed_gamma, 0)
            offset = ground.prepare_len() + gamma_len

            for i in range(gamma_len):
                assert ground.prepare_len() + i < ground.len() - 1
                pos = ground.prepare_len() + i

                action = ground.actions()[pos]
                argument = ground.arguments()[pos]

                thm = repl.apply(action)

                action._index = thm.index()
                argument._index = thm.index()

                ptra.append(action, argument)

        Log.out("TARGET", {
            'name': ground.name(),
            'prepare_length': ground.prepare_len(),
            'length': ground.action_len(),
            'summary': ground.summary(offset),
        })

        search = None
        if config.get('prooftrace_search_type') == 'beam':
            search = Beam(config, model, ptra, repl, target)
        if config.get('prooftrace_search_type') == 'mcts':
            search = MCTS(config, model, ptra, repl, target)
        assert search is not None

        depth = config.get('prooftrace_search_depth')
        if config.get('prooftrace_search_type') == 'beam':
            depth = fixed_gamma * 2

        for i in range(depth):
            done, ptra, proved = search.step(False, offset)
            if done:
                break
コード例 #3
0
def search():
    parser = argparse.ArgumentParser(description="")

    parser.add_argument(
        'config_path',
        type=str,
        help="path to the config file",
    )
    parser.add_argument(
        '--dataset_size',
        type=str,
        help="config override",
    )
    parser.add_argument(
        '--load_dir',
        type=str,
        help="config override",
    )

    parser.add_argument(
        '--device',
        type=str,
        help="config override",
    )

    args = parser.parse_args()

    config = Config.from_file(args.config_path)

    if args.device is not None:
        config.override('device', args.device)

    if args.dataset_size is not None:
        config.override(
            'prooftrace_dataset_size',
            args.dataset_size,
        )
    if args.load_dir is not None:
        config.override(
            'prooftrace_load_dir',
            os.path.expanduser(args.load_dir),
        )

    if config.get('device') != 'cpu':
        torch.cuda.set_device(torch.device(config.get('device')))

    dataset_dir = os.path.join(
        os.path.expanduser(config.get('prooftrace_dataset_dir')),
        config.get('prooftrace_dataset_size'), 'test_traces')

    assert os.path.isdir(dataset_dir)
    files = [
        os.path.join(dataset_dir, f) for f in os.listdir(dataset_dir)
        if os.path.isfile(os.path.join(dataset_dir, f))
    ]
    cases = []

    with gzip.open(
            os.path.join(
                os.path.expanduser(config.get('prooftrace_dataset_dir')),
                config.get('prooftrace_dataset_size'),
                'traces.tokenizer',
            ), 'rb') as f:
        tokenizer = pickle.load(f)

    for p in files:
        match = re.search("_(\\d+)_(\\d+)\\.actions$", p)
        if match is None:
            continue
        ptra_len = int(match.group(1))

        if ptra_len <= 64:
            cases.append((p, ptra_len))

    Log.out("Loaded ProofTraceActions", {
        'max_length': 64,
        'cases': len(cases),
    })

    model = Model(config).load()

    cases = sorted(cases, key=lambda c: c[1])

    for i in range(len(cases)):
        c = cases[i][0]
        with gzip.open(c, 'rb') as f:
            ground = pickle.load(f)

        repl = REPL(tokenizer)
        repl.prepare(ground)
        target = repl.replay(ground)

        Log.out(
            "TARGET", {
                'name': ground.name(),
                'prepare_length': ground.prepare_len(),
                'length': ground.action_len(),
                'summary': ground.summary(),
            })

        tree = Node.bootstrap(config, tokenizer, model, ground, target)

        done = False
        while (not done):
            if tree.min_value() is _MAX_VALUE:
                Log.out("FAILED", {
                    'name': ground.name(),
                })
                done = True
            else:
                thm = tree.expand()
                if thm:
                    Log.out("DEMONSTRATED", {
                        'name': ground.name(),
                        'theorem': thm.thm_string(),
                    })
                    done = True
コード例 #4
0
    def run_once(self, ):
        info = self._wrk.fetch(self._device, False)
        if info is not None:
            self.update(info['config'])

        for m in self._model.modules():
            self._model.modules()[m].eval()

        assert os.path.isdir(self._rollout_dir)

        rdirs = [
            os.path.join(self._rollout_dir, d)
            for d in os.listdir(self._rollout_dir)
            if os.path.isdir(os.path.join(self._rollout_dir, d))
        ]

        rdir = random.choice(rdirs)
        rfiles = sorted([
            os.path.join(rdir, f)
            for f in os.listdir(rdir) if re.search(".rollout$", f)
        ],
                        reverse=True)

        if len(rfiles) == 0:
            return

        path = rfiles[0]
        with gzip.open(path, 'rb') as f:
            base = pickle.load(f)

        gamma = random.choice(GAMMAS)

        ground = base.positive()
        name = base.name()

        ptra = ProofTraceActions(
            'BEAM-{}-{}'.format(
                datetime.datetime.now().strftime("%Y%m%d_%H%M_%S.%f"),
                random.randint(0, 9999),
            ),
            [
                ground.actions()[i] for i in range(ground.len())
                if ground.actions()[i].value in INV_PREPARE_TOKENS
            ],
            [
                ground.arguments()[i] for i in range(ground.len())
                if ground.actions()[i].value in INV_PREPARE_TOKENS
            ],
        )
        repl = REPL(self._tokenizer)
        target = repl.prepare(ptra)

        gamma = min(ground.action_len(), gamma)
        gamma_len = ground.action_len() - gamma
        offset = ground.prepare_len() + gamma_len

        for i in range(gamma_len):
            assert ground.prepare_len() + i < ground.len() - 1
            pos = ground.prepare_len() + i

            action = ground.actions()[pos]
            argument = ground.arguments()[pos]

            thm = repl.apply(action)

            action._index = thm.index()
            argument._index = thm.index()

            ptra.append(action, argument)

        search = None
        if self._config.get('prooftrace_search_type') == 'beam':
            search = Beam(self._config, self._model, ptra, repl, target)
        if self._config.get('prooftrace_search_type') == 'mcts':
            search = MCTS(self._config, self._model, ptra, repl, target)
        assert search is not None

        Log.out(
            "ROLLOUT START", {
                'name': name,
                'gamma': gamma,
                'prepare_length': ground.prepare_len(),
                'action_length': ground.action_len(),
                'length': ground.len(),
            })

        Log.out("TARGET", {
            'name': name,
            'summary': ground.summary(offset),
        })

        rollout = None
        proven = False
        ptra = None

        depth = self._config.get('prooftrace_search_depth')
        if self._config.get('prooftrace_search_type') == 'beam':
            depth = gamma * 2

        for i in range(depth):
            step_start = time.time()
            done, ptra, proven = search.step(i == (depth - 1), offset)
            step_end = time.time()
            Log.out(
                'STEP', {
                    'i': i,
                    'done': done,
                    'proven': proven,
                    'gamma': gamma,
                    'time': "{:.2f}".format(step_end - step_start),
                })
            if done:
                if proven:
                    rollout = Rollout(name, [ptra], [])
                else:
                    rollout = Rollout(name, [], [ptra])
                break
            if (step_end - step_start) > \
                    self._config.get('prooftrace_search_step_timeout'):
                rollout = Rollout(name, [], [ptra])
                break

        demo_length = (ptra.len() - (ground.prepare_len() + gamma_len))

        Log.out(
            "ROLLOUT END", {
                'name': name,
                'proven': proven,
                'gamma': gamma,
                'demo_length': demo_length,
            })

        Log.out("PTRA", {
            'name': name,
            'summary': ptra.summary(offset),
        })

        if demo_length > 0:
            info = {
                'rll_cnt': 1,
                'pos_cnt': 1 if proven else 0,
                'neg_cnt': 0 if proven else 1,
            }
            if proven:
                info['demo_len'] = demo_length

            # Publish the statistics.
            self._wrk.publish(info)

            # Finally merge and store the new rollout
            base.merge(rollout)

            now = datetime.datetime.now().strftime("%Y%m%d_%H%M_%S.%f")
            rnd = random.randint(0, 10e9)

            tmp_path = os.path.join(rdir, "{}_{}.tmp".format(now, rnd))
            fnl_path = os.path.join(rdir, "{}_{}.rollout".format(now, rnd))

            with gzip.open(tmp_path, 'wb') as f:
                pickle.dump(base, f, protocol=pickle.HIGHEST_PROTOCOL)
            os.rename(tmp_path, fnl_path)

            del base
            del rollout

            if len(rfiles) > 1:
                for p in rfiles[1:]:
                    try:
                        os.remove(p)
                    except FileNotFoundError:
                        pass

            Log.out("MERGE WRITE", {
                'name': name,
                'path': fnl_path,
            })
コード例 #5
0
ファイル: lm_rollout.py プロジェクト: spolu/z3ta
    def run_once(self, ):
        info = self._wrk.fetch(self._device, False)
        if info is not None:
            self.update(info['config'])

        for m in self._model.modules():
            self._model.modules()[m].eval()

        assert os.path.isdir(self._rollout_dir)

        rdirs = [
            os.path.join(self._rollout_dir, d)
            for d in os.listdir(self._rollout_dir)
            if os.path.isdir(os.path.join(self._rollout_dir, d))
        ]

        rdir = random.choice(rdirs)
        rfiles = sorted([
            os.path.join(rdir, f)
            for f in os.listdir(rdir) if re.search(".rollout$", f)
        ],
                        reverse=True)

        if len(rfiles) == 0:
            return

        path = rfiles[0]
        with gzip.open(path, 'rb') as f:
            base = pickle.load(f)

        ground = base.positive()
        name = base.name()

        ptra = ProofTraceActions(
            'ROLLOUT-{}-{}'.format(
                datetime.datetime.now().strftime("%Y%m%d_%H%M_%S.%f"),
                random.randint(0, 9999),
            ),
            [
                ground.actions()[i] for i in range(ground.len())
                if ground.actions()[i].value in INV_PREPARE_TOKENS
            ],
            [
                ground.arguments()[i] for i in range(ground.len())
                if ground.actions()[i].value in INV_PREPARE_TOKENS
            ],
        )
        repl = REPL(self._tokenizer)
        target = repl.prepare(ptra)

        search = None
        if self._config.get('prooftrace_search_type') == 'beam':
            search = Beam(
                self._config,
                self._model,
                ptra,
                repl,
                target,
            )
        if self._config.get('prooftrace_search_type') == 'policy_sample':
            search = PolicySample(
                self._config,
                self._model,
                ptra,
                repl,
                target,
            )
        assert search is not None

        depth = self._config.get('prooftrace_sequence_length') - \
            ground.prepare_len()

        if 2 * ground.action_len() < depth:
            depth = 2 * ground.action_len()

        Log.out(
            "ROLLOUT START", {
                'name': name,
                'prepare_length': ground.prepare_len(),
                'action_length': ground.action_len(),
                'depth': depth,
            })

        rollout = None
        proved = False
        ptra = None

        for i in range(depth):
            step_start = time.time()
            done, ptra, proved = search.step()
            step_end = time.time()
            Log.out(
                'STEP', {
                    'i': i,
                    'done': done,
                    'proved': proved,
                    'time': "{:.2f}".format(step_end - step_start),
                })
            if done:
                break
            if (step_end - step_start) > 20:
                # self._config.get('prooftrace_search_step_timeout'):
                break

        if proved:
            rollout = Rollout(name, [ptra], [])
        else:
            rollout = Rollout(name, [], [ptra])

        demo_length = ptra.action_len()
        demo_delta = ptra.action_len() - ground.action_len()

        Log.out(
            "ROLLOUT END", {
                'name': name,
                'proved': proved,
                'demo_length': demo_length,
                'demo_delta': demo_delta
            })

        if proved:
            Log.out("PTRA", {
                'name': name,
                'summary': ptra.summary(),
            })

        if demo_length > 0:
            info = {
                'rll_cnt': 1,
                'pos_cnt': 1 if proved else 0,
                'neg_cnt': 0 if proved else 1,
            }
            if proved:
                info['demo_len'] = demo_length
                info['demo_dlt'] = demo_delta

            # Publish the statistics.
            self._wrk.publish(info)

            # Finally merge and store the new rollout
            base.merge(rollout)

            now = datetime.datetime.now().strftime("%Y%m%d_%H%M_%S.%f")
            rnd = random.randint(0, 10e9)

            tmp_path = os.path.join(rdir, "{}_{}.tmp".format(now, rnd))
            fnl_path = os.path.join(rdir, "{}_{}.rollout".format(now, rnd))

            with gzip.open(tmp_path, 'wb') as f:
                pickle.dump(base, f, protocol=pickle.HIGHEST_PROTOCOL)
            os.rename(tmp_path, fnl_path)

            del base
            del rollout

            if len(rfiles) > 1:
                for p in rfiles[1:]:
                    try:
                        os.remove(p)
                    except FileNotFoundError:
                        pass

            Log.out("MERGE WRITE", {
                'name': name,
                'path': fnl_path,
            })
コード例 #6
0
ファイル: run.py プロジェクト: spolu/z3ta
def search():
    parser = argparse.ArgumentParser(description="")

    parser.add_argument(
        'config_path',
        type=str,
        help="path to the config file",
    )
    parser.add_argument(
        '--dataset_size',
        type=str,
        help="config override",
    )
    parser.add_argument(
        '--load_dir',
        type=str,
        help="config override",
    )

    parser.add_argument(
        '--device',
        type=str,
        help="config override",
    )
    parser.add_argument(
        '--train',
        type=str2bool,
        help="search training set",
    )

    args = parser.parse_args()

    config = Config.from_file(args.config_path)

    if args.device is not None:
        config.override('device', args.device)

    if args.dataset_size is not None:
        config.override(
            'prooftrace_dataset_size',
            args.dataset_size,
        )
    if args.load_dir is not None:
        config.override(
            'prooftrace_load_dir',
            os.path.expanduser(args.load_dir),
        )

    train = False
    if args.train is not None:
        train = args.train

    if train:
        dataset_dir = os.path.join(
            os.path.expanduser(config.get('prooftrace_dataset_dir')),
            config.get('prooftrace_dataset_size'), 'train_traces')
    else:
        dataset_dir = os.path.join(
            os.path.expanduser(config.get('prooftrace_dataset_dir')),
            config.get('prooftrace_dataset_size'), 'test_traces')

    assert os.path.isdir(dataset_dir)
    files = [
        os.path.join(dataset_dir, f) for f in os.listdir(dataset_dir)
        if os.path.isfile(os.path.join(dataset_dir, f))
    ]
    cases = []

    with gzip.open(
            os.path.join(
                os.path.expanduser(config.get('prooftrace_dataset_dir')),
                config.get('prooftrace_dataset_size'),
                'traces.tokenizer',
            ), 'rb') as f:
        tokenizer = pickle.load(f)

    for p in files:
        match = re.search("_(\\d+)_(\\d+)\\.actions$", p)
        if match is None:
            continue
        ptra_len = int(match.group(1))
        cases.append((p, ptra_len))

    Log.out("Loaded ProofTraceActions", {
        'cases': len(cases),
    })

    l_model = LModel(config).load()
    # v_model = VModel(config).load()

    cases = sorted(cases, key=lambda c: c[1])

    for i in range(len(cases)):
        c = cases[i][0]
        with gzip.open(c, 'rb') as f:
            ground = pickle.load(f)

        ptra = ProofTraceActions(
            'SEARCH-{}-{}'.format(
                datetime.datetime.now().strftime("%Y%m%d_%H%M_%S.%f"),
                random.randint(0, 9999),
            ),
            [
                ground.actions()[i] for i in range(ground.len())
                if ground.actions()[i].value in INV_PREPARE_TOKENS
            ],
            [
                ground.arguments()[i] for i in range(ground.len())
                if ground.actions()[i].value in INV_PREPARE_TOKENS
            ],
        )
        repl = REPL(tokenizer)
        target = repl.prepare(ptra)

        offset = 0
        fixed_gamma = config.get('prooftrace_search_fixed_gamma')
        if fixed_gamma > 0:
            gamma_len = max(ground.action_len() - fixed_gamma, 0)
            offset = ground.prepare_len() + gamma_len

            for i in range(gamma_len):
                assert ground.prepare_len() + i < ground.len() - 1
                pos = ground.prepare_len() + i

                action = ground.actions()[pos]
                argument = ground.arguments()[pos]

                thm = repl.apply(action)

                action._index = thm.index()
                argument._index = thm.index()

                ptra.append(action, argument)

        Log.out(
            "TARGET", {
                'name': ground.name(),
                'prepare_length': ground.prepare_len(),
                'action_length': ground.action_len(),
                'summary': ground.summary(offset),
                'theorem': target.thm_string(False, True),
            })

        search = None
        if config.get('prooftrace_search_type') == 'beam':
            search = Beam(config, l_model, ptra, repl, target)
        # if config.get('prooftrace_search_type') == 'mcts':
        #     search = MCTS(config, l_model, v_model, ptra, repl, target)
        # if config.get('prooftrace_search_type') == 'particle_filter':
        #     search = ParticleFilter(
        #         config, l_model, v_model, ptra, repl, target,
        #     )
        if config.get('prooftrace_search_type') == 'policy_sample':
            search = PolicySample(config, l_model, ptra, repl, target)
        assert search is not None

        depth = config.get('prooftrace_sequence_length') - \
            ground.prepare_len()

        if fixed_gamma != 0:
            if 2 * fixed_gamma < depth:
                depth = fixed_gamma * 2
        else:
            if 2 * ground.action_len() < depth:
                depth = 2 * ground.action_len()

        for i in range(depth):
            if fixed_gamma != 0:
                conclusion = (i >= fixed_gamma * 2)
            else:
                conclusion = (i >= ground.action_len())

            step_start = time.time()
            done, ptra, proved = search.step(offset, conclusion)
            step_end = time.time()

            Log.out(
                'STEP', {
                    'i': i,
                    'done': done,
                    'proved': proved,
                    'time': "{:.2f}".format(step_end - step_start),
                    'summary': ptra.summary(offset),
                })
            if done:
                if proved:
                    Log.out("DEMONSTRATED", {
                        'theorem': target.thm_string(False, True),
                    })
                break

            # if (step_end - step_start) > \
            #         config.get('prooftrace_search_step_timeout'):
            #     break

        Log.out("FINISH", {
            'summary': ptra.summary(offset),
        })
        if config.get('prooftrace_search_type') == 'random' \
                and search.last_thm() is not None:
            Log.out("GENERATED",
                    {'theorem': search.last_thm().thm_string(False, True)})
コード例 #7
0
class Env:
    def __init__(
        self,
        config: Config,
        test: bool,
    ) -> None:
        self._sequence_length = config.get('prooftrace_sequence_length')

        self._device = torch.device(config.get('device'))

        if test:
            dataset_dir = os.path.join(
                os.path.expanduser(config.get('prooftrace_dataset_dir')),
                config.get('prooftrace_dataset_size'), 'test_traces')
        else:
            dataset_dir = os.path.join(
                os.path.expanduser(config.get('prooftrace_dataset_dir')),
                config.get('prooftrace_dataset_size'), 'train_traces')
        assert os.path.isdir(dataset_dir)

        self._trace_files = [
            os.path.join(dataset_dir, f) for f in os.listdir(dataset_dir)
            if (os.path.isfile(os.path.join(dataset_dir, f))
                and re.search("\\.actions$", f) is not None)
        ]

        with gzip.open(
                os.path.join(
                    os.path.expanduser(config.get('prooftrace_dataset_dir')),
                    config.get('prooftrace_dataset_size'),
                    'traces.tokenizer',
                ), 'rb') as f:
            self._tokenizer = pickle.load(f)

        self._ground = None
        self._run = None
        self._repl = None
        self._target = None
        self._alpha = 0

    def reset(
        self,
        gamma: float,
        fixed_gamma: int,
    ) -> typing.Tuple[int, typing.List[Action]]:
        self._ground = None
        self._run = None
        self._repl = None
        self._target = None
        self._alpha = 0
        self._gamma_len = 0

        self._match_count = 0

        while self._ground is None:
            path = random.choice(self._trace_files)

            match = re.search("_(\\d+)_(\\d+)\\.actions$", path)
            ptra_len = int(match.group(1))

            if ptra_len <= self._sequence_length:
                with gzip.open(path, 'rb') as f:
                    self._ground = pickle.load(f)
                # Log.out("Selecting trace", {
                #     "trace": self._ground.name(),
                #     'length': self._ground.len(),
                # })

        self._run = ProofTraceActions(
            'REPL-{}-{}'.format(
                datetime.datetime.now().strftime("%Y%m%d_%H%M_%S.%f"),
                random.randint(0, 9999),
            ),
            [
                self._ground.actions()[i] for i in range(self._ground.len())
                if self._ground.actions()[i].value in INV_PREPARE_TOKENS
            ],
            [
                self._ground.arguments()[i] for i in range(self._ground.len())
                if self._ground.actions()[i].value in INV_PREPARE_TOKENS
            ],
        )

        self._repl = REPL(self._tokenizer)
        self._target = self._repl.prepare(self._run)

        # GAMMA Initialization.
        if gamma > 0.0 and random.random() < gamma:
            if fixed_gamma > 0:
                self._gamma_len = self._ground.action_len() - \
                    random.randrange(
                        1, min(fixed_gamma, self._ground.action_len()) + 1
                    )
            else:
                self._gamma_len = random.randrange(0,
                                                   self._ground.action_len())

            for i in range(self._gamma_len):
                assert self._ground.prepare_len() + i < self._ground.len() - 1
                pos = self._ground.prepare_len() + i
                action = self._ground.actions()[pos]
                argument = self._ground.arguments()[pos]

                thm = self._repl.apply(action)

                action._index = thm.index()
                argument._index = thm.index()

                self._run.append(action, argument)

        return self.observation()

    def observation(
        self,
    ) -> typing.Tuple[int, typing.List[Action], typing.List[Action], ]:
        actions = self._run.actions().copy()
        arguments = self._run.arguments().copy()

        # If the len match this is a final observation, so no extract will be
        # appended and that's fine because this observation won't make it to
        # the agent.
        if len(actions) < self._sequence_length:
            actions.append(Action.from_action('EXTRACT', None, None))

        # Finally we always return actions with the same length.
        empty = Action.from_action('EMPTY', None, None)
        while len(actions) < self._sequence_length:
            actions.append(empty)
        while len(arguments) < self._sequence_length:
            arguments.append(empty)

        return (self._run.len(), actions, arguments)

    def alpha_oracle(self, ) -> typing.Tuple[torch.Tensor, int]:
        self._alpha += 1
        for i in range(self._ground.prepare_len(), self._ground.len()):
            a = self._ground.actions()[i]
            if (not self._run.seen(a)) and \
                    self._run.seen(a.left) and \
                    self._run.seen(a.right):
                assert 0 <= a.value - len(PREPARE_TOKENS)
                assert a.value < len(PROOFTRACE_TOKENS)
                actions = torch.tensor([[
                    a.value - len(PREPARE_TOKENS),
                    self._run.hashes()[a.left.hash()],
                    self._run.hashes()[a.right.hash()],
                ]],
                                       dtype=torch.int64).to(self._device)
                return actions, 0

        # We may reach this point as final actions are sometime repeated at the
        # end of prooftraces.
        return None, 0

    def beta_oracle(
        self,
        prd_actions: torch.Tensor,
        prd_lefts: torch.Tensor,
        prd_rights: torch.Tensor,
        beta_width: int,
        beta_size: int,
    ) -> typing.Tuple[torch.Tensor, int]:
        top_actions = torch.exp(prd_actions).topk(beta_width)
        top_lefts = torch.exp(prd_lefts).topk(beta_width)
        top_rights = torch.exp(prd_rights).topk(beta_width)

        out = []
        frame_count = 0

        for ia in range(beta_width):
            for il in range(beta_width):
                for ir in range(beta_width):
                    action = top_actions[1][ia].item()
                    assert action >= 0
                    assert action < len(PROOFTRACE_TOKENS) - len(
                        PREPARE_TOKENS)
                    left = top_lefts[1][il].item()
                    right = top_rights[1][ir].item()
                    prob = top_actions[0][ia].item() * \
                        top_lefts[0][il].item() * \
                        top_rights[0][ir].item()

                    if left >= self._run.len() or right >= self._run.len():
                        out.append(([action, left, right], prob))
                        continue

                    a = Action.from_action(
                        INV_PROOFTRACE_TOKENS[action + len(PREPARE_TOKENS)],
                        self._run.arguments()[left],
                        self._run.arguments()[right],
                    )

                    if self._run.seen(a):
                        out.append(([action, left, right], prob))
                        continue

                    frame_count += 1
                    if not self._repl.valid(a):
                        out.append(([action, left, right], prob))
                        continue

                    out.append(([action, left, right], prob + 1.0))

        out = sorted(out, key=lambda o: o[1], reverse=True)

        actions = []
        for i in range(beta_size):
            actions.append(out[i][0])

        return \
            torch.tensor(actions, dtype=torch.int64).to(self._device), \
            frame_count

    def explore(
        self,
        prd_actions: torch.Tensor,
        prd_lefts: torch.Tensor,
        prd_rights: torch.Tensor,
        alpha: float,
        beta: float,
        beta_width: int,
    ) -> typing.Tuple[torch.Tensor, int]:

        # ALPHA Oracle.
        if alpha > 0.0 and random.random() < alpha and self._alpha == 0:
            actions, frame_count = self.alpha_oracle()
            if actions is not None:
                return actions, frame_count

        # BETA Oracle.
        if beta > 0.0 and random.random() < beta:
            return self.beta_oracle(
                prd_actions,
                prd_lefts,
                prd_rights,
                beta_width,
                1,
            )

        # Sampling.
        actions = torch.cat((
            Categorical(
                torch.exp(prd_actions)).sample().unsqueeze(0).unsqueeze(1),
            Categorical(
                torch.exp(prd_lefts)).sample().unsqueeze(0).unsqueeze(1),
            Categorical(
                torch.exp(prd_rights)).sample().unsqueeze(0).unsqueeze(1),
        ),
                            dim=1)

        return actions, 0

    def step(
        self,
        action: typing.Tuple[int, int, int],
        step_reward_prob: float,
        match_reward_prob: float,
        gamma: float,
        fixed_gamma: int,
    ) -> typing.Tuple[typing.Tuple[int, typing.List[Action]], typing.Tuple[
            float, float, float], bool, typing.Dict[str, int], ]:
        assert self._ground is not None
        assert self._run is not None

        def finish(rewards, done, info):
            if done:
                observation = self.reset(gamma, fixed_gamma)
            else:
                observation = self.observation()
            return observation, rewards, done, info

        if action[1] >= self._run.len() or action[2] >= self._run.len():
            Log.out(
                "DONE ILLEGAL[overflow]", {
                    'ground_length': self._ground.action_len(),
                    'gamma_length': self._gamma_len,
                    'run_length': self._run.action_len() - self._gamma_len,
                    'name': self._ground.name(),
                })
            return finish(
                (0.0, 0.0, 0.0), True, {
                    'match_count': self._match_count,
                    'run_length': self._run.action_len() - self._gamma_len,
                })

        action = Action.from_action(
            INV_PROOFTRACE_TOKENS[action[0] + len(PREPARE_TOKENS)],
            self._run.arguments()[action[1]],
            self._run.arguments()[action[2]],
        )

        if self._run.seen(action):
            Log.out(
                "DONE ILLEGAL[seen]", {
                    'ground_length': self._ground.action_len(),
                    'gamma_length': self._gamma_len,
                    'run_length': self._run.action_len() - self._gamma_len,
                    'name': self._ground.name(),
                })
            return finish(
                (0.0, 0.0, 0.0), True, {
                    'match_count': self._match_count,
                    'run_length': self._run.action_len() - self._gamma_len,
                })

        try:
            thm = self._repl.apply(action)
        except (FusionException, REPLException, TypeException):
            Log.out(
                "DONE ILLEGAL[fusion]", {
                    'ground_length': self._ground.action_len(),
                    'gamma_length': self._gamma_len,
                    'run_length': self._run.action_len() - self._gamma_len,
                    'name': self._ground.name(),
                })
            return finish(
                (0.0, 0.0, 0.0), True, {
                    'match_count': self._match_count,
                    'run_length': self._run.action_len() - self._gamma_len,
                })

        action._index = thm.index()
        argument = self._run.build_argument(
            thm.concl(),
            thm.hyp(),
            thm.index(),
        )
        self._run.append(action, argument)

        step_reward = 0.0
        match_reward = 0.0
        final_reward = 0.0
        done = False
        info = {}

        if step_reward_prob > 0.0 and random.random() < step_reward_prob:
            step_reward = 1.0

        if self._ground.seen(action):
            self._match_count += 1
            if match_reward_prob > 0.0 and random.random() < match_reward_prob:
                match_reward = 1.0
                step_reward = 0.0

        if self._target.thm_string(True) == thm.thm_string(True):
            final_reward = 10.0
            done = True
            info['demo_length'] = min(
                self._run.action_len(),
                self._ground.action_len(),
            ) - self._gamma_len
            info['demo_delta'] = \
                self._run.action_len() - self._ground.action_len()
            Log.out(
                "DEMONSTRATED", {
                    'ground_length': self._ground.action_len(),
                    'gamma_length': self._gamma_len,
                    'run_length': self._run.action_len() - self._gamma_len,
                    'name': self._ground.name(),
                })
        if self._run.len() >= self._sequence_length:
            done = True
            Log.out(
                "DONE LENGTH ", {
                    'ground_length': self._ground.action_len(),
                    'gamma_length': self._gamma_len,
                    'run_length': self._run.action_len() - self._gamma_len,
                    'name': self._ground.name(),
                })

        if done:
            info['match_count'] = self._match_count
            info['run_length'] = self._run.action_len() - self._gamma_len

        return finish((step_reward, match_reward, final_reward), done, info)
コード例 #8
0
ファイル: search_test.py プロジェクト: spolu/z3ta
    def run_once(self, ):
        info = self._tst.fetch(self._device, False)
        if info is not None:
            self.update(info['config'])

        self._modules['E'].eval()
        self._modules['T'].eval()
        self._modules['PH'].eval()
        self._modules['VH'].eval()

        model = BeamModel(self._config, self._modules)

        assert os.path.isdir(self._dataset_dir)
        files = [
            os.path.join(self._dataset_dir, f)
            for f in os.listdir(self._dataset_dir)
            if os.path.isfile(os.path.join(self._dataset_dir, f))
        ]

        cases = {}
        for gamma in GAMMAS:
            cases[gamma] = []

        for p in files:
            match = re.search("_(\\d+)_(\\d+)\\.actions$", p)
            if match is None:
                continue
            for gamma in GAMMAS:
                cases[gamma].append(p)

        info = {
            'demo_len': 0.0,
        }
        for gamma in GAMMAS:
            cases[gamma] = random.sample(cases[gamma], self._test_gamma_size)
            info['gamma_{}'.format(gamma)] = 0.0

        for gamma in GAMMAS:
            for i in range(len(cases[gamma])):
                c = cases[gamma][i]
                with gzip.open(c, 'rb') as f:
                    ground = pickle.load(f)

                ptra = ProofTraceActions(
                    'BEAM-{}-{}'.format(
                        datetime.datetime.now().strftime("%Y%m%d_%H%M_%S.%f"),
                        random.randint(0, 9999),
                    ),
                    [
                        ground.actions()[i] for i in range(ground.len())
                        if ground.actions()[i].value in INV_PREPARE_TOKENS
                    ],
                    [
                        ground.arguments()[i] for i in range(ground.len())
                        if ground.actions()[i].value in INV_PREPARE_TOKENS
                    ],
                )
                repl = REPL(self._tokenizer)
                target = repl.prepare(ptra)

                offset = 0
                gamma_len = max(ground.action_len() - gamma, 0)
                offset = ground.prepare_len() + gamma_len

                for i in range(gamma_len):
                    assert ground.prepare_len() + i < ground.len() - 1
                    pos = ground.prepare_len() + i

                    action = ground.actions()[pos]
                    argument = ground.arguments()[pos]

                    thm = repl.apply(action)

                    action._index = thm.index()
                    argument._index = thm.index()

                    ptra.append(action, argument)

                Log.out(
                    "TARGET",
                    {
                        'name': ground.name(),
                        'prepare_length': ground.prepare_len(),
                        'length': ground.action_len(),
                        # 'summary': ground.summary(offset),
                    })

                beam = Beam(self._config, model, ptra, repl, target)

                proven = False
                ptra = None

                for i in range(gamma):
                    step_start = time.time()
                    done, ptra, proven = beam.step(i == (gamma - 1), offset)
                    step_end = time.time()
                    if done:
                        break
                    if (step_end - step_start) > \
                            self._config.get('prooftrace_search_step_timeout'):
                        break

                demo_length = (ptra.len() - (ground.prepare_len() + gamma_len))

                Log.out(
                    "DONE", {
                        'name': ground.name(),
                        'proven': proven,
                        'gamma': gamma,
                        'demo_length': demo_length,
                    })

                if proven:
                    info['gamma_{}'.format(gamma)] += \
                        1.0 / self._test_gamma_size
                info['demo_len'] += \
                    demo_length / (self._test_gamma_size * len(GAMMAS))

        self._tst.publish(info)