Example #1
0
 def _eval(self, model, data_loader):
     model.eval()
     meter = Meter(self.tasks)
     for i, batch_data in enumerate(data_loader):
         smiless, inputs, labels, masks = self._prepare_batch_data(
             batch_data)
         _, predictions = model(inputs)
         if self.norm:
             predictions = predictions * self.data_std + self.data_mean
         meter.update(predictions, labels, masks)
     eval_results_dict = meter.compute_metric(self.metrics)
     return eval_results_dict
Example #2
0
    def batch_test(self, ):
        assert self._test_loader is not None

        self._model.eval()
        loss_meter = Meter()

        hit = 0
        total = 0
        pos = 0
        neg = 0

        with torch.no_grad():
            for it, (cnj, thr, pre) in enumerate(self._test_loader):
                res = self._model(
                    cnj.to(self._device),
                    thr.to(self._device),
                )

                loss = F.binary_cross_entropy(res, pre.to(self._device))

                loss_meter.update(loss.item())

                limit = 0.50
                for i in range(res.size(0)):
                    if res[i].item() >= limit and pre[i].item() >= limit:
                        hit += 1
                    if res[i].item() < limit and pre[i].item() < limit:
                        hit += 1
                    if res[i].item() >= limit:
                        pos += 1
                    if res[i].item() < limit:
                        neg += 1
                    total += 1

        Log.out(
            "TH2VEC TEST", {
                'batch_count': self._train_batch,
                'loss_avg': loss_meter.avg,
                'hit_rate': "{:.3f}".format(hit / total),
                'pos_rate': "{:.3f}".format(pos / total),
            })

        if self._tb_writer is not None:
            self._tb_writer.add_scalar(
                "test/th2vec/direct_premiser/loss",
                loss_meter.avg,
                self._train_batch,
            )
            self._tb_writer.add_scalar(
                "test/th2vec/direct_premiser/hit_rate",
                hit / total,
                self._train_batch,
            )
Example #3
0
    def batch_train(
        self,
        epoch,
    ):
        assert self._train_loader is not None

        self._model.train()
        loss_meter = Meter()

        if self._config.get('distributed_training'):
            self._train_sampler.set_epoch(epoch)
        # self._scheduler.step()

        for it, (cnj, thr, pre) in enumerate(self._train_loader):
            res = self._model(
                cnj.to(self._device),
                thr.to(self._device),
            )

            loss = F.binary_cross_entropy(res, pre.to(self._device))

            loss.backward()

            if it % self._accumulation_step_count == 0:
                self._optimizer.step()
                self._optimizer.zero_grad()

            loss_meter.update(loss.item())

            self._train_batch += 1

            if self._train_batch % 10 == 0:
                Log.out("TH2VEC TRAIN", {
                    'train_batch': self._train_batch,
                    'loss_avg': loss_meter.avg,
                })

                if self._tb_writer is not None:
                    self._tb_writer.add_scalar(
                        "train/th2vec/direct_premiser/loss",
                        loss_meter.avg,
                        self._train_batch,
                    )

                loss_meter = Meter()

        Log.out(
            "EPOCH DONE",
            {
                'epoch': epoch,
                # 'learning_rate': self._scheduler.get_lr(),
            })
Example #4
0
    def run_once(self, ):
        run_start = time.time()

        infos = self._ctl.aggregate()

        if len(infos) == 0:
            time.sleep(10)
            return

        rll_cnt_meter = Meter()
        pos_cnt_meter = Meter()
        neg_cnt_meter = Meter()
        demo_len_meter = Meter()

        for info in infos:
            rll_cnt_meter.update(info['rll_cnt'])
            pos_cnt_meter.update(info['pos_cnt'])
            neg_cnt_meter.update(info['neg_cnt'])
            if 'demo_len' in info:
                demo_len_meter.update(info['demo_len'])

        Log.out(
            "PROOFTRACE BEAM ROLLOUT CTL RUN", {
                'epoch': self._epoch,
                'run_time': "{:.2f}".format(time.time() - run_start),
                'update_count': len(infos),
                'rll_cnt': "{:.4f}".format(rll_cnt_meter.sum or 0.0),
                'pos_cnt': "{:.4f}".format(pos_cnt_meter.avg or 0.0),
                'neg_cnt': "{:.4f}".format(neg_cnt_meter.avg or 0.0),
                'demo_len': "{:.4f}".format(demo_len_meter.avg or 0.0),
            })

        if self._tb_writer is not None:
            if rll_cnt_meter.avg is not None:
                self._tb_writer.add_scalar(
                    "prooftrace_search_rollout/rll_cnt",
                    rll_cnt_meter.sum,
                    self._epoch,
                )
            if pos_cnt_meter.avg is not None:
                self._tb_writer.add_scalar(
                    "prooftrace_search_rollout/pos_cnt",
                    pos_cnt_meter.avg,
                    self._epoch,
                )
            if neg_cnt_meter.avg is not None:
                self._tb_writer.add_scalar(
                    "prooftrace_search_rollout/neg_cnt",
                    neg_cnt_meter.avg,
                    self._epoch,
                )
            if demo_len_meter.avg is not None:
                self._tb_writer.add_scalar(
                    "prooftrace_search_rollout/demo_len",
                    demo_len_meter.avg,
                    self._epoch,
                )

        self._epoch += 1
Example #5
0
    def run_once(self, ):
        run_start = time.time()

        infos = self._ctl.aggregate()

        if len(infos) == 0:
            time.sleep(10)
            return

        demo_len_meter = Meter()
        gamma_meters = {}
        for gamma in GAMMAS:
            gamma_meters['gamma_{}'.format(gamma)] = Meter()

        for info in infos:
            demo_len_meter.update(info['demo_len'])
            for gamma in GAMMAS:
                key = 'gamma_{}'.format(gamma)
                gamma_meters[key].update(info[key])

        out = {
            'epoch': self._epoch,
            'run_time': "{:.2f}".format(time.time() - run_start),
            'update_count': len(infos),
            'demo_len': "{:.4f}".format(demo_len_meter.avg or 0.0),
        }
        for gamma in GAMMAS:
            key = 'gamma_{}'.format(gamma)
            out[key] = "{:.4f}".format(gamma_meters[key].avg or 0.0)
        Log.out("PROOFTRACE BEAM CTL RUN", out)

        if self._tb_writer is not None:
            if demo_len_meter.avg is not None:
                self._tb_writer.add_scalar(
                    "prooftrace_search_ctl/demo_len",
                    demo_len_meter.avg,
                    self._epoch,
                )
            for gamma in GAMMAS:
                key = 'gamma_{}'.format(gamma)
                if gamma_meters[key].avg is not None:
                    self._tb_writer.add_scalar(
                        "prooftrace_search_ctl/{}".format(key),
                        gamma_meters[key].avg,
                        self._epoch,
                    )

        self._epoch += 1
Example #6
0
    def batch_test(self, ):
        assert self._test_loader is not None

        self._model.eval()
        loss_meter = Meter()

        hit = 0
        total = 0

        with torch.no_grad():
            for it, (cl_pos, cl_neg, sats) in enumerate(self._test_loader):
                generated = self._model(
                    cl_pos.to(self._device),
                    cl_neg.to(self._device),
                )

                loss = F.binary_cross_entropy(generated, sats.to(self._device))

                loss_meter.update(loss.item())

                for i in range(generated.size(0)):
                    if generated[i].item() >= 0.5 and sats[i].item() >= 0.5:
                        hit += 1
                    if generated[i].item() < 0.5 and sats[i].item() < 0.5:
                        hit += 1
                    total += 1

        Log.out(
            "SAT TEST", {
                'batch_count': self._train_batch,
                'loss_avg': loss_meter.avg,
                'hit_rate': "{:.2f}".format(hit / total),
            })

        if self._tb_writer is not None:
            self._tb_writer.add_scalar(
                "test/sat/loss",
                loss_meter.avg,
                self._train_batch,
            )
            self._tb_writer.add_scalar(
                "test/sat/hit_rate",
                hit / total,
                self._train_batch,
            )

        return loss_meter.avg
