Ejemplo n.º 1
0
    def sample(
        self,
        ptra: ProofTraceActions,
        repl: REPL,
        tries: int,
        conclusion: bool = False,
    ) -> Action:
        for i in range(tries):
            if not conclusion:
                action = random.choice(list(NON_PREPARE_TOKENS.values()))
            else:
                action = random.choice(list(CONCLUSION_TOKENS.values()))

            if INV_PROOFTRACE_TOKENS[action] == 'REFL':
                left = self.sample_term()
                right = 0
            if INV_PROOFTRACE_TOKENS[action] == 'TRANS':
                left = self.sample_theorem(ptra)
                right = self.sample_theorem(ptra)
            if INV_PROOFTRACE_TOKENS[action] == 'MK_COMB':
                left = self.sample_theorem(ptra)
                right = self.sample_theorem(ptra)
            if INV_PROOFTRACE_TOKENS[action] == 'ABS':
                left = self.sample_theorem(ptra)
                right = self.sample_term()
            if INV_PROOFTRACE_TOKENS[action] == 'BETA':
                left = self.sample_term()
                right = 0
            if INV_PROOFTRACE_TOKENS[action] == 'ASSUME':
                left = self.sample_term()
                right = 0
            if INV_PROOFTRACE_TOKENS[action] == 'EQ_MP':
                left = self.sample_theorem(ptra)
                right = self.sample_theorem(ptra)
            if INV_PROOFTRACE_TOKENS[action] == 'DEDUCT_ANTISYM_RULE':
                left = self.sample_theorem(ptra)
                right = self.sample_theorem(ptra)
            if INV_PROOFTRACE_TOKENS[action] == 'INST':
                left = self.sample_theorem(ptra)
                right = self.sample_subst()
            if INV_PROOFTRACE_TOKENS[action] == 'INST_TYPE':
                left = self.sample_theorem(ptra)
                right = self.sample_subst_type()

            a = Action.from_action(
                INV_PROOFTRACE_TOKENS[action],
                ptra.arguments()[left],
                ptra.arguments()[right],
            )

            if ptra.seen(a):
                continue

            if not repl.valid(a):
                continue

            return a

        return None
Ejemplo n.º 2
0
Archivo: beam.py Proyecto: spolu/z3ta
    def apply(
        self,
        ptra: ProofTraceActions,
        repl: REPL,
        beta_width: int,
        head_width: int,
    ) -> typing.List[typing.Tuple[float, Action], ]:
        a_count = min(
            beta_width,
            len(PROOFTRACE_TOKENS) - len(PREPARE_TOKENS),
        )
        top_actions = torch.exp(self._prd_actions.cpu()).topk(a_count)
        top_lefts = torch.exp(self._prd_lefts.cpu()).topk(beta_width)
        top_rights = torch.exp(self._prd_rights.cpu()).topk(beta_width)

        candidates = []

        for ia in range(a_count):
            for il in range(beta_width):
                for ir in range(beta_width):
                    action = top_actions[1][ia].item()
                    left = top_lefts[1][il].item()
                    right = top_rights[1][ir].item()

                    if left >= ptra.len() or right >= ptra.len():
                        continue

                    a = Action.from_action(
                        INV_PROOFTRACE_TOKENS[action + len(PREPARE_TOKENS)],
                        ptra.arguments()[left],
                        ptra.arguments()[right],
                    )

                    if ptra.seen(a):
                        continue

                    if not repl.valid(a):
                        continue

                    candidates.append((
                        self._value *  # PROB
                        top_actions[0][ia].item() * top_lefts[0][il].item() *
                        top_rights[0][ir].item(),
                        a))

        return sorted(candidates, key=lambda c: c[0],
                      reverse=True)[:head_width]
