Ejemplo n.º 1
0
    def __init__(
        self,
        config: Config,
    ):
        self._config = config

        self._device = torch.device(config.get('device'))

        self._model = LModel(config)

        self._rollout_dir = os.path.join(
            os.path.expanduser(config.get('prooftrace_rollout_dir')),
            config.get('prooftrace_dataset_size'),
            'train_rollouts',
        )
        with gzip.open(
                os.path.join(
                    os.path.expanduser(config.get('prooftrace_dataset_dir')),
                    config.get('prooftrace_dataset_size'),
                    'traces.tokenizer',
                ), 'rb') as f:
            self._tokenizer = pickle.load(f)

        self._wrk = IOTAWrk(
            config.get('prooftrace_lm_iota_sync_dir'),
            'rollout',
            self._model.modules(),
        )

        self._type = config.get('prooftrace_search_type')

        Log.out('WRK initialization', {})
Ejemplo n.º 2
0
    def __init__(
        self,
        config: Config,
        train_dataset: ProofTraceLMDataset,
    ):
        self._config = config

        self._action_coeff = config.get('prooftrace_lm_action_coeff')
        self._grad_norm_max = config.get('prooftrace_lm_grad_norm_max')

        self._device = torch.device(config.get('device'))

        self._sequence_length = config.get('prooftrace_sequence_length')

        self._model = LModel(config)
        self._ack = IOTAAck(
            config.get('prooftrace_lm_iota_sync_dir'),
            self._model.modules(),
        )

        self._nll_loss = nn.NLLLoss()

        self._train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=self._config.get('prooftrace_lm_batch_size'),
            shuffle=True,
            collate_fn=lm_collate,
        )

        Log.out('ACK initialization', {
            "batch_size": self._config.get('prooftrace_lm_batch_size'),
        })

        self._train_batch = 0
Ejemplo n.º 3
0
    def __init__(
        self,
        config: Config,
    ):
        self._config = config

        self._learning_rate = config.get('prooftrace_lm_learning_rate')
        self._min_update_count = \
            config.get('prooftrace_lm_iota_min_update_count')

        self._device = torch.device(config.get('device'))

        self._save_dir = config.get('prooftrace_save_dir')
        self._load_dir = config.get('prooftrace_load_dir')

        self._epoch = 0
        self._last_update = None

        self._tb_writer = None
        if self._config.get('tensorboard_log_dir'):
            self._tb_writer = SummaryWriter(
                self._config.get('tensorboard_log_dir'), )

        self._model = LModel(config)

        Log.out(
            "SYN Initializing",
            {
                'parameters_count_pE':
                self._model.modules()['pE'].parameters_count(),
                'parameters_count_pT':
                self._model.modules()['pT'].parameters_count(),
                'parameters_count_pH':
                self._model.modules()['pH'].parameters_count(),
            },
        )

        self._syn = IOTASyn(
            config.get('prooftrace_lm_iota_sync_dir'),
            self._model.modules(),
        )

        self._policy_optimizer = optim.Adam(
            [
                {
                    'params': self._model.modules()['pE'].parameters()
                },
                {
                    'params': self._model.modules()['pT'].parameters()
                },
                {
                    'params': self._model.modules()['pH'].parameters()
                },
            ],
            lr=self._learning_rate,
        )

        self._syn.broadcast({'config': self._config})