Example #7
0
    def batch_test(self, ):
        assert self._test_loader is not None

        self._model.eval()

        rec_loss_meter = Meter()

        with torch.no_grad():
            for it, trm in enumerate(self._test_loader):
                mu, _ = self._inner_model.encode(trm.to(self._device))
                trm_rec = self._inner_model.decode(mu)

                rec_loss = self._loss(
                    trm_rec.view(-1, trm_rec.size(2)),
                    trm.to(self._device).view(-1),
                )

                rec_loss_meter.update(rec_loss.item())

                if it == 0:
                    trm_smp = self._inner_model.sample(trm_rec)

                    Log.out("<<<", {
                        'term':
                        self._kernel.detokenize(trm[0].data.numpy(), ),
                    })
                    Log.out(
                        ">>>", {
                            'term':
                            self._kernel.detokenize(
                                trm_smp[0].cpu().data.numpy(), ),
                        })

        Log.out("TH2VEC TEST", {
            'batch_count': self._train_batch,
            'loss_avg': rec_loss_meter.avg,
        })

        if self._tb_writer is not None:
            self._tb_writer.add_scalar(
                "test/th2vec/autoencoder_embedder/rec_loss",
                rec_loss_meter.avg,
                self._train_batch,
            )

        rec_loss_meter = Meter()
Example #8
0
    def batch_test(
            self,
    ):
        assert self._test_loader is not None

        self._model.eval()
        loss_meter = Meter()

        hit = 0
        total = 0

        with torch.no_grad():
            for it, (cnj, thr, pre) in enumerate(self._test_loader):
                with torch.no_grad():
                    cnj_emd = self._embedder.encode(cnj.to(self._device))
                    thr_emd = self._embedder.encode(thr.to(self._device))

                res = self._model(cnj_emd.detach(), thr_emd.detach())

                loss = F.binary_cross_entropy(res, pre.to(self._device))

                loss_meter.update(loss.item())

                for i in range(res.size(0)):
                    if res[i].item() >= 0.5 and pre[i].item() >= 0.5:
                        hit += 1
                    if res[i].item() < 0.5 and pre[i].item() < 0.5:
                        hit += 1
                    total += 1

        Log.out("TH2VEC TEST", {
            'batch_count': self._train_batch,
            'loss_avg': loss_meter.avg,
            'hit_rate': "{:.3f}".format(hit / total),
        })

        if self._tb_writer is not None:
            self._tb_writer.add_scalar(
                "test/th2vec/premiser/loss",
                loss_meter.avg, self._train_batch,
            )
            self._tb_writer.add_scalar(
                "test/th2vec/premiser/hit_rate",
                hit / total, self._train_batch,
            )
Example #9
0
    def run_once(
        self,
        epoch,
    ):
        val_loss_meter = Meter()

        with torch.no_grad():
            for it, (idx, act, arg, trh, val) in enumerate(self._test_loader):
                self._ack.fetch(self._device, blocking=False)
                self._model.eval()

                prd_values = self._model.infer(idx, act, arg)
                values = torch.tensor(val, dtype=torch.float).unsqueeze(-1).to(
                    self._device)

                val_loss = self._mse_loss(prd_values, values)

                val_loss_meter.update(val_loss.item())

                info = {
                    'test_val_loss': val_loss_meter.avg,
                }

                self._ack.push(info, None, True)

                Log.out(
                    "PROOFTRACE V TST RUN", {
                        'epoch': epoch,
                        'val_loss_avg': "{:.4f}".format(val_loss.item()),
                    })

                self._train_batch += 1

        Log.out("EPOCH DONE", {
            'epoch': epoch,
        })
Example #10
0
    def test(self, ):
        assert self._test_loader is not None

        self._model_E.eval()
        self._model_H.eval()
        self._model_PH.eval()
        self._model_VH.eval()

        act_loss_meter = Meter()
        lft_loss_meter = Meter()
        rgt_loss_meter = Meter()
        # val_loss_meter = Meter()

        test_batch = 0

        with torch.no_grad():
            for it, (idx, act, arg, trh, val) in enumerate(self._test_loader):
                action_embeds = self._model_E(act)
                argument_embeds = self._model_E(arg)

                hiddens = self._model_H(action_embeds, argument_embeds)

                heads = torch.cat(
                    [hiddens[i][idx[i]].unsqueeze(0) for i in range(len(idx))],
                    dim=0)
                targets = torch.cat([
                    action_embeds[i][0].unsqueeze(0) for i in range(len(idx))
                ],
                                    dim=0)

                actions = torch.tensor([
                    trh[i].value - len(PREPARE_TOKENS) for i in range(len(trh))
                ],
                                       dtype=torch.int64).to(self._device)
                lefts = torch.tensor(
                    [arg[i].index(trh[i].left) for i in range(len(trh))],
                    dtype=torch.int64).to(self._device)
                rights = torch.tensor(
                    [arg[i].index(trh[i].right) for i in range(len(trh))],
                    dtype=torch.int64).to(self._device)
                # values = torch.tensor(val).unsqueeze(1).to(self._device)

                prd_actions, prd_lefts, prd_rights = \
                    self._model_PH(heads, hiddens, targets)
                # prd_values = self._model_VH(heads, targets)

                act_loss = self._nll_loss(prd_actions, actions)
                lft_loss = self._nll_loss(prd_lefts, lefts)
                rgt_loss = self._nll_loss(prd_rights, rights)
                # val_loss = self._mse_loss(prd_values, values)

                act_loss_meter.update(act_loss.item())
                lft_loss_meter.update(lft_loss.item())
                rgt_loss_meter.update(rgt_loss.item())
                # val_loss_meter.update(val_loss.item())

                Log.out(
                    "TEST BATCH",
                    {
                        'train_batch': self._train_batch,
                        'test_batch': test_batch,
                        'act_loss_avg': "{:.4f}".format(act_loss.item()),
                        'lft_loss_avg': "{:.4f}".format(lft_loss.item()),
                        'rgt_loss_avg': "{:.4f}".format(rgt_loss.item()),
                        # 'val_loss_avg': "{:.4f}".format(val_loss.item()),
                    })

                test_batch += 1

            Log.out(
                "PROOFTRACE TEST",
                {
                    'train_batch': self._train_batch,
                    'act_loss_avg': "{:.4f}".format(act_loss_meter.avg),
                    'lft_loss_avg': "{:.4f}".format(lft_loss_meter.avg),
                    'rgt_loss_avg': "{:.4f}".format(rgt_loss_meter.avg),
                    # 'val_loss_avg': "{:.4f}".format(val_loss_meter.avg),
                })

            if self._tb_writer is not None:
                self._tb_writer.add_scalar(
                    "prooftrace_lm_test/act_loss",
                    act_loss_meter.avg,
                    self._train_batch,
                )
                self._tb_writer.add_scalar(
                    "prooftrace_lm_test/lft_loss",
                    lft_loss_meter.avg,
                    self._train_batch,
                )
                self._tb_writer.add_scalar(
                    "prooftrace_lm_test/rgt_loss",
                    rgt_loss_meter.avg,
                    self._train_batch,
                )
