Ejemplo n.º 1
0
    def sample(
        self,
        ptra: ProofTraceActions,
        repl: REPL,
        tries: int,
        conclusion: bool = False,
    ) -> Action:
        for i in range(tries):
            if not conclusion:
                action = random.choice(list(NON_PREPARE_TOKENS.values()))
            else:
                action = random.choice(list(CONCLUSION_TOKENS.values()))

            if INV_PROOFTRACE_TOKENS[action] == 'REFL':
                left = self.sample_term()
                right = 0
            if INV_PROOFTRACE_TOKENS[action] == 'TRANS':
                left = self.sample_theorem(ptra)
                right = self.sample_theorem(ptra)
            if INV_PROOFTRACE_TOKENS[action] == 'MK_COMB':
                left = self.sample_theorem(ptra)
                right = self.sample_theorem(ptra)
            if INV_PROOFTRACE_TOKENS[action] == 'ABS':
                left = self.sample_theorem(ptra)
                right = self.sample_term()
            if INV_PROOFTRACE_TOKENS[action] == 'BETA':
                left = self.sample_term()
                right = 0
            if INV_PROOFTRACE_TOKENS[action] == 'ASSUME':
                left = self.sample_term()
                right = 0
            if INV_PROOFTRACE_TOKENS[action] == 'EQ_MP':
                left = self.sample_theorem(ptra)
                right = self.sample_theorem(ptra)
            if INV_PROOFTRACE_TOKENS[action] == 'DEDUCT_ANTISYM_RULE':
                left = self.sample_theorem(ptra)
                right = self.sample_theorem(ptra)
            if INV_PROOFTRACE_TOKENS[action] == 'INST':
                left = self.sample_theorem(ptra)
                right = self.sample_subst()
            if INV_PROOFTRACE_TOKENS[action] == 'INST_TYPE':
                left = self.sample_theorem(ptra)
                right = self.sample_subst_type()

            a = Action.from_action(
                INV_PROOFTRACE_TOKENS[action],
                ptra.arguments()[left],
                ptra.arguments()[right],
            )

            if ptra.seen(a):
                continue

            if not repl.valid(a):
                continue

            return a

        return None
Ejemplo n.º 2
0
    def prepare(
        ptra: ProofTraceActions,
        a: Action,
        sequence_length: int,
    ) -> typing.Tuple[typing.List[Action], typing.List[int], ]:
        trc = ptra.actions().copy()
        idx = len(trc)
        if a is not None:
            trc.append(a)
            idx += 1

        trc.append(Action.from_action('EXTRACT', None, None))
        empty = Action.from_action('EMPTY', None, None)
        while len(trc) < sequence_length:
            trc.append(empty)

        return trc, idx
Ejemplo n.º 3
0
    def beta_oracle(
        self,
        prd_actions: torch.Tensor,
        prd_lefts: torch.Tensor,
        prd_rights: torch.Tensor,
        beta_width: int,
        beta_size: int,
    ) -> typing.Tuple[torch.Tensor, int]:
        top_actions = torch.exp(prd_actions).topk(beta_width)
        top_lefts = torch.exp(prd_lefts).topk(beta_width)
        top_rights = torch.exp(prd_rights).topk(beta_width)

        out = []
        frame_count = 0

        for ia in range(beta_width):
            for il in range(beta_width):
                for ir in range(beta_width):
                    action = top_actions[1][ia].item()
                    assert action >= 0
                    assert action < len(PROOFTRACE_TOKENS) - len(
                        PREPARE_TOKENS)
                    left = top_lefts[1][il].item()
                    right = top_rights[1][ir].item()
                    prob = top_actions[0][ia].item() * \
                        top_lefts[0][il].item() * \
                        top_rights[0][ir].item()

                    if left >= self._run.len() or right >= self._run.len():
                        out.append(([action, left, right], prob))
                        continue

                    a = Action.from_action(
                        INV_PROOFTRACE_TOKENS[action + len(PREPARE_TOKENS)],
                        self._run.arguments()[left],
                        self._run.arguments()[right],
                    )

                    if self._run.seen(a):
                        out.append(([action, left, right], prob))
                        continue

                    frame_count += 1
                    if not self._repl.valid(a):
                        out.append(([action, left, right], prob))
                        continue

                    out.append(([action, left, right], prob + 1.0))

        out = sorted(out, key=lambda o: o[1], reverse=True)

        actions = []
        for i in range(beta_size):
            actions.append(out[i][0])

        return \
            torch.tensor(actions, dtype=torch.int64).to(self._device), \
            frame_count