Ejemplo n.º 4
0
class SYN:
    def __init__(
        self,
        config: Config,
    ):
        self._config = config

        self._learning_rate = config.get('prooftrace_lm_learning_rate')
        self._min_update_count = \
            config.get('prooftrace_lm_iota_min_update_count')

        self._device = torch.device(config.get('device'))

        self._save_dir = config.get('prooftrace_save_dir')
        self._load_dir = config.get('prooftrace_load_dir')

        self._epoch = 0
        self._last_update = None

        self._tb_writer = None
        if self._config.get('tensorboard_log_dir'):
            self._tb_writer = SummaryWriter(
                self._config.get('tensorboard_log_dir'), )

        self._model = LModel(config)

        Log.out(
            "SYN Initializing",
            {
                'parameters_count_pE':
                self._model.modules()['pE'].parameters_count(),
                'parameters_count_pT':
                self._model.modules()['pT'].parameters_count(),
                'parameters_count_pH':
                self._model.modules()['pH'].parameters_count(),
            },
        )

        self._syn = IOTASyn(
            config.get('prooftrace_lm_iota_sync_dir'),
            self._model.modules(),
        )

        self._policy_optimizer = optim.Adam(
            [
                {
                    'params': self._model.modules()['pE'].parameters()
                },
                {
                    'params': self._model.modules()['pT'].parameters()
                },
                {
                    'params': self._model.modules()['pH'].parameters()
                },
            ],
            lr=self._learning_rate,
        )

        self._syn.broadcast({'config': self._config})

    def load(
        self,
        training=True,
    ):

        if self._load_dir:
            Log.out("Loading prooftrace models", {
                'load_dir': self._load_dir,
            })

            self._model.load()

            if training and os.path.isfile(self._load_dir + "/optimizer.pt"):
                self._policy_optimizer.load_state_dict(
                    torch.load(
                        self._load_dir + "/policy_optimizer.pt",
                        map_location=self._device,
                    ), )

        self._syn.broadcast({'config': self._config})

        return self

    def save(self, ):
        if self._save_dir:
            Log.out("Saving prooftrace models", {
                'save_dir': self._save_dir,
            })

            self._model.save()

            torch.save(
                self._policy_optimizer.state_dict(),
                self._save_dir + "/policy_optimizer.pt",
            )

    def update(self, ) -> None:
        update = self._config.update()
        if update:
            if 'prooftrace_lm_learning_rate' in update:
                lr = self._config.get('prooftrace_lm_learning_rate')
                if lr != self._learning_rate:
                    self._learning_rate = lr
                    for group in self._policy_optimizer.param_groups:
                        group['lr'] = lr
                    Log.out("Updated", {
                        "prooftrace_lm_learning_rate": lr,
                    })
            if 'prooftrace_lm_iota_min_update_count' in update:
                cnt = \
                    self._config.get('prooftrace_lm_iota_min_update_count')
                if cnt != self._min_update_count:
                    self._min_update_count = cnt
                    Log.out("Updated", {
                        "prooftrace_lm_iota_min_update_count": cnt,
                    })

            if self._tb_writer is not None:
                for k in update:
                    if k in [
                            'prooftrace_lm_learning_rate',
                            'prooftrace_lm_iota_min_update_count',
                            'prooftrace_lm_action_coeff',
                    ]:
                        self._tb_writer.add_scalar(
                            "prooftrace_lm_train_run/{}".format(k),
                            update[k],
                            self._epoch,
                        )

    def run_once(self, ):
        for m in self._model.modules():
            self._model.modules()[m].train()

        run_start = time.time()

        self._policy_optimizer.zero_grad()

        infos = self._syn.reduce(self._device, self._min_update_count)

        if len(infos) == 0:
            time.sleep(1)
            return

        self._policy_optimizer.step()

        self._syn.broadcast({'config': self._config})

        if self._last_update is not None:
            update_delta = time.time() - self._last_update
        else:
            update_delta = 0.0
        self._last_update = time.time()

        act_loss_meter = Meter()
        lft_loss_meter = Meter()
        rgt_loss_meter = Meter()
        test_act_loss_meter = Meter()
        test_lft_loss_meter = Meter()
        test_rgt_loss_meter = Meter()

        for info in infos:
            if 'act_loss' in info:
                act_loss_meter.update(info['act_loss'])
            if 'lft_loss' in info:
                lft_loss_meter.update(info['lft_loss'])
            if 'rgt_loss' in info:
                rgt_loss_meter.update(info['rgt_loss'])
            if 'test_act_loss' in info:
                test_act_loss_meter.update(info['test_act_loss'])
            if 'test_lft_loss' in info:
                test_lft_loss_meter.update(info['test_lft_loss'])
            if 'test_rgt_loss' in info:
                test_rgt_loss_meter.update(info['test_rgt_loss'])

        Log.out(
            "PROOFTRACE SYN RUN", {
                'epoch': self._epoch,
                'run_time': "{:.2f}".format(time.time() - run_start),
                'update_count': len(infos),
                'update_delta': "{:.2f}".format(update_delta),
                'act_loss': "{:.4f}".format(act_loss_meter.avg or 0.0),
                'lft_loss': "{:.4f}".format(lft_loss_meter.avg or 0.0),
                'rgt_loss': "{:.4f}".format(rgt_loss_meter.avg or 0.0),
                'test_act_loss': "{:.4f}".format(test_act_loss_meter.avg
                                                 or 0.0),
                'test_lft_loss': "{:.4f}".format(test_lft_loss_meter.avg
                                                 or 0.0),
                'test_rgt_loss': "{:.4f}".format(test_rgt_loss_meter.avg
                                                 or 0.0),
            })

        if self._tb_writer is not None:
            self._tb_writer.add_scalar(
                "prooftrace_lm_train/update_delta",
                update_delta,
                self._epoch,
            )
            self._tb_writer.add_scalar(
                "prooftrace_lm_train/update_count",
                len(infos),
                self._epoch,
            )
            if act_loss_meter.avg is not None:
                self._tb_writer.add_scalar(
                    "prooftrace_lm_train/act_loss",
                    act_loss_meter.avg,
                    self._epoch,
                )
            if lft_loss_meter.avg is not None:
                self._tb_writer.add_scalar(
                    "prooftrace_lm_train/lft_loss",
                    lft_loss_meter.avg,
                    self._epoch,
                )
            if rgt_loss_meter.avg is not None:
                self._tb_writer.add_scalar(
                    "prooftrace_lm_train/rgt_loss",
                    rgt_loss_meter.avg,
                    self._epoch,
                )

            if test_act_loss_meter.avg is not None:
                self._tb_writer.add_scalar(
                    "prooftrace_lm_test/act_loss",
                    test_act_loss_meter.avg,
                    self._epoch,
                )
            if test_lft_loss_meter.avg is not None:
                self._tb_writer.add_scalar(
                    "prooftrace_lm_test/lft_loss",
                    test_lft_loss_meter.avg,
                    self._epoch,
                )
            if test_rgt_loss_meter.avg is not None:
                self._tb_writer.add_scalar(
                    "prooftrace_lm_test/rgt_loss",
                    test_rgt_loss_meter.avg,
                    self._epoch,
                )

        self._epoch += 1

        if self._epoch % 100 == 0:
            self.save()