Example #11
0
    def batch_train(
        self,
        epoch,
    ):
        assert self._train_loader is not None

        self._model_E.train()
        self._model_H.train()
        self._model_PH.train()
        self._model_VH.train()

        act_loss_meter = Meter()
        lft_loss_meter = Meter()
        rgt_loss_meter = Meter()
        # val_loss_meter = Meter()

        if self._config.get('distributed_training'):
            self._train_sampler.set_epoch(epoch)

        for it, (idx, act, arg, trh, val) in enumerate(self._train_loader):
            action_embeds = self._model_E(act)
            argument_embeds = self._model_E(arg)

            # action_embeds = \
            #     torch.zeros(action_embeds.size()).to(self._device)
            # argument_embeds = \
            #     torch.zeros(argument_embeds.size()).to(self._device)

            hiddens = self._model_H(action_embeds, argument_embeds)

            heads = torch.cat(
                [hiddens[i][idx[i]].unsqueeze(0) for i in range(len(idx))],
                dim=0)
            targets = torch.cat(
                [action_embeds[i][0].unsqueeze(0) for i in range(len(idx))],
                dim=0)

            actions = torch.tensor(
                [trh[i].value - len(PREPARE_TOKENS) for i in range(len(trh))],
                dtype=torch.int64).to(self._device)
            lefts = torch.tensor(
                [arg[i].index(trh[i].left) for i in range(len(trh))],
                dtype=torch.int64).to(self._device)
            rights = torch.tensor(
                [arg[i].index(trh[i].right) for i in range(len(trh))],
                dtype=torch.int64).to(self._device)
            # values = torch.tensor(val).unsqueeze(1).to(self._device)

            prd_actions, prd_lefts, prd_rights = \
                self._model_PH(heads, hiddens, targets)
            # prd_values = self._model_VH(heads, targets)

            act_loss = self._nll_loss(prd_actions, actions)
            lft_loss = self._nll_loss(prd_lefts, lefts)
            rgt_loss = self._nll_loss(prd_rights, rights)
            # val_loss = self._mse_loss(prd_values, values)

            # (act_loss + lft_loss + rgt_loss +
            #  self._value_coeff * val_loss).backward()
            (act_loss + lft_loss + rgt_loss).backward()

            if it % self._accumulation_step_count == 0:
                self._optimizer.step()
                self._optimizer.zero_grad()

            act_loss_meter.update(act_loss.item())
            lft_loss_meter.update(lft_loss.item())
            rgt_loss_meter.update(rgt_loss.item())
            # val_loss_meter.update(val_loss.item())

            Log.out(
                "TRAIN BATCH",
                {
                    'train_batch': self._train_batch,
                    'act_loss_avg': "{:.4f}".format(act_loss.item()),
                    'lft_loss_avg': "{:.4f}".format(lft_loss.item()),
                    'rgt_loss_avg': "{:.4f}".format(rgt_loss.item()),
                    # 'val_loss_avg': "{:.4f}".format(val_loss.item()),
                })

            if self._train_batch % 10 == 0 and self._train_batch != 0:
                Log.out(
                    "PROOFTRACE TRAIN",
                    {
                        'epoch': epoch,
                        'train_batch': self._train_batch,
                        'act_loss_avg': "{:.4f}".format(act_loss_meter.avg),
                        'lft_loss_avg': "{:.4f}".format(lft_loss_meter.avg),
                        'rgt_loss_avg': "{:.4f}".format(rgt_loss_meter.avg),
                        # 'val_loss_avg': "{:.4f}".format(val_loss_meter.avg),
                    })

                if self._tb_writer is not None:
                    self._tb_writer.add_scalar(
                        "prooftrace_lm_train/act_loss",
                        act_loss_meter.avg,
                        self._train_batch,
                    )
                    self._tb_writer.add_scalar(
                        "prooftrace_lm_train/lft_loss",
                        lft_loss_meter.avg,
                        self._train_batch,
                    )
                    self._tb_writer.add_scalar(
                        "prooftrace_lm_train/rgt_loss",
                        rgt_loss_meter.avg,
                        self._train_batch,
                    )
                    # self._tb_writer.add_scalar(
                    #     "prooftrace_lm_train/val_loss",
                    #     val_loss_meter.avg, self._train_batch,
                    # )

                act_loss_meter = Meter()
                lft_loss_meter = Meter()
                rgt_loss_meter = Meter()
                # val_loss_meter = Meter()

            if self._train_batch % 1000 == 0:
                self.save()

                self.test()
                self._model_E.train()
                self._model_H.train()
                self._model_PH.train()
                self._model_VH.train()

                self.update()

            self._train_batch += 1

        Log.out("EPOCH DONE", {
            'epoch': epoch,
        })
Example #12
0
    def batch_train(
            self,
            epoch,
    ):
        assert self._train_loader is not None

        self._model_G.train()
        self._model_D.train()

        dis_loss_meter = Meter()
        gen_loss_meter = Meter()

        if self._config.get('distributed_training'):
            self._train_sampler.set_epoch(epoch)

        for it, trm in enumerate(self._train_loader):
            nse = torch.randn(
                trm.size(0),
                self._config.get('th2vec_transformer_hidden_size'),
            ).to(self._device)

            trm_rel = trm.to(self._device)
            trm_gen = self._model_G(nse)

            m = Categorical(torch.exp(trm_gen))
            trm_smp = m.sample()

            # onh_rel = self._inner_model_D.one_hot(trm_rel)
            # onh_gen = self._inner_model_D.one_hot(trm_smp)
            # onh_gen.requires_grad = True

            dis_rel = self._model_D(trm_rel)
            dis_gen = self._model_D(trm_smp)

            # Label smoothing
            ones = torch.ones(*dis_rel.size()) - \
                torch.rand(*dis_rel.size()) / 10
            zeros = torch.zeros(*dis_rel.size()) + \
                torch.rand(*dis_rel.size()) / 10

            dis_loss_rel = \
                F.binary_cross_entropy(dis_rel, ones.to(self._device))
            dis_loss_gen = \
                F.binary_cross_entropy(dis_gen, zeros.to(self._device))
            dis_loss = dis_loss_rel + dis_loss_gen

            self._optimizer_D.zero_grad()
            dis_loss.backward()
            self._optimizer_D.step()

            # REINFORCE

            # reward = onh_gen.grad.detach()
            # reward -= \
            #     torch.mean(reward, 2).unsqueeze(-1).expand(*reward.size())
            # reward /= \
            #     torch.std(reward, 2).unsqueeze(-1).expand(*reward.size())

            # gen_loss = trm_gen * reward  # (-logp * -grad_dis)
            # gen_loss = gen_loss.mean()

            reward = dis_gen.squeeze(1).detach()
            reward -= torch.mean(reward)
            reward /= torch.std(reward)

            gen_loss = -m.log_prob(trm_smp).mean(1) * reward
            gen_loss = gen_loss.mean()

            self._optimizer_G.zero_grad()
            gen_loss.backward()
            self._optimizer_G.step()

            dis_loss_meter.update(dis_loss.item())
            gen_loss_meter.update(gen_loss.item())

            self._train_batch += 1

            if self._train_batch % 10 == 0:
                Log.out("TH2VEC GENERATOR TRAIN", {
                    'train_batch': self._train_batch,
                    'dis_loss_avg': dis_loss_meter.avg,
                    'gen_loss_avg': gen_loss_meter.avg,
                })

                Log.out("<<<", {
                    'term': self._kernel.detokenize(
                        trm_smp[0].cpu().data.numpy(),
                    ),
                })

                if self._tb_writer is not None:
                    self._tb_writer.add_scalar(
                        "train/th2vec/generator/dis_loss",
                        dis_loss_meter.avg, self._train_batch,
                    )
                    self._tb_writer.add_scalar(
                        "train/th2vec/generator/gen_loss",
                        gen_loss_meter.avg, self._train_batch,
                    )

                dis_loss_meter = Meter()
                gen_loss_meter = Meter()

        Log.out("EPOCH DONE", {
            'epoch': epoch,
        })