Ejemplo n.º 4
0
    def observation(
        self,
    ) -> typing.Tuple[int, typing.List[Action], typing.List[Action], ]:
        actions = self._run.actions().copy()
        arguments = self._run.arguments().copy()

        # If the len match this is a final observation, so no extract will be
        # appended and that's fine because this observation won't make it to
        # the agent.
        if len(actions) < self._sequence_length:
            actions.append(Action.from_action('EXTRACT', None, None))

        # Finally we always return actions with the same length.
        empty = Action.from_action('EMPTY', None, None)
        while len(actions) < self._sequence_length:
            actions.append(empty)
        while len(arguments) < self._sequence_length:
            arguments.append(empty)

        return (self._run.len(), actions, arguments)
Ejemplo n.º 5
0
    def __getitem__(
        self,
        idx: int,
    ) -> typing.Tuple[int, typing.List[Action], typing.List[Action], Action,
                      float, ]:
        rdir = self._rdirs[idx]

        rfiles = sorted([
            os.path.join(rdir, f)
            for f in os.listdir(rdir) if re.search(".rollout$", f)
        ],
                        reverse=True)

        with gzip.open(rfiles[0], 'rb') as f:
            rollout = pickle.load(f)

        ptra, outcome = rollout.random()

        index = random.randrange(ptra.prepare_len(), ptra.len())

        assert idx <= self._sequence_length

        truth = ptra.actions()[index]
        actions = ptra.actions()[:index]
        arguments = ptra.arguments()[:index]

        value = 0.0
        if outcome:
            value = 1.0

        actions.append(Action.from_action('EXTRACT', None, None))

        empty = Action.from_action('EMPTY', None, None)
        while len(actions) < self._sequence_length:
            actions.append(empty)
        while len(arguments) < self._sequence_length:
            arguments.append(empty)

        return (index, actions, arguments, truth, value)
Ejemplo n.º 6
0
Archivo: beam.py Proyecto: spolu/z3ta
    def apply(
        self,
        ptra: ProofTraceActions,
        repl: REPL,
        beta_width: int,
        head_width: int,
    ) -> typing.List[typing.Tuple[float, Action], ]:
        a_count = min(
            beta_width,
            len(PROOFTRACE_TOKENS) - len(PREPARE_TOKENS),
        )
        top_actions = torch.exp(self._prd_actions.cpu()).topk(a_count)
        top_lefts = torch.exp(self._prd_lefts.cpu()).topk(beta_width)
        top_rights = torch.exp(self._prd_rights.cpu()).topk(beta_width)

        candidates = []

        for ia in range(a_count):
            for il in range(beta_width):
                for ir in range(beta_width):
                    action = top_actions[1][ia].item()
                    left = top_lefts[1][il].item()
                    right = top_rights[1][ir].item()

                    if left >= ptra.len() or right >= ptra.len():
                        continue

                    a = Action.from_action(
                        INV_PROOFTRACE_TOKENS[action + len(PREPARE_TOKENS)],
                        ptra.arguments()[left],
                        ptra.arguments()[right],
                    )

                    if ptra.seen(a):
                        continue

                    if not repl.valid(a):
                        continue

                    candidates.append((
                        self._value *  # PROB
                        top_actions[0][ia].item() * top_lefts[0][il].item() *
                        top_rights[0][ir].item(),
                        a))

        return sorted(candidates, key=lambda c: c[0],
                      reverse=True)[:head_width]