Ejemplo n.º 5
0
class ACK:
    def __init__(
        self,
        config: Config,
        train_dataset: ProofTraceLMDataset,
    ):
        self._config = config

        self._action_coeff = config.get('prooftrace_lm_action_coeff')
        self._grad_norm_max = config.get('prooftrace_lm_grad_norm_max')

        self._device = torch.device(config.get('device'))

        self._sequence_length = config.get('prooftrace_sequence_length')

        self._model = LModel(config)
        self._ack = IOTAAck(
            config.get('prooftrace_lm_iota_sync_dir'),
            self._model.modules(),
        )

        self._nll_loss = nn.NLLLoss()

        self._train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=self._config.get('prooftrace_lm_batch_size'),
            shuffle=True,
            collate_fn=lm_collate,
        )

        Log.out('ACK initialization', {
            "batch_size": self._config.get('prooftrace_lm_batch_size'),
        })

        self._train_batch = 0

    def update(
        self,
        config: Config,
    ) -> None:
        self._config = config

        coeff = self._config.get('prooftrace_lm_action_coeff')
        if coeff != self._action_coeff:
            self._action_coeff = coeff
            Log.out("Updated", {
                "prooftrace_lm_action_coeff": coeff,
            })

    def run_once(
        self,
        epoch,
    ):
        for it, (act, arg, trh) in enumerate(self._train_loader):
            info = self._ack.fetch(self._device)
            if info is not None:
                self.update(info['config'])
            self._model.train()

            trh_actions, trh_lefts, trh_rights = trh_extract(trh, arg)

            # Because we can't run a pointer network on the full length
            # (memory), we extract indices to focus loss on.
            idx = random.sample(range(self._sequence_length), 64)

            actions = torch.index_select(
                torch.tensor(trh_actions, dtype=torch.int64),
                1,
                torch.tensor(idx, dtype=torch.int64),
            ).to(self._device)
            lefts = torch.index_select(
                torch.tensor(trh_lefts, dtype=torch.int64),
                1,
                torch.tensor(idx, dtype=torch.int64),
            ).to(self._device)
            rights = torch.index_select(
                torch.tensor(trh_rights, dtype=torch.int64),
                1,
                torch.tensor(idx, dtype=torch.int64),
            ).to(self._device)

            prd_actions, prd_lefts, prd_rights = \
                self._model.infer(idx, act, arg)

            act_loss = self._nll_loss(
                prd_actions.view(-1, prd_actions.size(-1)),
                actions.view(-1),
            )
            lft_loss = self._nll_loss(
                prd_lefts.view(-1, prd_lefts.size(-1)),
                lefts.view(-1),
            )
            rgt_loss = self._nll_loss(
                prd_rights.view(-1, prd_rights.size(-1)),
                rights.view(-1),
            )

            # Backward pass.
            for m in self._model.modules():
                self._model.modules()[m].zero_grad()

            (self._action_coeff * act_loss + lft_loss + rgt_loss).backward()

            if self._grad_norm_max > 0.0:
                for m in self._model.modules():
                    torch.nn.utils.clip_grad_norm_(
                        self._model.modules()[m].parameters(),
                        self._grad_norm_max,
                    )

            info = {
                'act_loss': act_loss.item(),
                'lft_loss': lft_loss.item(),
                'rgt_loss': rgt_loss.item(),
            }

            self._ack.push(info, None)

            Log.out(
                "PROOFTRACE LM ACK RUN", {
                    'epoch': epoch,
                    'train_batch': self._train_batch,
                    'act_loss_avg': "{:.4f}".format(act_loss.item()),
                    'lft_loss_avg': "{:.4f}".format(lft_loss.item()),
                    'rgt_loss_avg': "{:.4f}".format(rgt_loss.item()),
                })

            self._train_batch += 1

        Log.out("EPOCH DONE", {
            'epoch': epoch,
        })