Example #13
0
    def run_once(
            self,
            epoch: int,
    ):
        assert self._optimizer is not None

        for m in self._modules:
            self._modules[m].train()

        reward_meter = Meter()
        act_loss_meter = Meter()
        val_loss_meter = Meter()
        entropy_meter = Meter()

        for step in range(self._rollout_size):
            with torch.no_grad():
                obs = self._rollouts._observations[step]

                # hiddens = self._modules['CNN'](obs)
                hiddens = self._modules['MLP'](obs)
                prd_actions = self._modules['PH'](hiddens)
                values = self._modules['VH'](hiddens)

                m = Categorical(torch.exp(prd_actions))
                actions = m.sample().view(-1, 1)

                observations, rewards, dones, infos = self._envs.step(
                    actions.squeeze(1).cpu().numpy(),
                )

                # observations = torch.from_numpy(
                #     observations,
                # ).float().transpose(3, 1).to(self._device) / 255.0
                observations = torch.from_numpy(
                    observations,
                ).float().to(self._device)

                log_probs = prd_actions.gather(1, actions)

                for i, r in enumerate(rewards):
                    self._episode_rewards[i] += r / \
                        self._config.get("energy_reward_scaling")
                    if dones[i]:
                        reward_meter.update(self._episode_rewards[i])
                        self._episode_rewards[i] = 0.0

                self._rollouts.insert(
                    step,
                    observations,
                    actions.detach(),
                    log_probs.detach(),
                    values.detach(),
                    torch.tensor(
                        [
                            r / self._config.get("energy_reward_scaling")
                            for r in rewards
                        ], dtype=torch.float,
                    ).unsqueeze(1).to(self._device),
                    torch.tensor(
                        [[0.0] if d else [1.0] for d in dones],
                    ).to(self._device),
                )

        with torch.no_grad():
            obs = self._rollouts._observations[-1]

            # hiddens = self._modules['CNN'](obs)
            hiddens = self._modules['MLP'](obs)
            values = self._modules['VH'](hiddens)

            self._rollouts.compute_returns(values.detach())

            advantages = \
                self._rollouts._returns[:-1] - self._rollouts._values[:-1]
            advantages = \
                (advantages - advantages.mean()) / (advantages.std() + 1e-5)

        for e in range(self._epoch_count):
            generator = self._rollouts.generator(advantages, self._batch_size)

            for batch in generator:
                rollout_observations, \
                    rollout_actions, \
                    rollout_values, \
                    rollout_returns, \
                    rollout_masks, \
                    rollout_log_probs, \
                    rollout_advantages = batch

                # hiddens = self._modules['CNN'](rollout_observations)
                hiddens = self._modules['MLP'](rollout_observations)
                prd_actions = self._modules['PH'](hiddens)
                values = self._modules['VH'](hiddens)

                log_probs = prd_actions.gather(1, rollout_actions)
                entropy = -(prd_actions * torch.exp(prd_actions)).mean()

                # Clipped action loss.
                ratio = torch.exp(log_probs - rollout_log_probs)
                action_loss = -torch.min(
                    ratio * rollout_advantages,
                    torch.clamp(ratio, 1.0 - self._clip, 1.0 + self._clip) *
                    rollout_advantages,
                ).mean()

                # value_loss = F.mse_loss(values, rollout_returns)
                value_loss = (rollout_returns.detach() - values).pow(2).mean()

                # Backward pass.
                self._optimizer.zero_grad()

                (
                    action_loss +
                    self._value_coeff * value_loss -
                    self._entropy_coeff * entropy
                ).backward()

                if self._grad_norm_max > 0.0:
                    torch.nn.utils.clip_grad_norm_(
                        # self._modules['CNN'].parameters(),
                        self._modules['MLP'].parameters(),
                        self._grad_norm_max,
                    )
                    torch.nn.utils.clip_grad_norm_(
                        self._modules['PH'].parameters(),
                        self._grad_norm_max,
                    )
                    torch.nn.utils.clip_grad_norm_(
                        self._modules['VH'].parameters(),
                        self._grad_norm_max,
                    )

                self._optimizer.step()

                act_loss_meter.update(action_loss.item())
                val_loss_meter.update(value_loss.item())
                entropy_meter.update(entropy.item())

        self._rollouts.after_update()

        if reward_meter.avg:
            if self._reward_tracker is None:
                self._reward_tracker = reward_meter.avg
            self._reward_tracker = \
                0.99 * self._reward_tracker + 0.01 * reward_meter.avg

            Log.out("ENERGY PPO RUN", {
                'epoch': epoch,
                'reward': "{:.4f}".format(reward_meter.avg or 0.0),
                'act_loss': "{:.4f}".format(act_loss_meter.avg or 0.0),
                'val_loss': "{:.4f}".format(val_loss_meter.avg or 0.0),
                'entropy': "{:.4f}".format(entropy_meter.avg or 0.0),
                'tracker': "{:.4f}".format(self._reward_tracker),
            })
Example #14
0
    def run_once(
        self,
        epoch,
    ):
        act_loss_meter = Meter()
        lft_loss_meter = Meter()
        rgt_loss_meter = Meter()

        with torch.no_grad():
            for it, (act, arg, trh) in enumerate(self._test_loader):
                self._ack.fetch(self._device, blocking=False)
                self._model.eval()

                trh_actions, trh_lefts, trh_rights = trh_extract(trh, arg)

                # Because we can't run a pointer network on the full length
                # (memory), we extract indices to focus loss on.
                idx = random.sample(range(self._sequence_length), 64)

                actions = torch.index_select(
                    torch.tensor(trh_actions, dtype=torch.int64),
                    1,
                    torch.tensor(idx, dtype=torch.int64),
                ).to(self._device)
                lefts = torch.index_select(
                    torch.tensor(trh_lefts, dtype=torch.int64),
                    1,
                    torch.tensor(idx, dtype=torch.int64),
                ).to(self._device)
                rights = torch.index_select(
                    torch.tensor(trh_rights, dtype=torch.int64),
                    1,
                    torch.tensor(idx, dtype=torch.int64),
                ).to(self._device)

                prd_actions, prd_lefts, prd_rights = \
                    self._model.infer(idx, act, arg)

                act_loss = self._nll_loss(
                    prd_actions.view(-1, prd_actions.size(-1)),
                    actions.view(-1),
                )
                lft_loss = self._nll_loss(
                    prd_lefts.view(-1, prd_lefts.size(-1)),
                    lefts.view(-1),
                )
                rgt_loss = self._nll_loss(
                    prd_rights.view(-1, prd_rights.size(-1)),
                    rights.view(-1),
                )

                act_loss_meter.update(act_loss.item())
                lft_loss_meter.update(lft_loss.item())
                rgt_loss_meter.update(rgt_loss.item())

                info = {
                    'test_act_loss': act_loss_meter.avg,
                    'test_lft_loss': lft_loss_meter.avg,
                    'test_rgt_loss': rgt_loss_meter.avg,
                }

                self._ack.push(info, None, True)

                Log.out(
                    "PROOFTRACE LM TST RUN", {
                        'epoch': epoch,
                        'act_loss_avg': "{:.4f}".format(act_loss.item()),
                        'lft_loss_avg': "{:.4f}".format(lft_loss.item()),
                        'rgt_loss_avg': "{:.4f}".format(rgt_loss.item()),
                    })

                self._train_batch += 1

        Log.out("EPOCH DONE", {
            'epoch': epoch,
        })
Example #15
0
    def batch_train(
        self,
        epoch,
    ):
        assert self._train_loader is not None

        self._model.train()

        rec_loss_meter = Meter()
        kld_loss_meter = Meter()
        all_loss_meter = Meter()

        if self._config.get('distributed_training'):
            self._train_sampler.set_epoch(epoch)
        # self._scheduler.step()

        for it, trm in enumerate(self._train_loader):
            trm_rec, mu, logvar = self._model(trm.to(self._device))

            rec_loss = self._loss(
                trm_rec.view(-1, trm_rec.size(2)),
                trm.to(self._device).view(-1),
            )
            kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

            all_loss = rec_loss + 0.001 * kld_loss

            self._optimizer.zero_grad()
            all_loss.backward()
            self._optimizer.step()

            rec_loss_meter.update(rec_loss.item())
            kld_loss_meter.update(kld_loss.item())
            all_loss_meter.update(all_loss.item())

            self._train_batch += 1

            if self._train_batch % 10 == 0:
                Log.out(
                    "TH2VEC AUTOENCODER_EMBEDDER TRAIN", {
                        'train_batch': self._train_batch,
                        'rec_loss_avg': rec_loss_meter.avg,
                        'kld_loss_avg': kld_loss_meter.avg,
                        'loss_avg': all_loss_meter.avg,
                    })

                if self._tb_writer is not None:
                    self._tb_writer.add_scalar(
                        "train/th2vec/autoencoder_embedder/rec_loss",
                        rec_loss_meter.avg,
                        self._train_batch,
                    )
                    self._tb_writer.add_scalar(
                        "train/th2vec/autoencoder_embedder/kld_loss",
                        kld_loss_meter.avg,
                        self._train_batch,
                    )
                    self._tb_writer.add_scalar(
                        "train/th2vec/autoencoder_embedder/all_loss",
                        all_loss_meter.avg,
                        self._train_batch,
                    )

                rec_loss_meter = Meter()
                kld_loss_meter = Meter()
                all_loss_meter = Meter()

        Log.out(
            "EPOCH DONE",
            {
                'epoch': epoch,
                # 'learning_rate': self._scheduler.get_lr(),
            })