Ejemplo n.º 7
0
    def __getitem__(
        self,
        idx: int,
    ) -> typing.Tuple[int, typing.List[Action], typing.List[Action], Action, ]:
        rdir = self._rdirs[idx]

        rfiles = sorted([
            os.path.join(rdir, f)
            for f in os.listdir(rdir) if re.search(".rollout$", f)
        ],
                        reverse=True)

        with gzip.open(rfiles[0], 'rb') as f:
            rollout = pickle.load(f)

        # `actions/arguemnts` are going from 0 to `ptra.len()-1` padded with
        # EXTRACT (removing final QED). `truth` is going from 1 to `ptra.len()`
        # (with PREPARE_TOKENS replaced by EXTRACT) and padded with EXTRACT.

        ptra = rollout.positive()
        assert ptra.action_len() > 0

        ptra_len = min(ptra.len(), self._sequence_length)

        actions = ptra.actions()[:ptra_len - 1]
        arguments = ptra.arguments()[:ptra_len - 1]

        empty = ptra.actions()[1]
        assert empty.value == PREPARE_TOKENS['EMPTY']

        extract = Action.from_action('EXTRACT', empty, empty)

        truth = [extract] * (ptra.prepare_len()-1) + \
            ptra.actions()[ptra.prepare_len():ptra_len]

        while len(actions) < self._sequence_length:
            actions.append(extract)
        while len(arguments) < self._sequence_length:
            arguments.append(empty)
        while len(truth) < self._sequence_length:
            truth.append(extract)

        return (actions, arguments, truth)
Ejemplo n.º 8
0
Archivo: search.py Proyecto: spolu/z3ta
    def preprocess_ptra(
        self,
        ptra: ProofTraceActions,
    ) -> typing.Tuple[int, typing.List[Action], typing.List[Action], ]:
        actions = ptra.actions().copy()
        arguments = ptra.arguments().copy()

        index = len(actions) - 1
        assert index < self._config.get('prooftrace_sequence_length')

        empty = ptra.actions()[1]
        assert empty.value == PREPARE_TOKENS['EMPTY']

        extract = Action.from_action('EXTRACT', empty, empty)

        while len(actions) < self._config.get('prooftrace_sequence_length'):
            actions.append(extract)
        while len(arguments) < self._config.get('prooftrace_sequence_length'):
            arguments.append(empty)

        return index, actions, arguments
Ejemplo n.º 9
0
    def step(
            self,
            offset: int = 0,
            conclusion: bool = False,
    ) -> typing.Tuple[
        bool, typing.Optional[ProofTraceActions], bool,
    ]:
        index, actions, arguments = self.preprocess_ptra(self._ptra)

        idx = index
        act = [actions]
        arg = [arguments]

        with torch.no_grad():
            prd_actions, prd_lefts, prd_rights = \
                self._l_model.infer([idx], act, arg)

        beta_width = \
            self._config.get('prooftrace_search_policy_sample_beta_width')

        a_count = min(
            beta_width,
            len(PROOFTRACE_TOKENS) - len(PREPARE_TOKENS),
        )
        top_actions = torch.exp(prd_actions[0].cpu()).topk(a_count)
        top_lefts = torch.exp(prd_lefts[0].cpu()).topk(beta_width)
        top_rights = torch.exp(prd_rights[0].cpu()).topk(beta_width)

        candidates = []

        for ia in range(a_count):
            for il in range(beta_width):
                for ir in range(beta_width):

                    action = top_actions[1][ia].item()
                    left = top_lefts[1][il].item()
                    right = top_rights[1][ir].item()

                    if left >= self._ptra.len() or right >= self._ptra.len():
                        continue

                    a = Action.from_action(
                        INV_PROOFTRACE_TOKENS[action + len(PREPARE_TOKENS)],
                        self._ptra.arguments()[left],
                        self._ptra.arguments()[right],
                    )

                    if self._ptra.seen(a):
                        continue

                    if not self._repl.valid(a):
                        continue

                    candidates.append((
                        top_actions[0][ia].item() *
                        top_lefts[0][il].item() *
                        top_rights[0][ir].item(),
                        a
                    ))

        if len(candidates) == 0:
            return True, self._ptra, False

        action = sorted(
            candidates, key=lambda c: c[0], reverse=True
        )[0][1]

        thm = self._repl.apply(action)
        action._index = thm.index()
        argument = self._ptra.build_argument(
            thm.concl(), thm.hyp(), thm.index(),
        )
        self._ptra.append(action, argument)

        if self._target.thm_string(True) == thm.thm_string(True):
            return True, self._ptra, True

        return False, self._ptra, False