Ejemplo n.º 6
0
class TST:
    def __init__(
        self,
        config: Config,
        test_dataset: ProofTraceLMDataset,
    ):
        self._config = config

        self._device = torch.device(config.get('device'))

        self._sequence_length = config.get('prooftrace_sequence_length')

        self._model = LModel(config)
        self._ack = IOTAAck(
            config.get('prooftrace_lm_iota_sync_dir'),
            self._model.modules(),
        )

        self._nll_loss = nn.NLLLoss()

        self._test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=self._config.get('prooftrace_lm_batch_size'),
            shuffle=True,
            collate_fn=lm_collate,
        )

        Log.out('TST initialization', {
            "batch_size": self._config.get('prooftrace_lm_batch_size'),
        })

        self._train_batch = 0

    def run_once(
        self,
        epoch,
    ):
        act_loss_meter = Meter()
        lft_loss_meter = Meter()
        rgt_loss_meter = Meter()

        with torch.no_grad():
            for it, (act, arg, trh) in enumerate(self._test_loader):
                self._ack.fetch(self._device, blocking=False)
                self._model.eval()

                trh_actions, trh_lefts, trh_rights = trh_extract(trh, arg)

                # Because we can't run a pointer network on the full length
                # (memory), we extract indices to focus loss on.
                idx = random.sample(range(self._sequence_length), 64)

                actions = torch.index_select(
                    torch.tensor(trh_actions, dtype=torch.int64),
                    1,
                    torch.tensor(idx, dtype=torch.int64),
                ).to(self._device)
                lefts = torch.index_select(
                    torch.tensor(trh_lefts, dtype=torch.int64),
                    1,
                    torch.tensor(idx, dtype=torch.int64),
                ).to(self._device)
                rights = torch.index_select(
                    torch.tensor(trh_rights, dtype=torch.int64),
                    1,
                    torch.tensor(idx, dtype=torch.int64),
                ).to(self._device)

                prd_actions, prd_lefts, prd_rights = \
                    self._model.infer(idx, act, arg)

                act_loss = self._nll_loss(
                    prd_actions.view(-1, prd_actions.size(-1)),
                    actions.view(-1),
                )
                lft_loss = self._nll_loss(
                    prd_lefts.view(-1, prd_lefts.size(-1)),
                    lefts.view(-1),
                )
                rgt_loss = self._nll_loss(
                    prd_rights.view(-1, prd_rights.size(-1)),
                    rights.view(-1),
                )

                act_loss_meter.update(act_loss.item())
                lft_loss_meter.update(lft_loss.item())
                rgt_loss_meter.update(rgt_loss.item())

                info = {
                    'test_act_loss': act_loss_meter.avg,
                    'test_lft_loss': lft_loss_meter.avg,
                    'test_rgt_loss': rgt_loss_meter.avg,
                }

                self._ack.push(info, None, True)

                Log.out(
                    "PROOFTRACE LM TST RUN", {
                        'epoch': epoch,
                        'act_loss_avg': "{:.4f}".format(act_loss.item()),
                        'lft_loss_avg': "{:.4f}".format(lft_loss.item()),
                        'rgt_loss_avg': "{:.4f}".format(rgt_loss.item()),
                    })

                self._train_batch += 1

        Log.out("EPOCH DONE", {
            'epoch': epoch,
        })