Example #16
0
    def batch_train(
        self,
        epoch,
    ):
        assert self._train_loader is not None

        self._model.train()
        loss_meter = Meter()

        if self._config.get('distributed_training'):
            self._train_sampler.set_epoch(epoch)
        # self._scheduler.step()

        for it, (cl_pos, cl_neg, sats) in enumerate(self._train_loader):
            generated = self._model(
                cl_pos.to(self._device),
                cl_neg.to(self._device),
            )

            loss = F.binary_cross_entropy(generated, sats.to(self._device))

            self._optimizer.zero_grad()
            loss.backward()
            self._optimizer.step()

            loss_meter.update(loss.item())

            self._train_batch += 1

            if self._train_batch % 10 == 0:

                hit = 0
                total = 0

                for i in range(generated.size(0)):
                    if generated[i].item() >= 0.5 and sats[i].item() >= 0.5:
                        hit += 1
                    if generated[i].item() < 0.5 and sats[i].item() < 0.5:
                        hit += 1
                    total += 1

                Log.out(
                    "SAT TRAIN", {
                        'train_batch': self._train_batch,
                        'loss_avg': loss_meter.avg,
                        'hit_rate': "{:.2f}".format(hit / total),
                    })

                if self._tb_writer is not None:
                    self._tb_writer.add_scalar(
                        "train/sat/loss",
                        loss_meter.avg,
                        self._train_batch,
                    )
                    self._tb_writer.add_scalar(
                        "train/sat/hit_rate",
                        hit / total,
                        self._train_batch,
                    )

                loss_meter = Meter()

        Log.out(
            "EPOCH DONE",
            {
                'epoch': epoch,
                # 'learning_rate': self._scheduler.get_lr(),
            })
Example #17
0
    def batch_test(self, ):
        assert self._test_loader is not None

        self._model.eval()

        all_loss_meter = Meter()
        thr_loss_meter = Meter()
        unr_loss_meter = Meter()
        thr_simi_meter = Meter()
        unr_simi_meter = Meter()

        with torch.no_grad():
            for it, (cnj, thr, unr) in enumerate(self._test_loader):
                cnj_emd = self._model(cnj.to(self._device))
                thr_emd = self._model(thr.to(self._device))
                unr_emd = self._model(unr.to(self._device))

                thr_loss = F.mse_loss(cnj_emd, thr_emd)
                unr_loss = F.mse_loss(cnj_emd, unr_emd)

                all_loss = thr_loss - unr_loss

                thr_simi = F.cosine_similarity(cnj_emd, thr_emd).mean()
                unr_simi = F.cosine_similarity(cnj_emd, unr_emd).mean()

                all_loss_meter.update(all_loss.item())
                thr_loss_meter.update(thr_loss.item())
                unr_loss_meter.update(unr_loss.item())
                thr_simi_meter.update(thr_simi.item())
                unr_simi_meter.update(unr_simi.item())

        Log.out("TH2VEC TEST", {
            'batch_count': self._train_batch,
            'loss_avg': all_loss_meter.avg,
        })

        if self._tb_writer is not None:
            self._tb_writer.add_scalar(
                "test/th2vec/premise_embedder/all_loss",
                all_loss_meter.avg,
                self._train_batch,
            )
            self._tb_writer.add_scalar(
                "test/th2vec/premise_embedder/thr_loss",
                thr_loss_meter.avg,
                self._train_batch,
            )
            self._tb_writer.add_scalar(
                "test/th2vec/premise_embedder/unr_loss",
                unr_loss_meter.avg,
                self._train_batch,
            )
            self._tb_writer.add_scalar(
                "test/th2vec/premise_embedder/thr_simi",
                thr_simi_meter.avg,
                self._train_batch,
            )
            self._tb_writer.add_scalar(
                "test/th2vec/premise_embedder/unr_simi",
                unr_simi_meter.avg,
                self._train_batch,
            )

        all_loss_meter = Meter()
        thr_loss_meter = Meter()
        unr_loss_meter = Meter()
        thr_simi_meter = Meter()
        unr_simi_meter = Meter()
Example #18
0
    def batch_train(
        self,
        epoch,
    ):
        assert self._train_loader is not None

        self._model.train()

        all_loss_meter = Meter()
        thr_loss_meter = Meter()
        unr_loss_meter = Meter()
        thr_simi_meter = Meter()
        unr_simi_meter = Meter()

        if self._config.get('distributed_training'):
            self._train_sampler.set_epoch(epoch)
        self._scheduler.step()

        for it, (cnj, thr, unr) in enumerate(self._train_loader):
            cnj_emd = self._model(cnj.to(self._device))
            thr_emd = self._model(thr.to(self._device))
            unr_emd = self._model(unr.to(self._device))

            thr_loss = F.mse_loss(cnj_emd, thr_emd)
            unr_loss = F.mse_loss(cnj_emd, unr_emd)

            all_loss = thr_loss - unr_loss

            thr_simi = F.cosine_similarity(cnj_emd, thr_emd).mean()
            unr_simi = F.cosine_similarity(cnj_emd, unr_emd).mean()

            self._optimizer.zero_grad()
            all_loss.backward()
            self._optimizer.step()

            all_loss_meter.update(all_loss.item())
            thr_loss_meter.update(thr_loss.item())
            unr_loss_meter.update(unr_loss.item())
            thr_simi_meter.update(thr_simi.item())
            unr_simi_meter.update(unr_simi.item())

            self._train_batch += 1

            if self._train_batch % 10 == 0:
                Log.out(
                    "TH2VEC PREMISE_EMBEDDER TRAIN", {
                        'train_batch': self._train_batch,
                        'loss_avg': all_loss_meter.avg,
                    })

                if self._tb_writer is not None:
                    self._tb_writer.add_scalar(
                        "train/th2vec/premise_embedder/all_loss",
                        all_loss_meter.avg,
                        self._train_batch,
                    )
                    self._tb_writer.add_scalar(
                        "train/th2vec/premise_embedder/thr_loss",
                        thr_loss_meter.avg,
                        self._train_batch,
                    )
                    self._tb_writer.add_scalar(
                        "train/th2vec/premise_embedder/unr_loss",
                        unr_loss_meter.avg,
                        self._train_batch,
                    )
                    self._tb_writer.add_scalar(
                        "train/th2vec/premise_embedder/thr_simi",
                        thr_simi_meter.avg,
                        self._train_batch,
                    )
                    self._tb_writer.add_scalar(
                        "train/th2vec/premise_embedder/unr_simi",
                        unr_simi_meter.avg,
                        self._train_batch,
                    )

                all_loss_meter = Meter()
                thr_loss_meter = Meter()
                unr_loss_meter = Meter()
                thr_simi_meter = Meter()
                unr_simi_meter = Meter()

        Log.out("EPOCH DONE", {
            'epoch': epoch,
            'learning_rate': self._scheduler.get_lr(),
        })