Ejemplo n.º 10
0
    def __init__(
        self,
        config: Config,
        parent,
        model: Model,
        ground: ProofTraceActions,
        target: Thm,
        ptra: ProofTraceActions,
        repl: REPL,
        prd_actions: torch.Tensor,
        prd_lefts: torch.Tensor,
        prd_rights: torch.Tensor,
        # value: float,
    ):
        self._config = config

        self._parent = parent

        self._model = model
        self._ground = ground
        self._target = target
        self._ptra = ptra
        self._repl = repl

        self._sequence_length = config.get('prooftrace_sequence_length')
        self._beta_width = config.get('prooftrace_lm_search_beta_width')

        a_count = min(
            self._beta_width,
            len(PROOFTRACE_TOKENS) - len(PREPARE_TOKENS),
        )
        top_actions = torch.exp(prd_actions).topk(a_count)
        top_lefts = torch.exp(prd_lefts).topk(self._beta_width)
        top_rights = torch.exp(prd_rights).topk(self._beta_width)

        actions = []

        for ia in range(a_count):
            for il in range(self._beta_width):
                for ir in range(self._beta_width):

                    action = top_actions[1][ia].item()
                    left = top_lefts[1][il].item()
                    right = top_rights[1][ir].item()

                    if left >= self._ptra.len() or right >= self._ptra.len():
                        continue

                    a = Action.from_action(
                        INV_PROOFTRACE_TOKENS[action + len(PREPARE_TOKENS)],
                        self._ptra.actions()[left],
                        self._ptra.actions()[right],
                    )

                    if self._ptra.seen(a):
                        continue

                    if not self._repl.valid(a):
                        continue

                    actions.append((a, top_actions[0][ia] * top_lefts[0][il] *
                                    top_rights[0][ir]))

        if len(actions) > 0:
            trc = []
            idx = []
            for a, p in actions:
                pre_trc, pre_idx = \
                    Node.prepare(self._ptra, a, self._sequence_length)
                trc.append(pre_trc)
                idx.append(pre_idx)

            prd_actions, prd_lefts, prd_rights, prd_values = \
                self._model.infer(trc, idx)

            self._queue = sorted(
                [(
                    actions[i][0],
                    prd_actions[i].to(torch.device('cpu')),
                    prd_lefts[i].to(torch.device('cpu')),
                    prd_rights[i].to(torch.device('cpu')),
                    prd_values[i].item(),
                    actions[i][1],
                ) for i in range(len(actions))],
                key=lambda t: self.potential(t),
            )
            self._min_value = self.queue_value()
        else:
            self._queue = []
            self._min_value = _MAX_VALUE

        self._children = []
Ejemplo n.º 11
0
    def apply(
            self,
            action: Action,
            fake: bool = False,
    ) -> Thm:
        action_token = INV_PROOFTRACE_TOKENS[action.value]

        thm = None

        if action_token == 'PREMISE':
            thm = Thm(
                action.index(),
                self.build_hypothesis(action.left),
                action.right.value,
            )
            thm = self._fusion.PREMISE(thm, fake)
        elif action_token == 'REFL':
            if action.left.value != PROOFTRACE_TOKENS['TERM']:
                raise REPLException
            if action.right.value != PROOFTRACE_TOKENS['EMPTY']:
                raise REPLException
            thm = self._fusion.REFL(
                action.left.left.value,
                fake,
            )
        elif action_token == 'TRANS':
            thm = self._fusion.TRANS(
                action.left.index(),
                action.right.index(),
                fake,
            )
        elif action_token == 'MK_COMB':
            thm = self._fusion.MK_COMB(
                action.left.index(),
                action.right.index(),
                fake,
            )
        elif action_token == 'ABS':
            if action.right.value != PROOFTRACE_TOKENS['TERM']:
                raise REPLException
            thm = self._fusion.ABS(
                action.left.index(),
                action.right.left.value,
                fake,
            )
        elif action_token == 'BETA':
            if action.left.value != PROOFTRACE_TOKENS['TERM']:
                raise REPLException
            if action.right.value != PROOFTRACE_TOKENS['EMPTY']:
                raise REPLException
            thm = self._fusion.BETA(
                action.left.left.value,
                fake,
            )
        elif action_token == 'ASSUME':
            if action.left.value != PROOFTRACE_TOKENS['TERM']:
                raise REPLException
            if action.right.value != PROOFTRACE_TOKENS['EMPTY']:
                raise REPLException
            thm = self._fusion.ASSUME(
                action.left.left.value,
                fake,
            )
        elif action_token == 'EQ_MP':
            thm = self._fusion.EQ_MP(
                action.left.index(),
                action.right.index(),
                fake,
            )
        elif action_token == 'DEDUCT_ANTISYM_RULE':
            thm = self._fusion.DEDUCT_ANTISYM_RULE(
                action.left.index(),
                action.right.index(),
                fake,
            )
        elif action_token == 'INST':
            def build_subst(subst):
                if subst is None:
                    return []
                if INV_PROOFTRACE_TOKENS[subst.value] == 'SUBST_PAIR':
                    return [[
                        subst.left.value,
                        subst.right.value,
                    ]]
                if INV_PROOFTRACE_TOKENS[subst.value] == 'SUBST':
                    return (
                        build_subst(subst.left) +
                        build_subst(subst.right)
                    )
                raise REPLException()

            if action.right.value != PROOFTRACE_TOKENS['SUBST']:
                raise REPLException

            thm = self._fusion.INST(
                action.left.index(),
                build_subst(action.right),
                fake,
            )
        elif action_token == 'INST_TYPE':
            def build_subst_type(subst_type):
                if subst_type is None:
                    return []
                if INV_PROOFTRACE_TOKENS[subst_type.value] == 'SUBST_PAIR':
                    return [[
                        subst_type.left.value,
                        subst_type.right.value,
                    ]]
                if INV_PROOFTRACE_TOKENS[subst_type.value] == 'SUBST_TYPE':
                    return (
                        build_subst_type(subst_type.left) +
                        build_subst_type(subst_type.right)
                    )
                raise REPLException()

            if action.right.value != PROOFTRACE_TOKENS['SUBST_TYPE']:
                raise REPLException

            thm = self._fusion.INST_TYPE(
                action.left.index(),
                build_subst_type(action.right),
                fake,
            )
        else:
            raise REPLException()

        try:
            return thm
        except AssertionError:
            Log.out("Action replay failure", {
                'action_token': action_token,
            })
            raise