Ejemplo n.º 7
0
Archivo: mcts.py Proyecto: spolu/z3ta
    def expand(
        self,
        beta_width: int,
        sequence_length: int,
        offset: int,
        l_model: LModel,
        v_model: VModel,
        target: Thm,
        step: int,
    ) -> typing.Tuple[float, ProofTraceActions, bool, ]:
        actions = self._ptra.actions().copy()
        arguments = self._ptra.arguments().copy()

        index = len(actions)

        empty = Action.from_action('EMPTY', None, None)
        while len(actions) < sequence_length:
            actions.append(empty)
        while len(arguments) < sequence_length:
            arguments.append(empty)

        with torch.no_grad():
            prd_actions, prd_lefts, prd_rights = \
                l_model.infer([index], [actions], [arguments])
            prd_values = \
                v_model.infer([index], [actions], [arguments])

        a_count = min(
            beta_width,
            len(PROOFTRACE_TOKENS) - len(PREPARE_TOKENS),
        )
        top_actions = torch.exp(prd_actions[0].cpu()).topk(a_count)
        top_lefts = torch.exp(prd_lefts[0].cpu()).topk(beta_width)
        top_rights = torch.exp(prd_rights[0].cpu()).topk(beta_width)

        value = prd_values[0].item() / self._ptra.action_len()

        candidates = []

        Log.out(
            "EXPAND",
            {
                'step': step,
                'value': "{:.3f}".format(value),
                'length': self._ptra.len(),
                'summary': self._ptra.summary(offset),
                # 'theorem': self._theorem.thm_string(True),
            })

        for ia in range(a_count):
            for il in range(beta_width):
                for ir in range(beta_width):

                    action = top_actions[1][ia].item()
                    left = top_lefts[1][il].item()
                    right = top_rights[1][ir].item()

                    if left >= self._ptra.len() or right >= self._ptra.len():
                        continue

                    a = Action.from_action(
                        INV_PROOFTRACE_TOKENS[action + len(PREPARE_TOKENS)],
                        self._ptra.arguments()[left],
                        self._ptra.arguments()[right],
                    )

                    if self._ptra.seen(a):
                        continue

                    if not self._repl.valid(a):
                        continue

                    candidates.append(
                        (top_actions[0][ia].item() * top_lefts[0][il].item() *
                         top_rights[0][ir].item(), a))

        candidates = sorted(candidates, key=lambda c: c[0], reverse=True)[0:8]

        for p, action in candidates:
            repl = self._repl.copy()
            ptra = self._ptra.copy()

            thm = repl.apply(action)
            action._index = thm.index()

            argument = ptra.build_argument(
                thm.concl(),
                thm.hyp(),
                thm.index(),
            )
            ptra.append(action, argument)

            if target.thm_string(True) == thm.thm_string(True):
                Log.out(
                    "DEMONSTRATED", {
                        'theorem': thm.thm_string(True),
                        'summary': ptra.summary(offset),
                    })
                return value, True, ptra

            self._children.append(Node(
                self,
                p,
                repl,
                ptra,
                thm,
            ))

        self._expanded = True

        return value, False, self._ptra