Example #19
0
    def run_once(self, ):
        for m in self._modules:
            self._modules[m].train()

        run_start = time.time()

        self._optimizer.zero_grad()
        infos = self._syn.reduce(self._device, self._min_update_count)

        if len(infos) == 0:
            time.sleep(1)
            return

        self._optimizer.step()
        self._syn.broadcast({'config': self._config})

        act_loss_meter = Meter()
        lft_loss_meter = Meter()
        rgt_loss_meter = Meter()
        val_loss_meter = Meter()

        for info in infos:
            act_loss_meter.update(info['act_loss'])
            lft_loss_meter.update(info['lft_loss'])
            rgt_loss_meter.update(info['rgt_loss'])
            val_loss_meter.update(info['val_loss'])

        Log.out(
            "PROOFTRACE LM SYN RUN", {
                'epoch': self._epoch,
                'run_time': "{:.2f}".format(time.time() - run_start),
                'update_count': len(infos),
                'act_loss': "{:.4f}".format(act_loss_meter.avg or 0.0),
                'lft_loss': "{:.4f}".format(lft_loss_meter.avg or 0.0),
                'rgt_loss': "{:.4f}".format(rgt_loss_meter.avg or 0.0),
                'val_loss': "{:.4f}".format(val_loss_meter.avg or 0.0),
            })

        if self._tb_writer is not None:
            if act_loss_meter.avg is not None:
                self._tb_writer.add_scalar(
                    "prooftrace_lm_train/act_loss",
                    act_loss_meter.avg,
                    self._epoch,
                )
            if lft_loss_meter.avg is not None:
                self._tb_writer.add_scalar(
                    "prooftrace_lm_train/lft_loss",
                    lft_loss_meter.avg,
                    self._epoch,
                )
            if rgt_loss_meter.avg is not None:
                self._tb_writer.add_scalar(
                    "prooftrace_lm_train/rgt_loss",
                    rgt_loss_meter.avg,
                    self._epoch,
                )
            if val_loss_meter.avg is not None:
                self._tb_writer.add_scalar(
                    "prooftrace_lm_train/val_loss",
                    val_loss_meter.avg,
                    self._epoch,
                )
            self._tb_writer.add_scalar(
                "prooftrace_lm_train/update_count",
                len(infos),
                self._epoch,
            )

        self._epoch += 1

        if self._epoch % 100 == 0:
            self.save()
Example #20
0
    def run_once(self, ):
        for m in self._modules:
            self._modules[m].train()

        run_start = time.time()

        if self._epoch == 0:
            self._syn.broadcast({'config': self._config})

        self._optimizer.zero_grad()
        infos = self._syn.reduce(self._device, self._min_update_count)

        if len(infos) == 0:
            if self._epoch == 0:
                self._epoch += 1
            time.sleep(1)
            return

        self._optimizer.step()
        self._syn.broadcast({'config': self._config})

        if self._last_update is not None:
            update_delta = time.time() - self._last_update
        else:
            update_delta = 0.0
        self._last_update = time.time()

        frame_count_meter = Meter()
        match_count_meter = Meter()
        run_length_meter = Meter()
        demo_length_avg_meter = Meter()
        demo_length_max_meter = Meter()
        demo_delta_meter = Meter()
        stp_reward_meter = Meter()
        mtc_reward_meter = Meter()
        fnl_reward_meter = Meter()
        tot_reward_meter = Meter()
        act_loss_meter = Meter()
        val_loss_meter = Meter()
        act_entropy_meter = Meter()
        ptr_entropy_meter = Meter()

        for info in infos:
            frame_count_meter.update(info['frame_count'])
            if 'match_count' in info:
                match_count_meter.update(info['match_count'])
            if 'run_length' in info:
                run_length_meter.update(info['run_length'])
            if 'demo_length_avg' in info:
                demo_length_avg_meter.update(info['demo_length_avg'])
            if 'demo_length_max' in info:
                demo_length_max_meter.update(info['demo_length_max'])
            if 'demo_delta' in info:
                demo_delta_meter.update(info['demo_delta'])
            tot_reward = 0.0
            has_reward = False
            if 'stp_reward' in info:
                stp_reward_meter.update(info['stp_reward'])
                tot_reward += info['stp_reward']
                has_reward = True
            if 'mtc_reward' in info:
                mtc_reward_meter.update(info['mtc_reward'])
                tot_reward += info['mtc_reward']
                has_reward = True
            if 'fnl_reward' in info:
                fnl_reward_meter.update(info['fnl_reward'])
                tot_reward += info['fnl_reward']
                has_reward = True
            if has_reward:
                tot_reward_meter.update(tot_reward)
            act_loss_meter.update(info['act_loss'])
            val_loss_meter.update(info['val_loss'])
            act_entropy_meter.update(info['act_entropy'])
            ptr_entropy_meter.update(info['ptr_entropy'])

        Log.out(
            "PROOFTRACE PPO SYN RUN", {
                'epoch':
                self._epoch,
                'run_time':
                "{:.2f}".format(time.time() - run_start),
                'update_count':
                len(infos),
                'frame_count':
                frame_count_meter.sum,
                'update_delta':
                "{:.2f}".format(update_delta),
                'match_count':
                "{:.2f}".format(match_count_meter.avg or 0.0),
                'run_length':
                "{:.2f}".format(run_length_meter.avg or 0.0),
                'demo_length':
                "{:.2f}/{:.0f}".format(
                    demo_length_avg_meter.avg or 0.0,
                    demo_length_max_meter.max or 0.0,
                ),
                'demo_delta':
                "{:.4f}".format(demo_delta_meter.avg or 0.0),
                'stp_reward':
                "{:.4f}".format(stp_reward_meter.avg or 0.0),
                'mtc_reward':
                "{:.4f}".format(mtc_reward_meter.avg or 0.0),
                'fnl_reward':
                "{:.4f}".format(fnl_reward_meter.avg or 0.0),
                'tot_reward':
                "{:.4f}".format(tot_reward_meter.avg or 0.0),
                'act_loss':
                "{:.4f}".format(act_loss_meter.avg or 0.0),
                'val_loss':
                "{:.4f}".format(val_loss_meter.avg or 0.0),
                'act_entropy':
                "{:.4f}".format(act_entropy_meter.avg or 0.0),
                'ptr_entropy':
                "{:.4f}".format(ptr_entropy_meter.avg or 0.0),
            })

        if self._tb_writer is not None:
            if len(infos) > 0:
                self._tb_writer.add_scalar(
                    "prooftrace_ppo_train/update_delta",
                    update_delta,
                    self._epoch,
                )
                if match_count_meter.avg:
                    self._tb_writer.add_scalar(
                        "prooftrace_ppo_train/match_count",
                        match_count_meter.avg,
                        self._epoch,
                    )
                if run_length_meter.avg:
                    self._tb_writer.add_scalar(
                        "prooftrace_ppo_train/run_length",
                        run_length_meter.avg,
                        self._epoch,
                    )
                if demo_length_avg_meter.avg:
                    self._tb_writer.add_scalar(
                        "prooftrace_ppo_train/demo_length_avg",
                        demo_length_avg_meter.avg,
                        self._epoch,
                    )
                if demo_length_max_meter.avg:
                    self._tb_writer.add_scalar(
                        "prooftrace_ppo_train/demo_length_max",
                        demo_length_max_meter.max,
                        self._epoch,
                    )
                if demo_delta_meter.avg:
                    self._tb_writer.add_scalar(
                        "prooftrace_ppo_train/demo_delta",
                        demo_delta_meter.avg,
                        self._epoch,
                    )
                self._tb_writer.add_scalar(
                    "prooftrace_ppo_train/act_loss",
                    act_loss_meter.avg,
                    self._epoch,
                )
                self._tb_writer.add_scalar(
                    "prooftrace_ppo_train/val_loss",
                    val_loss_meter.avg,
                    self._epoch,
                )
                self._tb_writer.add_scalar(
                    "prooftrace_ppo_train/act_entropy",
                    act_entropy_meter.avg,
                    self._epoch,
                )
                self._tb_writer.add_scalar(
                    "prooftrace_ppo_train/ptr_entropy",
                    ptr_entropy_meter.avg,
                    self._epoch,
                )
                if stp_reward_meter.avg:
                    self._tb_writer.add_scalar(
                        "prooftrace_ppo_train/stp_reward",
                        stp_reward_meter.avg,
                        self._epoch,
                    )
                if mtc_reward_meter.avg:
                    self._tb_writer.add_scalar(
                        "prooftrace_ppo_train/mtc_reward",
                        mtc_reward_meter.avg,
                        self._epoch,
                    )
                if fnl_reward_meter.avg:
                    self._tb_writer.add_scalar(
                        "prooftrace_ppo_train/fnl_reward",
                        fnl_reward_meter.avg,
                        self._epoch,
                    )
                if tot_reward_meter.avg:
                    self._tb_writer.add_scalar(
                        "prooftrace_ppo_train/tot_reward",
                        tot_reward_meter.avg,
                        self._epoch,
                    )
                self._tb_writer.add_scalar(
                    "prooftrace_ppo_train/frame_count",
                    frame_count_meter.sum,
                    self._epoch,
                )

        self._epoch += 1

        if self._epoch % 10 == 0:
            self.save()