Ejemplo n.º 12
0
Archivo: mcts.py Proyecto: spolu/z3ta
    def expand(
        self,
        beta_width: int,
        sequence_length: int,
        offset: int,
        l_model: LModel,
        v_model: VModel,
        target: Thm,
        step: int,
    ) -> typing.Tuple[float, ProofTraceActions, bool, ]:
        actions = self._ptra.actions().copy()
        arguments = self._ptra.arguments().copy()

        index = len(actions)

        empty = Action.from_action('EMPTY', None, None)
        while len(actions) < sequence_length:
            actions.append(empty)
        while len(arguments) < sequence_length:
            arguments.append(empty)

        with torch.no_grad():
            prd_actions, prd_lefts, prd_rights = \
                l_model.infer([index], [actions], [arguments])
            prd_values = \
                v_model.infer([index], [actions], [arguments])

        a_count = min(
            beta_width,
            len(PROOFTRACE_TOKENS) - len(PREPARE_TOKENS),
        )
        top_actions = torch.exp(prd_actions[0].cpu()).topk(a_count)
        top_lefts = torch.exp(prd_lefts[0].cpu()).topk(beta_width)
        top_rights = torch.exp(prd_rights[0].cpu()).topk(beta_width)

        value = prd_values[0].item() / self._ptra.action_len()

        candidates = []

        Log.out(
            "EXPAND",
            {
                'step': step,
                'value': "{:.3f}".format(value),
                'length': self._ptra.len(),
                'summary': self._ptra.summary(offset),
                # 'theorem': self._theorem.thm_string(True),
            })

        for ia in range(a_count):
            for il in range(beta_width):
                for ir in range(beta_width):

                    action = top_actions[1][ia].item()
                    left = top_lefts[1][il].item()
                    right = top_rights[1][ir].item()

                    if left >= self._ptra.len() or right >= self._ptra.len():
                        continue

                    a = Action.from_action(
                        INV_PROOFTRACE_TOKENS[action + len(PREPARE_TOKENS)],
                        self._ptra.arguments()[left],
                        self._ptra.arguments()[right],
                    )

                    if self._ptra.seen(a):
                        continue

                    if not self._repl.valid(a):
                        continue

                    candidates.append(
                        (top_actions[0][ia].item() * top_lefts[0][il].item() *
                         top_rights[0][ir].item(), a))

        candidates = sorted(candidates, key=lambda c: c[0], reverse=True)[0:8]

        for p, action in candidates:
            repl = self._repl.copy()
            ptra = self._ptra.copy()

            thm = repl.apply(action)
            action._index = thm.index()

            argument = ptra.build_argument(
                thm.concl(),
                thm.hyp(),
                thm.index(),
            )
            ptra.append(action, argument)

            if target.thm_string(True) == thm.thm_string(True):
                Log.out(
                    "DEMONSTRATED", {
                        'theorem': thm.thm_string(True),
                        'summary': ptra.summary(offset),
                    })
                return value, True, ptra

            self._children.append(Node(
                self,
                p,
                repl,
                ptra,
                thm,
            ))

        self._expanded = True

        return value, False, self._ptra