Ejemplo n.º 8
0
class WRK():
    def __init__(
        self,
        config: Config,
    ):
        self._config = config

        self._device = torch.device(config.get('device'))

        self._model = LModel(config)

        self._rollout_dir = os.path.join(
            os.path.expanduser(config.get('prooftrace_rollout_dir')),
            config.get('prooftrace_dataset_size'),
            'train_rollouts',
        )
        with gzip.open(
                os.path.join(
                    os.path.expanduser(config.get('prooftrace_dataset_dir')),
                    config.get('prooftrace_dataset_size'),
                    'traces.tokenizer',
                ), 'rb') as f:
            self._tokenizer = pickle.load(f)

        self._wrk = IOTAWrk(
            config.get('prooftrace_lm_iota_sync_dir'),
            'rollout',
            self._model.modules(),
        )

        self._type = config.get('prooftrace_search_type')

        Log.out('WRK initialization', {})

    def update(
        self,
        config: Config,
    ) -> None:
        self._config = config

        t = self._config.get('prooftrace_search_type')
        if t != self._type:
            self._type = t
            Log.out("Updated", {
                "prooftrace_search_type": t,
            })

    def run_once(self, ):
        info = self._wrk.fetch(self._device, False)
        if info is not None:
            self.update(info['config'])

        for m in self._model.modules():
            self._model.modules()[m].eval()

        assert os.path.isdir(self._rollout_dir)

        rdirs = [
            os.path.join(self._rollout_dir, d)
            for d in os.listdir(self._rollout_dir)
            if os.path.isdir(os.path.join(self._rollout_dir, d))
        ]

        rdir = random.choice(rdirs)
        rfiles = sorted([
            os.path.join(rdir, f)
            for f in os.listdir(rdir) if re.search(".rollout$", f)
        ],
                        reverse=True)

        if len(rfiles) == 0:
            return

        path = rfiles[0]
        with gzip.open(path, 'rb') as f:
            base = pickle.load(f)

        ground = base.positive()
        name = base.name()

        ptra = ProofTraceActions(
            'ROLLOUT-{}-{}'.format(
                datetime.datetime.now().strftime("%Y%m%d_%H%M_%S.%f"),
                random.randint(0, 9999),
            ),
            [
                ground.actions()[i] for i in range(ground.len())
                if ground.actions()[i].value in INV_PREPARE_TOKENS
            ],
            [
                ground.arguments()[i] for i in range(ground.len())
                if ground.actions()[i].value in INV_PREPARE_TOKENS
            ],
        )
        repl = REPL(self._tokenizer)
        target = repl.prepare(ptra)

        search = None
        if self._config.get('prooftrace_search_type') == 'beam':
            search = Beam(
                self._config,
                self._model,
                ptra,
                repl,
                target,
            )
        if self._config.get('prooftrace_search_type') == 'policy_sample':
            search = PolicySample(
                self._config,
                self._model,
                ptra,
                repl,
                target,
            )
        assert search is not None

        depth = self._config.get('prooftrace_sequence_length') - \
            ground.prepare_len()

        if 2 * ground.action_len() < depth:
            depth = 2 * ground.action_len()

        Log.out(
            "ROLLOUT START", {
                'name': name,
                'prepare_length': ground.prepare_len(),
                'action_length': ground.action_len(),
                'depth': depth,
            })

        rollout = None
        proved = False
        ptra = None

        for i in range(depth):
            step_start = time.time()
            done, ptra, proved = search.step()
            step_end = time.time()
            Log.out(
                'STEP', {
                    'i': i,
                    'done': done,
                    'proved': proved,
                    'time': "{:.2f}".format(step_end - step_start),
                })
            if done:
                break
            if (step_end - step_start) > 20:
                # self._config.get('prooftrace_search_step_timeout'):
                break

        if proved:
            rollout = Rollout(name, [ptra], [])
        else:
            rollout = Rollout(name, [], [ptra])

        demo_length = ptra.action_len()
        demo_delta = ptra.action_len() - ground.action_len()

        Log.out(
            "ROLLOUT END", {
                'name': name,
                'proved': proved,
                'demo_length': demo_length,
                'demo_delta': demo_delta
            })

        if proved:
            Log.out("PTRA", {
                'name': name,
                'summary': ptra.summary(),
            })

        if demo_length > 0:
            info = {
                'rll_cnt': 1,
                'pos_cnt': 1 if proved else 0,
                'neg_cnt': 0 if proved else 1,
            }
            if proved:
                info['demo_len'] = demo_length
                info['demo_dlt'] = demo_delta

            # Publish the statistics.
            self._wrk.publish(info)

            # Finally merge and store the new rollout
            base.merge(rollout)

            now = datetime.datetime.now().strftime("%Y%m%d_%H%M_%S.%f")
            rnd = random.randint(0, 10e9)

            tmp_path = os.path.join(rdir, "{}_{}.tmp".format(now, rnd))
            fnl_path = os.path.join(rdir, "{}_{}.rollout".format(now, rnd))

            with gzip.open(tmp_path, 'wb') as f:
                pickle.dump(base, f, protocol=pickle.HIGHEST_PROTOCOL)
            os.rename(tmp_path, fnl_path)

            del base
            del rollout

            if len(rfiles) > 1:
                for p in rfiles[1:]:
                    try:
                        os.remove(p)
                    except FileNotFoundError:
                        pass

            Log.out("MERGE WRITE", {
                'name': name,
                'path': fnl_path,
            })