Example #21
0
    def run_once(self, ):
        for m in self._model.modules():
            self._model.modules()[m].train()

        run_start = time.time()

        self._policy_optimizer.zero_grad()

        infos = self._syn.reduce(self._device, self._min_update_count)

        if len(infos) == 0:
            time.sleep(1)
            return

        self._policy_optimizer.step()

        self._syn.broadcast({'config': self._config})

        if self._last_update is not None:
            update_delta = time.time() - self._last_update
        else:
            update_delta = 0.0
        self._last_update = time.time()

        act_loss_meter = Meter()
        lft_loss_meter = Meter()
        rgt_loss_meter = Meter()
        test_act_loss_meter = Meter()
        test_lft_loss_meter = Meter()
        test_rgt_loss_meter = Meter()

        for info in infos:
            if 'act_loss' in info:
                act_loss_meter.update(info['act_loss'])
            if 'lft_loss' in info:
                lft_loss_meter.update(info['lft_loss'])
            if 'rgt_loss' in info:
                rgt_loss_meter.update(info['rgt_loss'])
            if 'test_act_loss' in info:
                test_act_loss_meter.update(info['test_act_loss'])
            if 'test_lft_loss' in info:
                test_lft_loss_meter.update(info['test_lft_loss'])
            if 'test_rgt_loss' in info:
                test_rgt_loss_meter.update(info['test_rgt_loss'])

        Log.out(
            "PROOFTRACE SYN RUN", {
                'epoch': self._epoch,
                'run_time': "{:.2f}".format(time.time() - run_start),
                'update_count': len(infos),
                'update_delta': "{:.2f}".format(update_delta),
                'act_loss': "{:.4f}".format(act_loss_meter.avg or 0.0),
                'lft_loss': "{:.4f}".format(lft_loss_meter.avg or 0.0),
                'rgt_loss': "{:.4f}".format(rgt_loss_meter.avg or 0.0),
                'test_act_loss': "{:.4f}".format(test_act_loss_meter.avg
                                                 or 0.0),
                'test_lft_loss': "{:.4f}".format(test_lft_loss_meter.avg
                                                 or 0.0),
                'test_rgt_loss': "{:.4f}".format(test_rgt_loss_meter.avg
                                                 or 0.0),
            })

        if self._tb_writer is not None:
            self._tb_writer.add_scalar(
                "prooftrace_lm_train/update_delta",
                update_delta,
                self._epoch,
            )
            self._tb_writer.add_scalar(
                "prooftrace_lm_train/update_count",
                len(infos),
                self._epoch,
            )
            if act_loss_meter.avg is not None:
                self._tb_writer.add_scalar(
                    "prooftrace_lm_train/act_loss",
                    act_loss_meter.avg,
                    self._epoch,
                )
            if lft_loss_meter.avg is not None:
                self._tb_writer.add_scalar(
                    "prooftrace_lm_train/lft_loss",
                    lft_loss_meter.avg,
                    self._epoch,
                )
            if rgt_loss_meter.avg is not None:
                self._tb_writer.add_scalar(
                    "prooftrace_lm_train/rgt_loss",
                    rgt_loss_meter.avg,
                    self._epoch,
                )

            if test_act_loss_meter.avg is not None:
                self._tb_writer.add_scalar(
                    "prooftrace_lm_test/act_loss",
                    test_act_loss_meter.avg,
                    self._epoch,
                )
            if test_lft_loss_meter.avg is not None:
                self._tb_writer.add_scalar(
                    "prooftrace_lm_test/lft_loss",
                    test_lft_loss_meter.avg,
                    self._epoch,
                )
            if test_rgt_loss_meter.avg is not None:
                self._tb_writer.add_scalar(
                    "prooftrace_lm_test/rgt_loss",
                    test_rgt_loss_meter.avg,
                    self._epoch,
                )

        self._epoch += 1

        if self._epoch % 100 == 0:
            self.save()