Ejemplo n.º 13
0
    def step(
        self,
        offset: int = 0,
        conclusion: bool = False,
    ) -> typing.Tuple[bool, typing.Optional[ProofTraceActions], bool, ]:
        idx = []
        act = []
        arg = []

        for p in self._particles:
            index, actions, arguments = self.preprocess_ptra(p['ptra'])

            idx += [index]
            act += [actions]
            arg += [arguments]

        with torch.no_grad():
            prd_actions, prd_lefts, prd_rights = \
                self._l_model.infer(idx, act, arg)

        print("PARTICLES LEN {}".format(len(self._particles)))

        beta_width = 16
        samples = {}

        a_count = min(
            beta_width,
            len(PROOFTRACE_TOKENS) - len(PREPARE_TOKENS),
        )
        top_actions = torch.exp(prd_actions.cpu()).topk(a_count)
        top_lefts = torch.exp(prd_lefts.cpu()).topk(beta_width)
        top_rights = torch.exp(prd_rights.cpu()).topk(beta_width)

        for ia in range(a_count):
            for il in range(beta_width):
                for ir in range(beta_width):
                    for i, p in enumerate(self._particles):
                        action = top_actions[1][i][ia].item()
                        left = top_lefts[1][i][il].item()
                        right = top_rights[1][i][ir].item()

                        if left >= p['ptra'].len() or right >= p['ptra'].len():
                            continue

                        a = Action.from_action(
                            INV_PROOFTRACE_TOKENS[action +
                                                  len(PREPARE_TOKENS)],
                            p['ptra'].arguments()[left],
                            p['ptra'].arguments()[right],
                        )

                        if p['ptra'].seen(a):
                            continue

                        if not p['repl'].valid(a):
                            continue

                        h = p['ptra'].actions()[-1].hash() + a.hash()

                        if h not in samples:
                            repl = p['repl'].copy()
                            ptra = p['ptra'].copy()

                            thm = repl.apply(a)
                            a._index = thm.index()

                            argument = ptra.build_argument(
                                thm.concl(),
                                thm.hyp(),
                                thm.index(),
                            )
                            ptra.append(a, argument)

                            if self._target.thm_string(True) == \
                                    thm.thm_string(True):
                                return True, ptra, True

                            # print(
                            #     "STORE {} {} {}  {} {} {}".format(
                            #         len(PREPARE_TOKENS) + action,
                            #         left,
                            #         right,
                            #         torch.exp(prd_actions)[i][action],
                            #         torch.exp(prd_lefts)[i][left],
                            #         torch.exp(prd_rights)[i][right],
                            #     ),
                            # )

                            samples[h] = {
                                'repl': repl,
                                'ptra': ptra,
                            }

                        if len(samples) >= self._sample_size:
                            break
                    if len(samples) >= self._sample_size:
                        break
                if len(samples) >= self._sample_size:
                    break
            if len(samples) >= self._sample_size:
                break

        # Resampling based on value
        samples = list(samples.values())
        # import pdb; pdb.set_trace();

        # print("SAMPLES LEN {}".format(len(samples)))

        if len(samples) == 0:
            return True, self._particles[0]['ptra'], False

        idx = []
        act = []
        arg = []

        for p in samples:
            index, actions, arguments = self.preprocess_ptra(p['ptra'])

            idx += [index]
            act += [actions]
            arg += [arguments]

        with torch.no_grad():
            prd_values = \
                self._v_model.infer(idx, act, arg)

        costs = F.softmax(
            (prd_values.squeeze(1) - prd_values.mean()) /
            (prd_values.std() + 1e-7),
            dim=0,
        )

        # for i, p in enumerate(samples):
        #     print("COST {} {}".format(costs[i].item(), prd_values[i].item()))

        m = D.Categorical(logits=costs)
        indices = m.sample((self._filter_size, )).cpu().numpy()
        self._particles = []

        for idx in indices:
            self._particles.append(samples[idx])

        return False, self._particles[0]['ptra'], False