Ejemplo n.º 9
0
Archivo: run.py Proyecto: spolu/z3ta
def search():
    parser = argparse.ArgumentParser(description="")

    parser.add_argument(
        'config_path',
        type=str,
        help="path to the config file",
    )
    parser.add_argument(
        '--dataset_size',
        type=str,
        help="config override",
    )
    parser.add_argument(
        '--load_dir',
        type=str,
        help="config override",
    )

    parser.add_argument(
        '--device',
        type=str,
        help="config override",
    )
    parser.add_argument(
        '--train',
        type=str2bool,
        help="search training set",
    )

    args = parser.parse_args()

    config = Config.from_file(args.config_path)

    if args.device is not None:
        config.override('device', args.device)

    if args.dataset_size is not None:
        config.override(
            'prooftrace_dataset_size',
            args.dataset_size,
        )
    if args.load_dir is not None:
        config.override(
            'prooftrace_load_dir',
            os.path.expanduser(args.load_dir),
        )

    train = False
    if args.train is not None:
        train = args.train

    if train:
        dataset_dir = os.path.join(
            os.path.expanduser(config.get('prooftrace_dataset_dir')),
            config.get('prooftrace_dataset_size'), 'train_traces')
    else:
        dataset_dir = os.path.join(
            os.path.expanduser(config.get('prooftrace_dataset_dir')),
            config.get('prooftrace_dataset_size'), 'test_traces')

    assert os.path.isdir(dataset_dir)
    files = [
        os.path.join(dataset_dir, f) for f in os.listdir(dataset_dir)
        if os.path.isfile(os.path.join(dataset_dir, f))
    ]
    cases = []

    with gzip.open(
            os.path.join(
                os.path.expanduser(config.get('prooftrace_dataset_dir')),
                config.get('prooftrace_dataset_size'),
                'traces.tokenizer',
            ), 'rb') as f:
        tokenizer = pickle.load(f)

    for p in files:
        match = re.search("_(\\d+)_(\\d+)\\.actions$", p)
        if match is None:
            continue
        ptra_len = int(match.group(1))
        cases.append((p, ptra_len))

    Log.out("Loaded ProofTraceActions", {
        'cases': len(cases),
    })

    l_model = LModel(config).load()
    # v_model = VModel(config).load()

    cases = sorted(cases, key=lambda c: c[1])

    for i in range(len(cases)):
        c = cases[i][0]
        with gzip.open(c, 'rb') as f:
            ground = pickle.load(f)

        ptra = ProofTraceActions(
            'SEARCH-{}-{}'.format(
                datetime.datetime.now().strftime("%Y%m%d_%H%M_%S.%f"),
                random.randint(0, 9999),
            ),
            [
                ground.actions()[i] for i in range(ground.len())
                if ground.actions()[i].value in INV_PREPARE_TOKENS
            ],
            [
                ground.arguments()[i] for i in range(ground.len())
                if ground.actions()[i].value in INV_PREPARE_TOKENS
            ],
        )
        repl = REPL(tokenizer)
        target = repl.prepare(ptra)

        offset = 0
        fixed_gamma = config.get('prooftrace_search_fixed_gamma')
        if fixed_gamma > 0:
            gamma_len = max(ground.action_len() - fixed_gamma, 0)
            offset = ground.prepare_len() + gamma_len

            for i in range(gamma_len):
                assert ground.prepare_len() + i < ground.len() - 1
                pos = ground.prepare_len() + i

                action = ground.actions()[pos]
                argument = ground.arguments()[pos]

                thm = repl.apply(action)

                action._index = thm.index()
                argument._index = thm.index()

                ptra.append(action, argument)

        Log.out(
            "TARGET", {
                'name': ground.name(),
                'prepare_length': ground.prepare_len(),
                'action_length': ground.action_len(),
                'summary': ground.summary(offset),
                'theorem': target.thm_string(False, True),
            })

        search = None
        if config.get('prooftrace_search_type') == 'beam':
            search = Beam(config, l_model, ptra, repl, target)
        # if config.get('prooftrace_search_type') == 'mcts':
        #     search = MCTS(config, l_model, v_model, ptra, repl, target)
        # if config.get('prooftrace_search_type') == 'particle_filter':
        #     search = ParticleFilter(
        #         config, l_model, v_model, ptra, repl, target,
        #     )
        if config.get('prooftrace_search_type') == 'policy_sample':
            search = PolicySample(config, l_model, ptra, repl, target)
        assert search is not None

        depth = config.get('prooftrace_sequence_length') - \
            ground.prepare_len()

        if fixed_gamma != 0:
            if 2 * fixed_gamma < depth:
                depth = fixed_gamma * 2
        else:
            if 2 * ground.action_len() < depth:
                depth = 2 * ground.action_len()

        for i in range(depth):
            if fixed_gamma != 0:
                conclusion = (i >= fixed_gamma * 2)
            else:
                conclusion = (i >= ground.action_len())

            step_start = time.time()
            done, ptra, proved = search.step(offset, conclusion)
            step_end = time.time()

            Log.out(
                'STEP', {
                    'i': i,
                    'done': done,
                    'proved': proved,
                    'time': "{:.2f}".format(step_end - step_start),
                    'summary': ptra.summary(offset),
                })
            if done:
                if proved:
                    Log.out("DEMONSTRATED", {
                        'theorem': target.thm_string(False, True),
                    })
                break

            # if (step_end - step_start) > \
            #         config.get('prooftrace_search_step_timeout'):
            #     break

        Log.out("FINISH", {
            'summary': ptra.summary(offset),
        })
        if config.get('prooftrace_search_type') == 'random' \
                and search.last_thm() is not None:
            Log.out("GENERATED",
                    {'theorem': search.last_thm().thm_string(False, True)})