Пример #1
0
Файл: beam.py Проект: spolu/z3ta
    def __init__(
        self,
        config: Config,
        l_model: LModel,
        ptra: ProofTraceActions,
        repl: REPL,
        target: Thm,
    ) -> None:
        super(Beam, self).__init__(config, ptra, repl, target)

        self._l_model = l_model

        index, actions, arguments = self.preprocess_ptra(ptra)

        with torch.no_grad():
            prd_actions, prd_lefts, prd_rights = \
                self._l_model.infer([index], [actions], [arguments])

        self._ptras = [ptra.copy()]
        self._repls = [repl.copy()]
        self._heads = [
            Head(
                prd_actions[0][0].cpu(),
                prd_lefts[0][0].cpu(),
                prd_rights[0][0].cpu(),
                1.0,  # PROB
            )
        ]
Пример #2
0
    def __init__(
            self,
            config: Config,
            l_model: LModel,
            ptra: ProofTraceActions,
            repl: REPL,
            target: Thm,
    ) -> None:
        super(PolicySample, self).__init__(config, ptra, repl, target)

        self._l_model = l_model

        self._ptra = ptra.copy()
        self._repl = repl.copy()
Пример #3
0
    def __init__(
        self,
        config: Config,
        ptra: ProofTraceActions,
        repl: REPL,
        target: Thm,
    ) -> None:
        super(Random, self).__init__(config, ptra, repl, target)

        self._ptra = ptra.copy()
        self._repl = repl.copy()
        self._last_thm = None

        self._sampler = RandomSampler(self._ptra)
Пример #4
0
    def __init__(
        self,
        config: Config,
        l_model: LModel,
        v_model: VModel,
        ptra: ProofTraceActions,
        repl: REPL,
        target: Thm,
    ) -> None:
        super(ParticleFilter, self).__init__(config, ptra, repl, target)

        self._l_model = l_model
        self._v_model = v_model

        self._filter_size = \
            config.get('prooftrace_search_particle_filter_size')
        self._sample_size = \
            config.get('prooftrace_search_particle_filter_sample_size')

        self._particles = [{
            'ptra': ptra.copy(),
            'repl': repl.copy(),
        } for _ in range(self._filter_size)]