Ejemplo n.º 14
0
    def step(
        self,
        action: typing.Tuple[int, int, int],
        step_reward_prob: float,
        match_reward_prob: float,
        gamma: float,
        fixed_gamma: int,
    ) -> typing.Tuple[typing.Tuple[int, typing.List[Action]], typing.Tuple[
            float, float, float], bool, typing.Dict[str, int], ]:
        assert self._ground is not None
        assert self._run is not None

        def finish(rewards, done, info):
            if done:
                observation = self.reset(gamma, fixed_gamma)
            else:
                observation = self.observation()
            return observation, rewards, done, info

        if action[1] >= self._run.len() or action[2] >= self._run.len():
            Log.out(
                "DONE ILLEGAL[overflow]", {
                    'ground_length': self._ground.action_len(),
                    'gamma_length': self._gamma_len,
                    'run_length': self._run.action_len() - self._gamma_len,
                    'name': self._ground.name(),
                })
            return finish(
                (0.0, 0.0, 0.0), True, {
                    'match_count': self._match_count,
                    'run_length': self._run.action_len() - self._gamma_len,
                })

        action = Action.from_action(
            INV_PROOFTRACE_TOKENS[action[0] + len(PREPARE_TOKENS)],
            self._run.arguments()[action[1]],
            self._run.arguments()[action[2]],
        )

        if self._run.seen(action):
            Log.out(
                "DONE ILLEGAL[seen]", {
                    'ground_length': self._ground.action_len(),
                    'gamma_length': self._gamma_len,
                    'run_length': self._run.action_len() - self._gamma_len,
                    'name': self._ground.name(),
                })
            return finish(
                (0.0, 0.0, 0.0), True, {
                    'match_count': self._match_count,
                    'run_length': self._run.action_len() - self._gamma_len,
                })

        try:
            thm = self._repl.apply(action)
        except (FusionException, REPLException, TypeException):
            Log.out(
                "DONE ILLEGAL[fusion]", {
                    'ground_length': self._ground.action_len(),
                    'gamma_length': self._gamma_len,
                    'run_length': self._run.action_len() - self._gamma_len,
                    'name': self._ground.name(),
                })
            return finish(
                (0.0, 0.0, 0.0), True, {
                    'match_count': self._match_count,
                    'run_length': self._run.action_len() - self._gamma_len,
                })

        action._index = thm.index()
        argument = self._run.build_argument(
            thm.concl(),
            thm.hyp(),
            thm.index(),
        )
        self._run.append(action, argument)

        step_reward = 0.0
        match_reward = 0.0
        final_reward = 0.0
        done = False
        info = {}

        if step_reward_prob > 0.0 and random.random() < step_reward_prob:
            step_reward = 1.0

        if self._ground.seen(action):
            self._match_count += 1
            if match_reward_prob > 0.0 and random.random() < match_reward_prob:
                match_reward = 1.0
                step_reward = 0.0

        if self._target.thm_string(True) == thm.thm_string(True):
            final_reward = 10.0
            done = True
            info['demo_length'] = min(
                self._run.action_len(),
                self._ground.action_len(),
            ) - self._gamma_len
            info['demo_delta'] = \
                self._run.action_len() - self._ground.action_len()
            Log.out(
                "DEMONSTRATED", {
                    'ground_length': self._ground.action_len(),
                    'gamma_length': self._gamma_len,
                    'run_length': self._run.action_len() - self._gamma_len,
                    'name': self._ground.name(),
                })
        if self._run.len() >= self._sequence_length:
            done = True
            Log.out(
                "DONE LENGTH ", {
                    'ground_length': self._ground.action_len(),
                    'gamma_length': self._gamma_len,
                    'run_length': self._run.action_len() - self._gamma_len,
                    'name': self._ground.name(),
                })

        if done:
            info['match_count'] = self._match_count
            info['run_length'] = self._run.action_len() - self._gamma_len

        return finish((step_reward, match_reward, final_reward), done, info)