Example #22
0
    def run_once(
        self,
        epoch,
    ):
        for m in self._modules:
            self._modules[m].train()

        info = self._ack.fetch(self._device, epoch == 0)
        if info is not None:
            self.update(info['config'])

        stp_reward_meter = Meter()
        mtc_reward_meter = Meter()
        fnl_reward_meter = Meter()
        act_loss_meter = Meter()
        val_loss_meter = Meter()
        act_entropy_meter = Meter()
        ptr_entropy_meter = Meter()
        match_count_meter = Meter()
        run_length_meter = Meter()
        demo_length_meter = Meter()
        demo_delta_meter = Meter()

        frame_count = 0

        for step in range(self._rollout_size):
            with torch.no_grad():
                (idx, act, arg) = self._rollouts.observations[step]

                action_embeds = self._modules['E'](act).detach()
                argument_embeds = self._modules['E'](arg).detach()

                hiddens = self._modules['T'](action_embeds, argument_embeds)

                heads = torch.cat(
                    [hiddens[i][idx[i]].unsqueeze(0) for i in range(len(idx))],
                    dim=0)
                targets = torch.cat([
                    action_embeds[i][0].unsqueeze(0) for i in range(len(idx))
                ],
                                    dim=0)

                prd_actions, prd_lefts, prd_rights = \
                    self._modules['PH'](heads, hiddens, targets)

                values = \
                    self._modules['VH'](heads, targets)

                actions, count = self._pool.explore(
                    prd_actions,
                    prd_lefts,
                    prd_rights,
                    self._explore_alpha,
                    self._explore_beta,
                    self._explore_beta_width,
                )
                frame_count += count

                observations, rewards, dones, infos = self._pool.step(
                    [tuple(a) for a in actions.detach().cpu().numpy()],
                    self._step_reward_prob,
                    self._match_reward_prob,
                    self._reset_gamma,
                    self._fixed_gamma,
                )
                frame_count += actions.size(0)
                for i, info in enumerate(infos):
                    if 'match_count' in info:
                        assert dones[i]
                        match_count_meter.update(info['match_count'])
                    if 'run_length' in info:
                        assert dones[i]
                        run_length_meter.update(info['run_length'])
                    if 'demo_length' in info:
                        assert dones[i]
                        demo_length_meter.update(info['demo_length'])
                    if 'demo_delta' in info:
                        assert dones[i]
                        demo_delta_meter.update(info['demo_delta'])

                log_probs = torch.cat((
                    prd_actions.gather(1, actions[:, 0].unsqueeze(1)),
                    prd_lefts.gather(1, actions[:, 1].unsqueeze(1)),
                    prd_rights.gather(1, actions[:, 2].unsqueeze(1)),
                ),
                                      dim=1)

            for i, r in enumerate(rewards):
                self._episode_stp_reward[i] += r[0]
                self._episode_mtc_reward[i] += r[1]
                self._episode_fnl_reward[i] += r[2]
                if dones[i]:
                    stp_reward_meter.update(self._episode_stp_reward[i])
                    mtc_reward_meter.update(self._episode_mtc_reward[i])
                    fnl_reward_meter.update(self._episode_fnl_reward[i])
                    self._episode_stp_reward[i] = 0.0
                    self._episode_mtc_reward[i] = 0.0
                    self._episode_fnl_reward[i] = 0.0

            self._rollouts.insert(
                step,
                observations,
                actions.detach(),
                log_probs.detach(),
                values.detach(),
                torch.tensor(
                    [(r[0] + r[1] + r[2]) for r in rewards],
                    dtype=torch.int64,
                ).unsqueeze(1).to(self._device),
                torch.tensor([[0.0] if d else [1.0]
                              for d in dones], ).to(self._device),
            )

        with torch.no_grad():
            (idx, act, arg) = self._rollouts.observations[-1]

            action_embeds = self._modules['E'](act)
            argument_embeds = self._modules['E'](arg)

            hiddens = self._modules['T'](action_embeds, argument_embeds)

            heads = torch.cat(
                [hiddens[i][idx[i]].unsqueeze(0) for i in range(len(idx))],
                dim=0)
            targets = torch.cat(
                [action_embeds[i][0].unsqueeze(0) for i in range(len(idx))],
                dim=0)

            values = \
                self._modules['VH'](heads, targets)

            self._rollouts.compute_returns(values.detach())

            advantages = \
                self._rollouts.returns[:-1] - self._rollouts.values[:-1]
            advantages = \
                (advantages - advantages.mean()) / (advantages.std() + 1e-5)

        ignored = False
        for e in range(self._epoch_count):
            if ignored:
                continue

            generator = self._rollouts.generator(advantages)

            for batch in generator:
                if ignored:
                    continue

                rollout_observations, \
                    rollout_actions, \
                    rollout_values, \
                    rollout_returns, \
                    rollout_masks, \
                    rollout_log_probs, \
                    rollout_advantages = batch

                (idx, act, arg) = rollout_observations

                action_embeds = self._modules['E'](act)
                argument_embeds = self._modules['E'](arg)

                hiddens = self._modules['T'](action_embeds, argument_embeds)

                heads = torch.cat(
                    [hiddens[i][idx[i]].unsqueeze(0) for i in range(len(idx))],
                    dim=0)
                targets = torch.cat([
                    action_embeds[i][0].unsqueeze(0) for i in range(len(idx))
                ],
                                    dim=0)

                prd_actions, prd_lefts, prd_rights = \
                    self._modules['PH'](heads, hiddens, targets)

                values = \
                    self._modules['VH'](heads, targets)

                log_probs = torch.cat((
                    prd_actions.gather(1, rollout_actions[:, 0].unsqueeze(1)),
                    prd_lefts.gather(1, rollout_actions[:, 1].unsqueeze(1)),
                    prd_rights.gather(1, rollout_actions[:, 2].unsqueeze(1)),
                ),
                                      dim=1)

                act_entropy = -((prd_actions * torch.exp(prd_actions)).mean())
                ptr_entropy = -((prd_lefts * torch.exp(prd_lefts)).mean() +
                                (prd_rights * torch.exp(prd_rights)).mean())

                # Clipped action loss.
                ratio = torch.exp(log_probs - rollout_log_probs)
                action_loss = -torch.min(
                    ratio * rollout_advantages,
                    torch.clamp(ratio, 1.0 - self._clip, 1.0 + self._clip) *
                    rollout_advantages,
                ).mean()

                # Clipped value loss.
                # clipped_values = rollout_values + \
                #     (values - rollout_values).clamp(-self._clip, self._clip)
                # value_loss = torch.max(
                #     F.mse_loss(values, rollout_returns),
                #     F.mse_loss(clipped_values, rollout_returns),
                # )
                value_loss = F.mse_loss(values, rollout_returns)

                # Log.out("RATIO/ADV/LOSS", {
                #     'clipped_ratio': torch.clamp(
                #         ratio, 1.0 - self._clip, 1.0 + self._clip
                #     ).mean().item(),
                #     'ratio': ratio.mean().item(),
                #     'advantages': rollout_advantages.mean().item(),
                #     'action_loss': action_loss.item(),
                # })

                if (abs(action_loss.item()) > 10e2
                        or abs(value_loss.item()) > 10e5
                        or math.isnan(value_loss.item())
                        or math.isnan(act_entropy.item())
                        or math.isnan(ptr_entropy.item())):
                    Log.out(
                        "IGNORING", {
                            'epoch': epoch,
                            'act_loss': "{:.4f}".format(action_loss.item()),
                            'val_loss': "{:.4f}".format(value_loss.item()),
                            'act_entropy': "{:.4f}".format(act_entropy.item()),
                            'ptr_entropy': "{:.4f}".format(ptr_entropy.item()),
                        })
                    ignored = True
                else:
                    # Backward pass.
                    for m in self._modules:
                        self._modules[m].zero_grad()

                    (action_loss + self._value_coeff * value_loss -
                     (self._act_entropy_coeff * act_entropy +
                      self._ptr_entropy_coeff * ptr_entropy)).backward()

                    if self._grad_norm_max > 0.0:
                        torch.nn.utils.clip_grad_norm_(
                            self._modules['E'].parameters(),
                            self._grad_norm_max,
                        )
                        torch.nn.utils.clip_grad_norm_(
                            self._modules['T'].parameters(),
                            self._grad_norm_max,
                        )
                        torch.nn.utils.clip_grad_norm_(
                            self._modules['VH'].parameters(),
                            self._grad_norm_max,
                        )
                        torch.nn.utils.clip_grad_norm_(
                            self._modules['PH'].parameters(),
                            self._grad_norm_max,
                        )

                    act_loss_meter.update(action_loss.item())
                    val_loss_meter.update(value_loss.item())
                    act_entropy_meter.update(act_entropy.item())
                    ptr_entropy_meter.update(ptr_entropy.item())

                    info = {
                        'frame_count': frame_count,
                        'act_loss': act_loss_meter.avg,
                        'val_loss': val_loss_meter.avg,
                        'act_entropy': act_entropy_meter.avg,
                        'ptr_entropy': ptr_entropy_meter.avg,
                    }
                    if match_count_meter.avg:
                        info['match_count'] = match_count_meter.avg
                    if run_length_meter.avg:
                        info['run_length'] = run_length_meter.avg
                    if demo_length_meter.avg:
                        info['demo_length_avg'] = demo_length_meter.avg
                    if demo_length_meter.max:
                        info['demo_length_max'] = demo_length_meter.max
                    if demo_delta_meter.avg:
                        info['demo_delta'] = demo_delta_meter.avg
                    if stp_reward_meter.avg:
                        info['stp_reward'] = stp_reward_meter.avg
                    if mtc_reward_meter.avg:
                        info['mtc_reward'] = mtc_reward_meter.avg
                    if fnl_reward_meter.avg:
                        info['fnl_reward'] = fnl_reward_meter.avg

                    self._ack.push(info, None)
                    if frame_count > 0:
                        frame_count = 0

                info = self._ack.fetch(self._device)
                if info is not None:
                    self.update(info['config'])

        self._rollouts.after_update()

        Log.out(
            "PROOFTRACE PPO ACK RUN", {
                'epoch':
                epoch,
                'ignored':
                ignored,
                'match_count':
                "{:.2f}".format(match_count_meter.avg or 0.0),
                'run_length':
                "{:.2f}".format(run_length_meter.avg or 0.0),
                'demo_length':
                "{:.2f}/{:.0f}".format(
                    demo_length_meter.avg or 0.0,
                    demo_length_meter.max or 0.0,
                ),
                'demo_delta':
                "{:.4f}".format(demo_delta_meter.avg or 0.0),
                'stp_reward':
                "{:.4f}".format(stp_reward_meter.avg or 0.0),
                'mtc_reward':
                "{:.4f}".format(mtc_reward_meter.avg or 0.0),
                'fnl_reward':
                "{:.4f}".format(fnl_reward_meter.avg or 0.0),
                'act_loss':
                "{:.4f}".format(act_loss_meter.avg or 0.0),
                'val_loss':
                "{:.4f}".format(val_loss_meter.avg or 0.0),
                'act_entropy':
                "{:.4f}".format(act_entropy_meter.avg or 0.0),
                'ptr_entropy':
                "{:.4f}".format(ptr_entropy_meter.avg or 0.0),
            })