Ejemplo n.º 3
0
class Env:
    def __init__(
        self,
        config: Config,
        test: bool,
    ) -> None:
        self._sequence_length = config.get('prooftrace_sequence_length')

        self._device = torch.device(config.get('device'))

        if test:
            dataset_dir = os.path.join(
                os.path.expanduser(config.get('prooftrace_dataset_dir')),
                config.get('prooftrace_dataset_size'), 'test_traces')
        else:
            dataset_dir = os.path.join(
                os.path.expanduser(config.get('prooftrace_dataset_dir')),
                config.get('prooftrace_dataset_size'), 'train_traces')
        assert os.path.isdir(dataset_dir)

        self._trace_files = [
            os.path.join(dataset_dir, f) for f in os.listdir(dataset_dir)
            if (os.path.isfile(os.path.join(dataset_dir, f))
                and re.search("\\.actions$", f) is not None)
        ]

        with gzip.open(
                os.path.join(
                    os.path.expanduser(config.get('prooftrace_dataset_dir')),
                    config.get('prooftrace_dataset_size'),
                    'traces.tokenizer',
                ), 'rb') as f:
            self._tokenizer = pickle.load(f)

        self._ground = None
        self._run = None
        self._repl = None
        self._target = None
        self._alpha = 0

    def reset(
        self,
        gamma: float,
        fixed_gamma: int,
    ) -> typing.Tuple[int, typing.List[Action]]:
        self._ground = None
        self._run = None
        self._repl = None
        self._target = None
        self._alpha = 0
        self._gamma_len = 0

        self._match_count = 0

        while self._ground is None:
            path = random.choice(self._trace_files)

            match = re.search("_(\\d+)_(\\d+)\\.actions$", path)
            ptra_len = int(match.group(1))

            if ptra_len <= self._sequence_length:
                with gzip.open(path, 'rb') as f:
                    self._ground = pickle.load(f)
                # Log.out("Selecting trace", {
                #     "trace": self._ground.name(),
                #     'length': self._ground.len(),
                # })

        self._run = ProofTraceActions(
            'REPL-{}-{}'.format(
                datetime.datetime.now().strftime("%Y%m%d_%H%M_%S.%f"),
                random.randint(0, 9999),
            ),
            [
                self._ground.actions()[i] for i in range(self._ground.len())
                if self._ground.actions()[i].value in INV_PREPARE_TOKENS
            ],
            [
                self._ground.arguments()[i] for i in range(self._ground.len())
                if self._ground.actions()[i].value in INV_PREPARE_TOKENS
            ],
        )

        self._repl = REPL(self._tokenizer)
        self._target = self._repl.prepare(self._run)

        # GAMMA Initialization.
        if gamma > 0.0 and random.random() < gamma:
            if fixed_gamma > 0:
                self._gamma_len = self._ground.action_len() - \
                    random.randrange(
                        1, min(fixed_gamma, self._ground.action_len()) + 1
                    )
            else:
                self._gamma_len = random.randrange(0,
                                                   self._ground.action_len())

            for i in range(self._gamma_len):
                assert self._ground.prepare_len() + i < self._ground.len() - 1
                pos = self._ground.prepare_len() + i
                action = self._ground.actions()[pos]
                argument = self._ground.arguments()[pos]

                thm = self._repl.apply(action)

                action._index = thm.index()
                argument._index = thm.index()

                self._run.append(action, argument)

        return self.observation()

    def observation(
        self,
    ) -> typing.Tuple[int, typing.List[Action], typing.List[Action], ]:
        actions = self._run.actions().copy()
        arguments = self._run.arguments().copy()

        # If the len match this is a final observation, so no extract will be
        # appended and that's fine because this observation won't make it to
        # the agent.
        if len(actions) < self._sequence_length:
            actions.append(Action.from_action('EXTRACT', None, None))

        # Finally we always return actions with the same length.
        empty = Action.from_action('EMPTY', None, None)
        while len(actions) < self._sequence_length:
            actions.append(empty)
        while len(arguments) < self._sequence_length:
            arguments.append(empty)

        return (self._run.len(), actions, arguments)

    def alpha_oracle(self, ) -> typing.Tuple[torch.Tensor, int]:
        self._alpha += 1
        for i in range(self._ground.prepare_len(), self._ground.len()):
            a = self._ground.actions()[i]
            if (not self._run.seen(a)) and \
                    self._run.seen(a.left) and \
                    self._run.seen(a.right):
                assert 0 <= a.value - len(PREPARE_TOKENS)
                assert a.value < len(PROOFTRACE_TOKENS)
                actions = torch.tensor([[
                    a.value - len(PREPARE_TOKENS),
                    self._run.hashes()[a.left.hash()],
                    self._run.hashes()[a.right.hash()],
                ]],
                                       dtype=torch.int64).to(self._device)
                return actions, 0

        # We may reach this point as final actions are sometime repeated at the
        # end of prooftraces.
        return None, 0

    def beta_oracle(
        self,
        prd_actions: torch.Tensor,
        prd_lefts: torch.Tensor,
        prd_rights: torch.Tensor,
        beta_width: int,
        beta_size: int,
    ) -> typing.Tuple[torch.Tensor, int]:
        top_actions = torch.exp(prd_actions).topk(beta_width)
        top_lefts = torch.exp(prd_lefts).topk(beta_width)
        top_rights = torch.exp(prd_rights).topk(beta_width)

        out = []
        frame_count = 0

        for ia in range(beta_width):
            for il in range(beta_width):
                for ir in range(beta_width):
                    action = top_actions[1][ia].item()
                    assert action >= 0
                    assert action < len(PROOFTRACE_TOKENS) - len(
                        PREPARE_TOKENS)
                    left = top_lefts[1][il].item()
                    right = top_rights[1][ir].item()
                    prob = top_actions[0][ia].item() * \
                        top_lefts[0][il].item() * \
                        top_rights[0][ir].item()

                    if left >= self._run.len() or right >= self._run.len():
                        out.append(([action, left, right], prob))
                        continue

                    a = Action.from_action(
                        INV_PROOFTRACE_TOKENS[action + len(PREPARE_TOKENS)],
                        self._run.arguments()[left],
                        self._run.arguments()[right],
                    )

                    if self._run.seen(a):
                        out.append(([action, left, right], prob))
                        continue

                    frame_count += 1
                    if not self._repl.valid(a):
                        out.append(([action, left, right], prob))
                        continue

                    out.append(([action, left, right], prob + 1.0))

        out = sorted(out, key=lambda o: o[1], reverse=True)

        actions = []
        for i in range(beta_size):
            actions.append(out[i][0])

        return \
            torch.tensor(actions, dtype=torch.int64).to(self._device), \
            frame_count

    def explore(
        self,
        prd_actions: torch.Tensor,
        prd_lefts: torch.Tensor,
        prd_rights: torch.Tensor,
        alpha: float,
        beta: float,
        beta_width: int,
    ) -> typing.Tuple[torch.Tensor, int]:

        # ALPHA Oracle.
        if alpha > 0.0 and random.random() < alpha and self._alpha == 0:
            actions, frame_count = self.alpha_oracle()
            if actions is not None:
                return actions, frame_count

        # BETA Oracle.
        if beta > 0.0 and random.random() < beta:
            return self.beta_oracle(
                prd_actions,
                prd_lefts,
                prd_rights,
                beta_width,
                1,
            )

        # Sampling.
        actions = torch.cat((
            Categorical(
                torch.exp(prd_actions)).sample().unsqueeze(0).unsqueeze(1),
            Categorical(
                torch.exp(prd_lefts)).sample().unsqueeze(0).unsqueeze(1),
            Categorical(
                torch.exp(prd_rights)).sample().unsqueeze(0).unsqueeze(1),
        ),
                            dim=1)

        return actions, 0

    def step(
        self,
        action: typing.Tuple[int, int, int],
        step_reward_prob: float,
        match_reward_prob: float,
        gamma: float,
        fixed_gamma: int,
    ) -> typing.Tuple[typing.Tuple[int, typing.List[Action]], typing.Tuple[
            float, float, float], bool, typing.Dict[str, int], ]:
        assert self._ground is not None
        assert self._run is not None

        def finish(rewards, done, info):
            if done:
                observation = self.reset(gamma, fixed_gamma)
            else:
                observation = self.observation()
            return observation, rewards, done, info

        if action[1] >= self._run.len() or action[2] >= self._run.len():
            Log.out(
                "DONE ILLEGAL[overflow]", {
                    'ground_length': self._ground.action_len(),
                    'gamma_length': self._gamma_len,
                    'run_length': self._run.action_len() - self._gamma_len,
                    'name': self._ground.name(),
                })
            return finish(
                (0.0, 0.0, 0.0), True, {
                    'match_count': self._match_count,
                    'run_length': self._run.action_len() - self._gamma_len,
                })

        action = Action.from_action(
            INV_PROOFTRACE_TOKENS[action[0] + len(PREPARE_TOKENS)],
            self._run.arguments()[action[1]],
            self._run.arguments()[action[2]],
        )

        if self._run.seen(action):
            Log.out(
                "DONE ILLEGAL[seen]", {
                    'ground_length': self._ground.action_len(),
                    'gamma_length': self._gamma_len,
                    'run_length': self._run.action_len() - self._gamma_len,
                    'name': self._ground.name(),
                })
            return finish(
                (0.0, 0.0, 0.0), True, {
                    'match_count': self._match_count,
                    'run_length': self._run.action_len() - self._gamma_len,
                })

        try:
            thm = self._repl.apply(action)
        except (FusionException, REPLException, TypeException):
            Log.out(
                "DONE ILLEGAL[fusion]", {
                    'ground_length': self._ground.action_len(),
                    'gamma_length': self._gamma_len,
                    'run_length': self._run.action_len() - self._gamma_len,
                    'name': self._ground.name(),
                })
            return finish(
                (0.0, 0.0, 0.0), True, {
                    'match_count': self._match_count,
                    'run_length': self._run.action_len() - self._gamma_len,
                })

        action._index = thm.index()
        argument = self._run.build_argument(
            thm.concl(),
            thm.hyp(),
            thm.index(),
        )
        self._run.append(action, argument)

        step_reward = 0.0
        match_reward = 0.0
        final_reward = 0.0
        done = False
        info = {}

        if step_reward_prob > 0.0 and random.random() < step_reward_prob:
            step_reward = 1.0

        if self._ground.seen(action):
            self._match_count += 1
            if match_reward_prob > 0.0 and random.random() < match_reward_prob:
                match_reward = 1.0
                step_reward = 0.0

        if self._target.thm_string(True) == thm.thm_string(True):
            final_reward = 10.0
            done = True
            info['demo_length'] = min(
                self._run.action_len(),
                self._ground.action_len(),
            ) - self._gamma_len
            info['demo_delta'] = \
                self._run.action_len() - self._ground.action_len()
            Log.out(
                "DEMONSTRATED", {
                    'ground_length': self._ground.action_len(),
                    'gamma_length': self._gamma_len,
                    'run_length': self._run.action_len() - self._gamma_len,
                    'name': self._ground.name(),
                })
        if self._run.len() >= self._sequence_length:
            done = True
            Log.out(
                "DONE LENGTH ", {
                    'ground_length': self._ground.action_len(),
                    'gamma_length': self._gamma_len,
                    'run_length': self._run.action_len() - self._gamma_len,
                    'name': self._ground.name(),
                })

        if done:
            info['match_count'] = self._match_count
            info['run_length'] = self._run.action_len() - self._gamma_len

        return finish((step_reward, match_reward, final_reward), done, info)