def _eval(self, model, data_loader): model.eval() meter = Meter(self.tasks) for i, batch_data in enumerate(data_loader): smiless, inputs, labels, masks = self._prepare_batch_data( batch_data) _, predictions = model(inputs) if self.norm: predictions = predictions * self.data_std + self.data_mean meter.update(predictions, labels, masks) eval_results_dict = meter.compute_metric(self.metrics) return eval_results_dict
def batch_test(self, ): assert self._test_loader is not None self._model.eval() loss_meter = Meter() hit = 0 total = 0 pos = 0 neg = 0 with torch.no_grad(): for it, (cnj, thr, pre) in enumerate(self._test_loader): res = self._model( cnj.to(self._device), thr.to(self._device), ) loss = F.binary_cross_entropy(res, pre.to(self._device)) loss_meter.update(loss.item()) limit = 0.50 for i in range(res.size(0)): if res[i].item() >= limit and pre[i].item() >= limit: hit += 1 if res[i].item() < limit and pre[i].item() < limit: hit += 1 if res[i].item() >= limit: pos += 1 if res[i].item() < limit: neg += 1 total += 1 Log.out( "TH2VEC TEST", { 'batch_count': self._train_batch, 'loss_avg': loss_meter.avg, 'hit_rate': "{:.3f}".format(hit / total), 'pos_rate': "{:.3f}".format(pos / total), }) if self._tb_writer is not None: self._tb_writer.add_scalar( "test/th2vec/direct_premiser/loss", loss_meter.avg, self._train_batch, ) self._tb_writer.add_scalar( "test/th2vec/direct_premiser/hit_rate", hit / total, self._train_batch, )
def batch_train( self, epoch, ): assert self._train_loader is not None self._model.train() loss_meter = Meter() if self._config.get('distributed_training'): self._train_sampler.set_epoch(epoch) # self._scheduler.step() for it, (cnj, thr, pre) in enumerate(self._train_loader): res = self._model( cnj.to(self._device), thr.to(self._device), ) loss = F.binary_cross_entropy(res, pre.to(self._device)) loss.backward() if it % self._accumulation_step_count == 0: self._optimizer.step() self._optimizer.zero_grad() loss_meter.update(loss.item()) self._train_batch += 1 if self._train_batch % 10 == 0: Log.out("TH2VEC TRAIN", { 'train_batch': self._train_batch, 'loss_avg': loss_meter.avg, }) if self._tb_writer is not None: self._tb_writer.add_scalar( "train/th2vec/direct_premiser/loss", loss_meter.avg, self._train_batch, ) loss_meter = Meter() Log.out( "EPOCH DONE", { 'epoch': epoch, # 'learning_rate': self._scheduler.get_lr(), })
def run_once(self, ): run_start = time.time() infos = self._ctl.aggregate() if len(infos) == 0: time.sleep(10) return rll_cnt_meter = Meter() pos_cnt_meter = Meter() neg_cnt_meter = Meter() demo_len_meter = Meter() for info in infos: rll_cnt_meter.update(info['rll_cnt']) pos_cnt_meter.update(info['pos_cnt']) neg_cnt_meter.update(info['neg_cnt']) if 'demo_len' in info: demo_len_meter.update(info['demo_len']) Log.out( "PROOFTRACE BEAM ROLLOUT CTL RUN", { 'epoch': self._epoch, 'run_time': "{:.2f}".format(time.time() - run_start), 'update_count': len(infos), 'rll_cnt': "{:.4f}".format(rll_cnt_meter.sum or 0.0), 'pos_cnt': "{:.4f}".format(pos_cnt_meter.avg or 0.0), 'neg_cnt': "{:.4f}".format(neg_cnt_meter.avg or 0.0), 'demo_len': "{:.4f}".format(demo_len_meter.avg or 0.0), }) if self._tb_writer is not None: if rll_cnt_meter.avg is not None: self._tb_writer.add_scalar( "prooftrace_search_rollout/rll_cnt", rll_cnt_meter.sum, self._epoch, ) if pos_cnt_meter.avg is not None: self._tb_writer.add_scalar( "prooftrace_search_rollout/pos_cnt", pos_cnt_meter.avg, self._epoch, ) if neg_cnt_meter.avg is not None: self._tb_writer.add_scalar( "prooftrace_search_rollout/neg_cnt", neg_cnt_meter.avg, self._epoch, ) if demo_len_meter.avg is not None: self._tb_writer.add_scalar( "prooftrace_search_rollout/demo_len", demo_len_meter.avg, self._epoch, ) self._epoch += 1
def run_once(self, ): run_start = time.time() infos = self._ctl.aggregate() if len(infos) == 0: time.sleep(10) return demo_len_meter = Meter() gamma_meters = {} for gamma in GAMMAS: gamma_meters['gamma_{}'.format(gamma)] = Meter() for info in infos: demo_len_meter.update(info['demo_len']) for gamma in GAMMAS: key = 'gamma_{}'.format(gamma) gamma_meters[key].update(info[key]) out = { 'epoch': self._epoch, 'run_time': "{:.2f}".format(time.time() - run_start), 'update_count': len(infos), 'demo_len': "{:.4f}".format(demo_len_meter.avg or 0.0), } for gamma in GAMMAS: key = 'gamma_{}'.format(gamma) out[key] = "{:.4f}".format(gamma_meters[key].avg or 0.0) Log.out("PROOFTRACE BEAM CTL RUN", out) if self._tb_writer is not None: if demo_len_meter.avg is not None: self._tb_writer.add_scalar( "prooftrace_search_ctl/demo_len", demo_len_meter.avg, self._epoch, ) for gamma in GAMMAS: key = 'gamma_{}'.format(gamma) if gamma_meters[key].avg is not None: self._tb_writer.add_scalar( "prooftrace_search_ctl/{}".format(key), gamma_meters[key].avg, self._epoch, ) self._epoch += 1
def batch_test(self, ): assert self._test_loader is not None self._model.eval() loss_meter = Meter() hit = 0 total = 0 with torch.no_grad(): for it, (cl_pos, cl_neg, sats) in enumerate(self._test_loader): generated = self._model( cl_pos.to(self._device), cl_neg.to(self._device), ) loss = F.binary_cross_entropy(generated, sats.to(self._device)) loss_meter.update(loss.item()) for i in range(generated.size(0)): if generated[i].item() >= 0.5 and sats[i].item() >= 0.5: hit += 1 if generated[i].item() < 0.5 and sats[i].item() < 0.5: hit += 1 total += 1 Log.out( "SAT TEST", { 'batch_count': self._train_batch, 'loss_avg': loss_meter.avg, 'hit_rate': "{:.2f}".format(hit / total), }) if self._tb_writer is not None: self._tb_writer.add_scalar( "test/sat/loss", loss_meter.avg, self._train_batch, ) self._tb_writer.add_scalar( "test/sat/hit_rate", hit / total, self._train_batch, ) return loss_meter.avg
def batch_test(self, ): assert self._test_loader is not None self._model.eval() rec_loss_meter = Meter() with torch.no_grad(): for it, trm in enumerate(self._test_loader): mu, _ = self._inner_model.encode(trm.to(self._device)) trm_rec = self._inner_model.decode(mu) rec_loss = self._loss( trm_rec.view(-1, trm_rec.size(2)), trm.to(self._device).view(-1), ) rec_loss_meter.update(rec_loss.item()) if it == 0: trm_smp = self._inner_model.sample(trm_rec) Log.out("<<<", { 'term': self._kernel.detokenize(trm[0].data.numpy(), ), }) Log.out( ">>>", { 'term': self._kernel.detokenize( trm_smp[0].cpu().data.numpy(), ), }) Log.out("TH2VEC TEST", { 'batch_count': self._train_batch, 'loss_avg': rec_loss_meter.avg, }) if self._tb_writer is not None: self._tb_writer.add_scalar( "test/th2vec/autoencoder_embedder/rec_loss", rec_loss_meter.avg, self._train_batch, ) rec_loss_meter = Meter()
def batch_test( self, ): assert self._test_loader is not None self._model.eval() loss_meter = Meter() hit = 0 total = 0 with torch.no_grad(): for it, (cnj, thr, pre) in enumerate(self._test_loader): with torch.no_grad(): cnj_emd = self._embedder.encode(cnj.to(self._device)) thr_emd = self._embedder.encode(thr.to(self._device)) res = self._model(cnj_emd.detach(), thr_emd.detach()) loss = F.binary_cross_entropy(res, pre.to(self._device)) loss_meter.update(loss.item()) for i in range(res.size(0)): if res[i].item() >= 0.5 and pre[i].item() >= 0.5: hit += 1 if res[i].item() < 0.5 and pre[i].item() < 0.5: hit += 1 total += 1 Log.out("TH2VEC TEST", { 'batch_count': self._train_batch, 'loss_avg': loss_meter.avg, 'hit_rate': "{:.3f}".format(hit / total), }) if self._tb_writer is not None: self._tb_writer.add_scalar( "test/th2vec/premiser/loss", loss_meter.avg, self._train_batch, ) self._tb_writer.add_scalar( "test/th2vec/premiser/hit_rate", hit / total, self._train_batch, )
def run_once( self, epoch, ): val_loss_meter = Meter() with torch.no_grad(): for it, (idx, act, arg, trh, val) in enumerate(self._test_loader): self._ack.fetch(self._device, blocking=False) self._model.eval() prd_values = self._model.infer(idx, act, arg) values = torch.tensor(val, dtype=torch.float).unsqueeze(-1).to( self._device) val_loss = self._mse_loss(prd_values, values) val_loss_meter.update(val_loss.item()) info = { 'test_val_loss': val_loss_meter.avg, } self._ack.push(info, None, True) Log.out( "PROOFTRACE V TST RUN", { 'epoch': epoch, 'val_loss_avg': "{:.4f}".format(val_loss.item()), }) self._train_batch += 1 Log.out("EPOCH DONE", { 'epoch': epoch, })
def test(self, ): assert self._test_loader is not None self._model_E.eval() self._model_H.eval() self._model_PH.eval() self._model_VH.eval() act_loss_meter = Meter() lft_loss_meter = Meter() rgt_loss_meter = Meter() # val_loss_meter = Meter() test_batch = 0 with torch.no_grad(): for it, (idx, act, arg, trh, val) in enumerate(self._test_loader): action_embeds = self._model_E(act) argument_embeds = self._model_E(arg) hiddens = self._model_H(action_embeds, argument_embeds) heads = torch.cat( [hiddens[i][idx[i]].unsqueeze(0) for i in range(len(idx))], dim=0) targets = torch.cat([ action_embeds[i][0].unsqueeze(0) for i in range(len(idx)) ], dim=0) actions = torch.tensor([ trh[i].value - len(PREPARE_TOKENS) for i in range(len(trh)) ], dtype=torch.int64).to(self._device) lefts = torch.tensor( [arg[i].index(trh[i].left) for i in range(len(trh))], dtype=torch.int64).to(self._device) rights = torch.tensor( [arg[i].index(trh[i].right) for i in range(len(trh))], dtype=torch.int64).to(self._device) # values = torch.tensor(val).unsqueeze(1).to(self._device) prd_actions, prd_lefts, prd_rights = \ self._model_PH(heads, hiddens, targets) # prd_values = self._model_VH(heads, targets) act_loss = self._nll_loss(prd_actions, actions) lft_loss = self._nll_loss(prd_lefts, lefts) rgt_loss = self._nll_loss(prd_rights, rights) # val_loss = self._mse_loss(prd_values, values) act_loss_meter.update(act_loss.item()) lft_loss_meter.update(lft_loss.item()) rgt_loss_meter.update(rgt_loss.item()) # val_loss_meter.update(val_loss.item()) Log.out( "TEST BATCH", { 'train_batch': self._train_batch, 'test_batch': test_batch, 'act_loss_avg': "{:.4f}".format(act_loss.item()), 'lft_loss_avg': "{:.4f}".format(lft_loss.item()), 'rgt_loss_avg': "{:.4f}".format(rgt_loss.item()), # 'val_loss_avg': "{:.4f}".format(val_loss.item()), }) test_batch += 1 Log.out( "PROOFTRACE TEST", { 'train_batch': self._train_batch, 'act_loss_avg': "{:.4f}".format(act_loss_meter.avg), 'lft_loss_avg': "{:.4f}".format(lft_loss_meter.avg), 'rgt_loss_avg': "{:.4f}".format(rgt_loss_meter.avg), # 'val_loss_avg': "{:.4f}".format(val_loss_meter.avg), }) if self._tb_writer is not None: self._tb_writer.add_scalar( "prooftrace_lm_test/act_loss", act_loss_meter.avg, self._train_batch, ) self._tb_writer.add_scalar( "prooftrace_lm_test/lft_loss", lft_loss_meter.avg, self._train_batch, ) self._tb_writer.add_scalar( "prooftrace_lm_test/rgt_loss", rgt_loss_meter.avg, self._train_batch, )
def batch_train( self, epoch, ): assert self._train_loader is not None self._model_E.train() self._model_H.train() self._model_PH.train() self._model_VH.train() act_loss_meter = Meter() lft_loss_meter = Meter() rgt_loss_meter = Meter() # val_loss_meter = Meter() if self._config.get('distributed_training'): self._train_sampler.set_epoch(epoch) for it, (idx, act, arg, trh, val) in enumerate(self._train_loader): action_embeds = self._model_E(act) argument_embeds = self._model_E(arg) # action_embeds = \ # torch.zeros(action_embeds.size()).to(self._device) # argument_embeds = \ # torch.zeros(argument_embeds.size()).to(self._device) hiddens = self._model_H(action_embeds, argument_embeds) heads = torch.cat( [hiddens[i][idx[i]].unsqueeze(0) for i in range(len(idx))], dim=0) targets = torch.cat( [action_embeds[i][0].unsqueeze(0) for i in range(len(idx))], dim=0) actions = torch.tensor( [trh[i].value - len(PREPARE_TOKENS) for i in range(len(trh))], dtype=torch.int64).to(self._device) lefts = torch.tensor( [arg[i].index(trh[i].left) for i in range(len(trh))], dtype=torch.int64).to(self._device) rights = torch.tensor( [arg[i].index(trh[i].right) for i in range(len(trh))], dtype=torch.int64).to(self._device) # values = torch.tensor(val).unsqueeze(1).to(self._device) prd_actions, prd_lefts, prd_rights = \ self._model_PH(heads, hiddens, targets) # prd_values = self._model_VH(heads, targets) act_loss = self._nll_loss(prd_actions, actions) lft_loss = self._nll_loss(prd_lefts, lefts) rgt_loss = self._nll_loss(prd_rights, rights) # val_loss = self._mse_loss(prd_values, values) # (act_loss + lft_loss + rgt_loss + # self._value_coeff * val_loss).backward() (act_loss + lft_loss + rgt_loss).backward() if it % self._accumulation_step_count == 0: self._optimizer.step() self._optimizer.zero_grad() act_loss_meter.update(act_loss.item()) lft_loss_meter.update(lft_loss.item()) rgt_loss_meter.update(rgt_loss.item()) # val_loss_meter.update(val_loss.item()) Log.out( "TRAIN BATCH", { 'train_batch': self._train_batch, 'act_loss_avg': "{:.4f}".format(act_loss.item()), 'lft_loss_avg': "{:.4f}".format(lft_loss.item()), 'rgt_loss_avg': "{:.4f}".format(rgt_loss.item()), # 'val_loss_avg': "{:.4f}".format(val_loss.item()), }) if self._train_batch % 10 == 0 and self._train_batch != 0: Log.out( "PROOFTRACE TRAIN", { 'epoch': epoch, 'train_batch': self._train_batch, 'act_loss_avg': "{:.4f}".format(act_loss_meter.avg), 'lft_loss_avg': "{:.4f}".format(lft_loss_meter.avg), 'rgt_loss_avg': "{:.4f}".format(rgt_loss_meter.avg), # 'val_loss_avg': "{:.4f}".format(val_loss_meter.avg), }) if self._tb_writer is not None: self._tb_writer.add_scalar( "prooftrace_lm_train/act_loss", act_loss_meter.avg, self._train_batch, ) self._tb_writer.add_scalar( "prooftrace_lm_train/lft_loss", lft_loss_meter.avg, self._train_batch, ) self._tb_writer.add_scalar( "prooftrace_lm_train/rgt_loss", rgt_loss_meter.avg, self._train_batch, ) # self._tb_writer.add_scalar( # "prooftrace_lm_train/val_loss", # val_loss_meter.avg, self._train_batch, # ) act_loss_meter = Meter() lft_loss_meter = Meter() rgt_loss_meter = Meter() # val_loss_meter = Meter() if self._train_batch % 1000 == 0: self.save() self.test() self._model_E.train() self._model_H.train() self._model_PH.train() self._model_VH.train() self.update() self._train_batch += 1 Log.out("EPOCH DONE", { 'epoch': epoch, })
def batch_train( self, epoch, ): assert self._train_loader is not None self._model_G.train() self._model_D.train() dis_loss_meter = Meter() gen_loss_meter = Meter() if self._config.get('distributed_training'): self._train_sampler.set_epoch(epoch) for it, trm in enumerate(self._train_loader): nse = torch.randn( trm.size(0), self._config.get('th2vec_transformer_hidden_size'), ).to(self._device) trm_rel = trm.to(self._device) trm_gen = self._model_G(nse) m = Categorical(torch.exp(trm_gen)) trm_smp = m.sample() # onh_rel = self._inner_model_D.one_hot(trm_rel) # onh_gen = self._inner_model_D.one_hot(trm_smp) # onh_gen.requires_grad = True dis_rel = self._model_D(trm_rel) dis_gen = self._model_D(trm_smp) # Label smoothing ones = torch.ones(*dis_rel.size()) - \ torch.rand(*dis_rel.size()) / 10 zeros = torch.zeros(*dis_rel.size()) + \ torch.rand(*dis_rel.size()) / 10 dis_loss_rel = \ F.binary_cross_entropy(dis_rel, ones.to(self._device)) dis_loss_gen = \ F.binary_cross_entropy(dis_gen, zeros.to(self._device)) dis_loss = dis_loss_rel + dis_loss_gen self._optimizer_D.zero_grad() dis_loss.backward() self._optimizer_D.step() # REINFORCE # reward = onh_gen.grad.detach() # reward -= \ # torch.mean(reward, 2).unsqueeze(-1).expand(*reward.size()) # reward /= \ # torch.std(reward, 2).unsqueeze(-1).expand(*reward.size()) # gen_loss = trm_gen * reward # (-logp * -grad_dis) # gen_loss = gen_loss.mean() reward = dis_gen.squeeze(1).detach() reward -= torch.mean(reward) reward /= torch.std(reward) gen_loss = -m.log_prob(trm_smp).mean(1) * reward gen_loss = gen_loss.mean() self._optimizer_G.zero_grad() gen_loss.backward() self._optimizer_G.step() dis_loss_meter.update(dis_loss.item()) gen_loss_meter.update(gen_loss.item()) self._train_batch += 1 if self._train_batch % 10 == 0: Log.out("TH2VEC GENERATOR TRAIN", { 'train_batch': self._train_batch, 'dis_loss_avg': dis_loss_meter.avg, 'gen_loss_avg': gen_loss_meter.avg, }) Log.out("<<<", { 'term': self._kernel.detokenize( trm_smp[0].cpu().data.numpy(), ), }) if self._tb_writer is not None: self._tb_writer.add_scalar( "train/th2vec/generator/dis_loss", dis_loss_meter.avg, self._train_batch, ) self._tb_writer.add_scalar( "train/th2vec/generator/gen_loss", gen_loss_meter.avg, self._train_batch, ) dis_loss_meter = Meter() gen_loss_meter = Meter() Log.out("EPOCH DONE", { 'epoch': epoch, })
def run_once( self, epoch: int, ): assert self._optimizer is not None for m in self._modules: self._modules[m].train() reward_meter = Meter() act_loss_meter = Meter() val_loss_meter = Meter() entropy_meter = Meter() for step in range(self._rollout_size): with torch.no_grad(): obs = self._rollouts._observations[step] # hiddens = self._modules['CNN'](obs) hiddens = self._modules['MLP'](obs) prd_actions = self._modules['PH'](hiddens) values = self._modules['VH'](hiddens) m = Categorical(torch.exp(prd_actions)) actions = m.sample().view(-1, 1) observations, rewards, dones, infos = self._envs.step( actions.squeeze(1).cpu().numpy(), ) # observations = torch.from_numpy( # observations, # ).float().transpose(3, 1).to(self._device) / 255.0 observations = torch.from_numpy( observations, ).float().to(self._device) log_probs = prd_actions.gather(1, actions) for i, r in enumerate(rewards): self._episode_rewards[i] += r / \ self._config.get("energy_reward_scaling") if dones[i]: reward_meter.update(self._episode_rewards[i]) self._episode_rewards[i] = 0.0 self._rollouts.insert( step, observations, actions.detach(), log_probs.detach(), values.detach(), torch.tensor( [ r / self._config.get("energy_reward_scaling") for r in rewards ], dtype=torch.float, ).unsqueeze(1).to(self._device), torch.tensor( [[0.0] if d else [1.0] for d in dones], ).to(self._device), ) with torch.no_grad(): obs = self._rollouts._observations[-1] # hiddens = self._modules['CNN'](obs) hiddens = self._modules['MLP'](obs) values = self._modules['VH'](hiddens) self._rollouts.compute_returns(values.detach()) advantages = \ self._rollouts._returns[:-1] - self._rollouts._values[:-1] advantages = \ (advantages - advantages.mean()) / (advantages.std() + 1e-5) for e in range(self._epoch_count): generator = self._rollouts.generator(advantages, self._batch_size) for batch in generator: rollout_observations, \ rollout_actions, \ rollout_values, \ rollout_returns, \ rollout_masks, \ rollout_log_probs, \ rollout_advantages = batch # hiddens = self._modules['CNN'](rollout_observations) hiddens = self._modules['MLP'](rollout_observations) prd_actions = self._modules['PH'](hiddens) values = self._modules['VH'](hiddens) log_probs = prd_actions.gather(1, rollout_actions) entropy = -(prd_actions * torch.exp(prd_actions)).mean() # Clipped action loss. ratio = torch.exp(log_probs - rollout_log_probs) action_loss = -torch.min( ratio * rollout_advantages, torch.clamp(ratio, 1.0 - self._clip, 1.0 + self._clip) * rollout_advantages, ).mean() # value_loss = F.mse_loss(values, rollout_returns) value_loss = (rollout_returns.detach() - values).pow(2).mean() # Backward pass. self._optimizer.zero_grad() ( action_loss + self._value_coeff * value_loss - self._entropy_coeff * entropy ).backward() if self._grad_norm_max > 0.0: torch.nn.utils.clip_grad_norm_( # self._modules['CNN'].parameters(), self._modules['MLP'].parameters(), self._grad_norm_max, ) torch.nn.utils.clip_grad_norm_( self._modules['PH'].parameters(), self._grad_norm_max, ) torch.nn.utils.clip_grad_norm_( self._modules['VH'].parameters(), self._grad_norm_max, ) self._optimizer.step() act_loss_meter.update(action_loss.item()) val_loss_meter.update(value_loss.item()) entropy_meter.update(entropy.item()) self._rollouts.after_update() if reward_meter.avg: if self._reward_tracker is None: self._reward_tracker = reward_meter.avg self._reward_tracker = \ 0.99 * self._reward_tracker + 0.01 * reward_meter.avg Log.out("ENERGY PPO RUN", { 'epoch': epoch, 'reward': "{:.4f}".format(reward_meter.avg or 0.0), 'act_loss': "{:.4f}".format(act_loss_meter.avg or 0.0), 'val_loss': "{:.4f}".format(val_loss_meter.avg or 0.0), 'entropy': "{:.4f}".format(entropy_meter.avg or 0.0), 'tracker': "{:.4f}".format(self._reward_tracker), })
def run_once( self, epoch, ): act_loss_meter = Meter() lft_loss_meter = Meter() rgt_loss_meter = Meter() with torch.no_grad(): for it, (act, arg, trh) in enumerate(self._test_loader): self._ack.fetch(self._device, blocking=False) self._model.eval() trh_actions, trh_lefts, trh_rights = trh_extract(trh, arg) # Because we can't run a pointer network on the full length # (memory), we extract indices to focus loss on. idx = random.sample(range(self._sequence_length), 64) actions = torch.index_select( torch.tensor(trh_actions, dtype=torch.int64), 1, torch.tensor(idx, dtype=torch.int64), ).to(self._device) lefts = torch.index_select( torch.tensor(trh_lefts, dtype=torch.int64), 1, torch.tensor(idx, dtype=torch.int64), ).to(self._device) rights = torch.index_select( torch.tensor(trh_rights, dtype=torch.int64), 1, torch.tensor(idx, dtype=torch.int64), ).to(self._device) prd_actions, prd_lefts, prd_rights = \ self._model.infer(idx, act, arg) act_loss = self._nll_loss( prd_actions.view(-1, prd_actions.size(-1)), actions.view(-1), ) lft_loss = self._nll_loss( prd_lefts.view(-1, prd_lefts.size(-1)), lefts.view(-1), ) rgt_loss = self._nll_loss( prd_rights.view(-1, prd_rights.size(-1)), rights.view(-1), ) act_loss_meter.update(act_loss.item()) lft_loss_meter.update(lft_loss.item()) rgt_loss_meter.update(rgt_loss.item()) info = { 'test_act_loss': act_loss_meter.avg, 'test_lft_loss': lft_loss_meter.avg, 'test_rgt_loss': rgt_loss_meter.avg, } self._ack.push(info, None, True) Log.out( "PROOFTRACE LM TST RUN", { 'epoch': epoch, 'act_loss_avg': "{:.4f}".format(act_loss.item()), 'lft_loss_avg': "{:.4f}".format(lft_loss.item()), 'rgt_loss_avg': "{:.4f}".format(rgt_loss.item()), }) self._train_batch += 1 Log.out("EPOCH DONE", { 'epoch': epoch, })
def batch_train( self, epoch, ): assert self._train_loader is not None self._model.train() rec_loss_meter = Meter() kld_loss_meter = Meter() all_loss_meter = Meter() if self._config.get('distributed_training'): self._train_sampler.set_epoch(epoch) # self._scheduler.step() for it, trm in enumerate(self._train_loader): trm_rec, mu, logvar = self._model(trm.to(self._device)) rec_loss = self._loss( trm_rec.view(-1, trm_rec.size(2)), trm.to(self._device).view(-1), ) kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) all_loss = rec_loss + 0.001 * kld_loss self._optimizer.zero_grad() all_loss.backward() self._optimizer.step() rec_loss_meter.update(rec_loss.item()) kld_loss_meter.update(kld_loss.item()) all_loss_meter.update(all_loss.item()) self._train_batch += 1 if self._train_batch % 10 == 0: Log.out( "TH2VEC AUTOENCODER_EMBEDDER TRAIN", { 'train_batch': self._train_batch, 'rec_loss_avg': rec_loss_meter.avg, 'kld_loss_avg': kld_loss_meter.avg, 'loss_avg': all_loss_meter.avg, }) if self._tb_writer is not None: self._tb_writer.add_scalar( "train/th2vec/autoencoder_embedder/rec_loss", rec_loss_meter.avg, self._train_batch, ) self._tb_writer.add_scalar( "train/th2vec/autoencoder_embedder/kld_loss", kld_loss_meter.avg, self._train_batch, ) self._tb_writer.add_scalar( "train/th2vec/autoencoder_embedder/all_loss", all_loss_meter.avg, self._train_batch, ) rec_loss_meter = Meter() kld_loss_meter = Meter() all_loss_meter = Meter() Log.out( "EPOCH DONE", { 'epoch': epoch, # 'learning_rate': self._scheduler.get_lr(), })
def batch_train( self, epoch, ): assert self._train_loader is not None self._model.train() loss_meter = Meter() if self._config.get('distributed_training'): self._train_sampler.set_epoch(epoch) # self._scheduler.step() for it, (cl_pos, cl_neg, sats) in enumerate(self._train_loader): generated = self._model( cl_pos.to(self._device), cl_neg.to(self._device), ) loss = F.binary_cross_entropy(generated, sats.to(self._device)) self._optimizer.zero_grad() loss.backward() self._optimizer.step() loss_meter.update(loss.item()) self._train_batch += 1 if self._train_batch % 10 == 0: hit = 0 total = 0 for i in range(generated.size(0)): if generated[i].item() >= 0.5 and sats[i].item() >= 0.5: hit += 1 if generated[i].item() < 0.5 and sats[i].item() < 0.5: hit += 1 total += 1 Log.out( "SAT TRAIN", { 'train_batch': self._train_batch, 'loss_avg': loss_meter.avg, 'hit_rate': "{:.2f}".format(hit / total), }) if self._tb_writer is not None: self._tb_writer.add_scalar( "train/sat/loss", loss_meter.avg, self._train_batch, ) self._tb_writer.add_scalar( "train/sat/hit_rate", hit / total, self._train_batch, ) loss_meter = Meter() Log.out( "EPOCH DONE", { 'epoch': epoch, # 'learning_rate': self._scheduler.get_lr(), })
def batch_test(self, ): assert self._test_loader is not None self._model.eval() all_loss_meter = Meter() thr_loss_meter = Meter() unr_loss_meter = Meter() thr_simi_meter = Meter() unr_simi_meter = Meter() with torch.no_grad(): for it, (cnj, thr, unr) in enumerate(self._test_loader): cnj_emd = self._model(cnj.to(self._device)) thr_emd = self._model(thr.to(self._device)) unr_emd = self._model(unr.to(self._device)) thr_loss = F.mse_loss(cnj_emd, thr_emd) unr_loss = F.mse_loss(cnj_emd, unr_emd) all_loss = thr_loss - unr_loss thr_simi = F.cosine_similarity(cnj_emd, thr_emd).mean() unr_simi = F.cosine_similarity(cnj_emd, unr_emd).mean() all_loss_meter.update(all_loss.item()) thr_loss_meter.update(thr_loss.item()) unr_loss_meter.update(unr_loss.item()) thr_simi_meter.update(thr_simi.item()) unr_simi_meter.update(unr_simi.item()) Log.out("TH2VEC TEST", { 'batch_count': self._train_batch, 'loss_avg': all_loss_meter.avg, }) if self._tb_writer is not None: self._tb_writer.add_scalar( "test/th2vec/premise_embedder/all_loss", all_loss_meter.avg, self._train_batch, ) self._tb_writer.add_scalar( "test/th2vec/premise_embedder/thr_loss", thr_loss_meter.avg, self._train_batch, ) self._tb_writer.add_scalar( "test/th2vec/premise_embedder/unr_loss", unr_loss_meter.avg, self._train_batch, ) self._tb_writer.add_scalar( "test/th2vec/premise_embedder/thr_simi", thr_simi_meter.avg, self._train_batch, ) self._tb_writer.add_scalar( "test/th2vec/premise_embedder/unr_simi", unr_simi_meter.avg, self._train_batch, ) all_loss_meter = Meter() thr_loss_meter = Meter() unr_loss_meter = Meter() thr_simi_meter = Meter() unr_simi_meter = Meter()
def batch_train( self, epoch, ): assert self._train_loader is not None self._model.train() all_loss_meter = Meter() thr_loss_meter = Meter() unr_loss_meter = Meter() thr_simi_meter = Meter() unr_simi_meter = Meter() if self._config.get('distributed_training'): self._train_sampler.set_epoch(epoch) self._scheduler.step() for it, (cnj, thr, unr) in enumerate(self._train_loader): cnj_emd = self._model(cnj.to(self._device)) thr_emd = self._model(thr.to(self._device)) unr_emd = self._model(unr.to(self._device)) thr_loss = F.mse_loss(cnj_emd, thr_emd) unr_loss = F.mse_loss(cnj_emd, unr_emd) all_loss = thr_loss - unr_loss thr_simi = F.cosine_similarity(cnj_emd, thr_emd).mean() unr_simi = F.cosine_similarity(cnj_emd, unr_emd).mean() self._optimizer.zero_grad() all_loss.backward() self._optimizer.step() all_loss_meter.update(all_loss.item()) thr_loss_meter.update(thr_loss.item()) unr_loss_meter.update(unr_loss.item()) thr_simi_meter.update(thr_simi.item()) unr_simi_meter.update(unr_simi.item()) self._train_batch += 1 if self._train_batch % 10 == 0: Log.out( "TH2VEC PREMISE_EMBEDDER TRAIN", { 'train_batch': self._train_batch, 'loss_avg': all_loss_meter.avg, }) if self._tb_writer is not None: self._tb_writer.add_scalar( "train/th2vec/premise_embedder/all_loss", all_loss_meter.avg, self._train_batch, ) self._tb_writer.add_scalar( "train/th2vec/premise_embedder/thr_loss", thr_loss_meter.avg, self._train_batch, ) self._tb_writer.add_scalar( "train/th2vec/premise_embedder/unr_loss", unr_loss_meter.avg, self._train_batch, ) self._tb_writer.add_scalar( "train/th2vec/premise_embedder/thr_simi", thr_simi_meter.avg, self._train_batch, ) self._tb_writer.add_scalar( "train/th2vec/premise_embedder/unr_simi", unr_simi_meter.avg, self._train_batch, ) all_loss_meter = Meter() thr_loss_meter = Meter() unr_loss_meter = Meter() thr_simi_meter = Meter() unr_simi_meter = Meter() Log.out("EPOCH DONE", { 'epoch': epoch, 'learning_rate': self._scheduler.get_lr(), })
def run_once(self, ): for m in self._modules: self._modules[m].train() run_start = time.time() self._optimizer.zero_grad() infos = self._syn.reduce(self._device, self._min_update_count) if len(infos) == 0: time.sleep(1) return self._optimizer.step() self._syn.broadcast({'config': self._config}) act_loss_meter = Meter() lft_loss_meter = Meter() rgt_loss_meter = Meter() val_loss_meter = Meter() for info in infos: act_loss_meter.update(info['act_loss']) lft_loss_meter.update(info['lft_loss']) rgt_loss_meter.update(info['rgt_loss']) val_loss_meter.update(info['val_loss']) Log.out( "PROOFTRACE LM SYN RUN", { 'epoch': self._epoch, 'run_time': "{:.2f}".format(time.time() - run_start), 'update_count': len(infos), 'act_loss': "{:.4f}".format(act_loss_meter.avg or 0.0), 'lft_loss': "{:.4f}".format(lft_loss_meter.avg or 0.0), 'rgt_loss': "{:.4f}".format(rgt_loss_meter.avg or 0.0), 'val_loss': "{:.4f}".format(val_loss_meter.avg or 0.0), }) if self._tb_writer is not None: if act_loss_meter.avg is not None: self._tb_writer.add_scalar( "prooftrace_lm_train/act_loss", act_loss_meter.avg, self._epoch, ) if lft_loss_meter.avg is not None: self._tb_writer.add_scalar( "prooftrace_lm_train/lft_loss", lft_loss_meter.avg, self._epoch, ) if rgt_loss_meter.avg is not None: self._tb_writer.add_scalar( "prooftrace_lm_train/rgt_loss", rgt_loss_meter.avg, self._epoch, ) if val_loss_meter.avg is not None: self._tb_writer.add_scalar( "prooftrace_lm_train/val_loss", val_loss_meter.avg, self._epoch, ) self._tb_writer.add_scalar( "prooftrace_lm_train/update_count", len(infos), self._epoch, ) self._epoch += 1 if self._epoch % 100 == 0: self.save()
def run_once(self, ): for m in self._modules: self._modules[m].train() run_start = time.time() if self._epoch == 0: self._syn.broadcast({'config': self._config}) self._optimizer.zero_grad() infos = self._syn.reduce(self._device, self._min_update_count) if len(infos) == 0: if self._epoch == 0: self._epoch += 1 time.sleep(1) return self._optimizer.step() self._syn.broadcast({'config': self._config}) if self._last_update is not None: update_delta = time.time() - self._last_update else: update_delta = 0.0 self._last_update = time.time() frame_count_meter = Meter() match_count_meter = Meter() run_length_meter = Meter() demo_length_avg_meter = Meter() demo_length_max_meter = Meter() demo_delta_meter = Meter() stp_reward_meter = Meter() mtc_reward_meter = Meter() fnl_reward_meter = Meter() tot_reward_meter = Meter() act_loss_meter = Meter() val_loss_meter = Meter() act_entropy_meter = Meter() ptr_entropy_meter = Meter() for info in infos: frame_count_meter.update(info['frame_count']) if 'match_count' in info: match_count_meter.update(info['match_count']) if 'run_length' in info: run_length_meter.update(info['run_length']) if 'demo_length_avg' in info: demo_length_avg_meter.update(info['demo_length_avg']) if 'demo_length_max' in info: demo_length_max_meter.update(info['demo_length_max']) if 'demo_delta' in info: demo_delta_meter.update(info['demo_delta']) tot_reward = 0.0 has_reward = False if 'stp_reward' in info: stp_reward_meter.update(info['stp_reward']) tot_reward += info['stp_reward'] has_reward = True if 'mtc_reward' in info: mtc_reward_meter.update(info['mtc_reward']) tot_reward += info['mtc_reward'] has_reward = True if 'fnl_reward' in info: fnl_reward_meter.update(info['fnl_reward']) tot_reward += info['fnl_reward'] has_reward = True if has_reward: tot_reward_meter.update(tot_reward) act_loss_meter.update(info['act_loss']) val_loss_meter.update(info['val_loss']) act_entropy_meter.update(info['act_entropy']) ptr_entropy_meter.update(info['ptr_entropy']) Log.out( "PROOFTRACE PPO SYN RUN", { 'epoch': self._epoch, 'run_time': "{:.2f}".format(time.time() - run_start), 'update_count': len(infos), 'frame_count': frame_count_meter.sum, 'update_delta': "{:.2f}".format(update_delta), 'match_count': "{:.2f}".format(match_count_meter.avg or 0.0), 'run_length': "{:.2f}".format(run_length_meter.avg or 0.0), 'demo_length': "{:.2f}/{:.0f}".format( demo_length_avg_meter.avg or 0.0, demo_length_max_meter.max or 0.0, ), 'demo_delta': "{:.4f}".format(demo_delta_meter.avg or 0.0), 'stp_reward': "{:.4f}".format(stp_reward_meter.avg or 0.0), 'mtc_reward': "{:.4f}".format(mtc_reward_meter.avg or 0.0), 'fnl_reward': "{:.4f}".format(fnl_reward_meter.avg or 0.0), 'tot_reward': "{:.4f}".format(tot_reward_meter.avg or 0.0), 'act_loss': "{:.4f}".format(act_loss_meter.avg or 0.0), 'val_loss': "{:.4f}".format(val_loss_meter.avg or 0.0), 'act_entropy': "{:.4f}".format(act_entropy_meter.avg or 0.0), 'ptr_entropy': "{:.4f}".format(ptr_entropy_meter.avg or 0.0), }) if self._tb_writer is not None: if len(infos) > 0: self._tb_writer.add_scalar( "prooftrace_ppo_train/update_delta", update_delta, self._epoch, ) if match_count_meter.avg: self._tb_writer.add_scalar( "prooftrace_ppo_train/match_count", match_count_meter.avg, self._epoch, ) if run_length_meter.avg: self._tb_writer.add_scalar( "prooftrace_ppo_train/run_length", run_length_meter.avg, self._epoch, ) if demo_length_avg_meter.avg: self._tb_writer.add_scalar( "prooftrace_ppo_train/demo_length_avg", demo_length_avg_meter.avg, self._epoch, ) if demo_length_max_meter.avg: self._tb_writer.add_scalar( "prooftrace_ppo_train/demo_length_max", demo_length_max_meter.max, self._epoch, ) if demo_delta_meter.avg: self._tb_writer.add_scalar( "prooftrace_ppo_train/demo_delta", demo_delta_meter.avg, self._epoch, ) self._tb_writer.add_scalar( "prooftrace_ppo_train/act_loss", act_loss_meter.avg, self._epoch, ) self._tb_writer.add_scalar( "prooftrace_ppo_train/val_loss", val_loss_meter.avg, self._epoch, ) self._tb_writer.add_scalar( "prooftrace_ppo_train/act_entropy", act_entropy_meter.avg, self._epoch, ) self._tb_writer.add_scalar( "prooftrace_ppo_train/ptr_entropy", ptr_entropy_meter.avg, self._epoch, ) if stp_reward_meter.avg: self._tb_writer.add_scalar( "prooftrace_ppo_train/stp_reward", stp_reward_meter.avg, self._epoch, ) if mtc_reward_meter.avg: self._tb_writer.add_scalar( "prooftrace_ppo_train/mtc_reward", mtc_reward_meter.avg, self._epoch, ) if fnl_reward_meter.avg: self._tb_writer.add_scalar( "prooftrace_ppo_train/fnl_reward", fnl_reward_meter.avg, self._epoch, ) if tot_reward_meter.avg: self._tb_writer.add_scalar( "prooftrace_ppo_train/tot_reward", tot_reward_meter.avg, self._epoch, ) self._tb_writer.add_scalar( "prooftrace_ppo_train/frame_count", frame_count_meter.sum, self._epoch, ) self._epoch += 1 if self._epoch % 10 == 0: self.save()
def run_once(self, ): for m in self._model.modules(): self._model.modules()[m].train() run_start = time.time() self._policy_optimizer.zero_grad() infos = self._syn.reduce(self._device, self._min_update_count) if len(infos) == 0: time.sleep(1) return self._policy_optimizer.step() self._syn.broadcast({'config': self._config}) if self._last_update is not None: update_delta = time.time() - self._last_update else: update_delta = 0.0 self._last_update = time.time() act_loss_meter = Meter() lft_loss_meter = Meter() rgt_loss_meter = Meter() test_act_loss_meter = Meter() test_lft_loss_meter = Meter() test_rgt_loss_meter = Meter() for info in infos: if 'act_loss' in info: act_loss_meter.update(info['act_loss']) if 'lft_loss' in info: lft_loss_meter.update(info['lft_loss']) if 'rgt_loss' in info: rgt_loss_meter.update(info['rgt_loss']) if 'test_act_loss' in info: test_act_loss_meter.update(info['test_act_loss']) if 'test_lft_loss' in info: test_lft_loss_meter.update(info['test_lft_loss']) if 'test_rgt_loss' in info: test_rgt_loss_meter.update(info['test_rgt_loss']) Log.out( "PROOFTRACE SYN RUN", { 'epoch': self._epoch, 'run_time': "{:.2f}".format(time.time() - run_start), 'update_count': len(infos), 'update_delta': "{:.2f}".format(update_delta), 'act_loss': "{:.4f}".format(act_loss_meter.avg or 0.0), 'lft_loss': "{:.4f}".format(lft_loss_meter.avg or 0.0), 'rgt_loss': "{:.4f}".format(rgt_loss_meter.avg or 0.0), 'test_act_loss': "{:.4f}".format(test_act_loss_meter.avg or 0.0), 'test_lft_loss': "{:.4f}".format(test_lft_loss_meter.avg or 0.0), 'test_rgt_loss': "{:.4f}".format(test_rgt_loss_meter.avg or 0.0), }) if self._tb_writer is not None: self._tb_writer.add_scalar( "prooftrace_lm_train/update_delta", update_delta, self._epoch, ) self._tb_writer.add_scalar( "prooftrace_lm_train/update_count", len(infos), self._epoch, ) if act_loss_meter.avg is not None: self._tb_writer.add_scalar( "prooftrace_lm_train/act_loss", act_loss_meter.avg, self._epoch, ) if lft_loss_meter.avg is not None: self._tb_writer.add_scalar( "prooftrace_lm_train/lft_loss", lft_loss_meter.avg, self._epoch, ) if rgt_loss_meter.avg is not None: self._tb_writer.add_scalar( "prooftrace_lm_train/rgt_loss", rgt_loss_meter.avg, self._epoch, ) if test_act_loss_meter.avg is not None: self._tb_writer.add_scalar( "prooftrace_lm_test/act_loss", test_act_loss_meter.avg, self._epoch, ) if test_lft_loss_meter.avg is not None: self._tb_writer.add_scalar( "prooftrace_lm_test/lft_loss", test_lft_loss_meter.avg, self._epoch, ) if test_rgt_loss_meter.avg is not None: self._tb_writer.add_scalar( "prooftrace_lm_test/rgt_loss", test_rgt_loss_meter.avg, self._epoch, ) self._epoch += 1 if self._epoch % 100 == 0: self.save()
def run_once( self, epoch, ): for m in self._modules: self._modules[m].train() info = self._ack.fetch(self._device, epoch == 0) if info is not None: self.update(info['config']) stp_reward_meter = Meter() mtc_reward_meter = Meter() fnl_reward_meter = Meter() act_loss_meter = Meter() val_loss_meter = Meter() act_entropy_meter = Meter() ptr_entropy_meter = Meter() match_count_meter = Meter() run_length_meter = Meter() demo_length_meter = Meter() demo_delta_meter = Meter() frame_count = 0 for step in range(self._rollout_size): with torch.no_grad(): (idx, act, arg) = self._rollouts.observations[step] action_embeds = self._modules['E'](act).detach() argument_embeds = self._modules['E'](arg).detach() hiddens = self._modules['T'](action_embeds, argument_embeds) heads = torch.cat( [hiddens[i][idx[i]].unsqueeze(0) for i in range(len(idx))], dim=0) targets = torch.cat([ action_embeds[i][0].unsqueeze(0) for i in range(len(idx)) ], dim=0) prd_actions, prd_lefts, prd_rights = \ self._modules['PH'](heads, hiddens, targets) values = \ self._modules['VH'](heads, targets) actions, count = self._pool.explore( prd_actions, prd_lefts, prd_rights, self._explore_alpha, self._explore_beta, self._explore_beta_width, ) frame_count += count observations, rewards, dones, infos = self._pool.step( [tuple(a) for a in actions.detach().cpu().numpy()], self._step_reward_prob, self._match_reward_prob, self._reset_gamma, self._fixed_gamma, ) frame_count += actions.size(0) for i, info in enumerate(infos): if 'match_count' in info: assert dones[i] match_count_meter.update(info['match_count']) if 'run_length' in info: assert dones[i] run_length_meter.update(info['run_length']) if 'demo_length' in info: assert dones[i] demo_length_meter.update(info['demo_length']) if 'demo_delta' in info: assert dones[i] demo_delta_meter.update(info['demo_delta']) log_probs = torch.cat(( prd_actions.gather(1, actions[:, 0].unsqueeze(1)), prd_lefts.gather(1, actions[:, 1].unsqueeze(1)), prd_rights.gather(1, actions[:, 2].unsqueeze(1)), ), dim=1) for i, r in enumerate(rewards): self._episode_stp_reward[i] += r[0] self._episode_mtc_reward[i] += r[1] self._episode_fnl_reward[i] += r[2] if dones[i]: stp_reward_meter.update(self._episode_stp_reward[i]) mtc_reward_meter.update(self._episode_mtc_reward[i]) fnl_reward_meter.update(self._episode_fnl_reward[i]) self._episode_stp_reward[i] = 0.0 self._episode_mtc_reward[i] = 0.0 self._episode_fnl_reward[i] = 0.0 self._rollouts.insert( step, observations, actions.detach(), log_probs.detach(), values.detach(), torch.tensor( [(r[0] + r[1] + r[2]) for r in rewards], dtype=torch.int64, ).unsqueeze(1).to(self._device), torch.tensor([[0.0] if d else [1.0] for d in dones], ).to(self._device), ) with torch.no_grad(): (idx, act, arg) = self._rollouts.observations[-1] action_embeds = self._modules['E'](act) argument_embeds = self._modules['E'](arg) hiddens = self._modules['T'](action_embeds, argument_embeds) heads = torch.cat( [hiddens[i][idx[i]].unsqueeze(0) for i in range(len(idx))], dim=0) targets = torch.cat( [action_embeds[i][0].unsqueeze(0) for i in range(len(idx))], dim=0) values = \ self._modules['VH'](heads, targets) self._rollouts.compute_returns(values.detach()) advantages = \ self._rollouts.returns[:-1] - self._rollouts.values[:-1] advantages = \ (advantages - advantages.mean()) / (advantages.std() + 1e-5) ignored = False for e in range(self._epoch_count): if ignored: continue generator = self._rollouts.generator(advantages) for batch in generator: if ignored: continue rollout_observations, \ rollout_actions, \ rollout_values, \ rollout_returns, \ rollout_masks, \ rollout_log_probs, \ rollout_advantages = batch (idx, act, arg) = rollout_observations action_embeds = self._modules['E'](act) argument_embeds = self._modules['E'](arg) hiddens = self._modules['T'](action_embeds, argument_embeds) heads = torch.cat( [hiddens[i][idx[i]].unsqueeze(0) for i in range(len(idx))], dim=0) targets = torch.cat([ action_embeds[i][0].unsqueeze(0) for i in range(len(idx)) ], dim=0) prd_actions, prd_lefts, prd_rights = \ self._modules['PH'](heads, hiddens, targets) values = \ self._modules['VH'](heads, targets) log_probs = torch.cat(( prd_actions.gather(1, rollout_actions[:, 0].unsqueeze(1)), prd_lefts.gather(1, rollout_actions[:, 1].unsqueeze(1)), prd_rights.gather(1, rollout_actions[:, 2].unsqueeze(1)), ), dim=1) act_entropy = -((prd_actions * torch.exp(prd_actions)).mean()) ptr_entropy = -((prd_lefts * torch.exp(prd_lefts)).mean() + (prd_rights * torch.exp(prd_rights)).mean()) # Clipped action loss. ratio = torch.exp(log_probs - rollout_log_probs) action_loss = -torch.min( ratio * rollout_advantages, torch.clamp(ratio, 1.0 - self._clip, 1.0 + self._clip) * rollout_advantages, ).mean() # Clipped value loss. # clipped_values = rollout_values + \ # (values - rollout_values).clamp(-self._clip, self._clip) # value_loss = torch.max( # F.mse_loss(values, rollout_returns), # F.mse_loss(clipped_values, rollout_returns), # ) value_loss = F.mse_loss(values, rollout_returns) # Log.out("RATIO/ADV/LOSS", { # 'clipped_ratio': torch.clamp( # ratio, 1.0 - self._clip, 1.0 + self._clip # ).mean().item(), # 'ratio': ratio.mean().item(), # 'advantages': rollout_advantages.mean().item(), # 'action_loss': action_loss.item(), # }) if (abs(action_loss.item()) > 10e2 or abs(value_loss.item()) > 10e5 or math.isnan(value_loss.item()) or math.isnan(act_entropy.item()) or math.isnan(ptr_entropy.item())): Log.out( "IGNORING", { 'epoch': epoch, 'act_loss': "{:.4f}".format(action_loss.item()), 'val_loss': "{:.4f}".format(value_loss.item()), 'act_entropy': "{:.4f}".format(act_entropy.item()), 'ptr_entropy': "{:.4f}".format(ptr_entropy.item()), }) ignored = True else: # Backward pass. for m in self._modules: self._modules[m].zero_grad() (action_loss + self._value_coeff * value_loss - (self._act_entropy_coeff * act_entropy + self._ptr_entropy_coeff * ptr_entropy)).backward() if self._grad_norm_max > 0.0: torch.nn.utils.clip_grad_norm_( self._modules['E'].parameters(), self._grad_norm_max, ) torch.nn.utils.clip_grad_norm_( self._modules['T'].parameters(), self._grad_norm_max, ) torch.nn.utils.clip_grad_norm_( self._modules['VH'].parameters(), self._grad_norm_max, ) torch.nn.utils.clip_grad_norm_( self._modules['PH'].parameters(), self._grad_norm_max, ) act_loss_meter.update(action_loss.item()) val_loss_meter.update(value_loss.item()) act_entropy_meter.update(act_entropy.item()) ptr_entropy_meter.update(ptr_entropy.item()) info = { 'frame_count': frame_count, 'act_loss': act_loss_meter.avg, 'val_loss': val_loss_meter.avg, 'act_entropy': act_entropy_meter.avg, 'ptr_entropy': ptr_entropy_meter.avg, } if match_count_meter.avg: info['match_count'] = match_count_meter.avg if run_length_meter.avg: info['run_length'] = run_length_meter.avg if demo_length_meter.avg: info['demo_length_avg'] = demo_length_meter.avg if demo_length_meter.max: info['demo_length_max'] = demo_length_meter.max if demo_delta_meter.avg: info['demo_delta'] = demo_delta_meter.avg if stp_reward_meter.avg: info['stp_reward'] = stp_reward_meter.avg if mtc_reward_meter.avg: info['mtc_reward'] = mtc_reward_meter.avg if fnl_reward_meter.avg: info['fnl_reward'] = fnl_reward_meter.avg self._ack.push(info, None) if frame_count > 0: frame_count = 0 info = self._ack.fetch(self._device) if info is not None: self.update(info['config']) self._rollouts.after_update() Log.out( "PROOFTRACE PPO ACK RUN", { 'epoch': epoch, 'ignored': ignored, 'match_count': "{:.2f}".format(match_count_meter.avg or 0.0), 'run_length': "{:.2f}".format(run_length_meter.avg or 0.0), 'demo_length': "{:.2f}/{:.0f}".format( demo_length_meter.avg or 0.0, demo_length_meter.max or 0.0, ), 'demo_delta': "{:.4f}".format(demo_delta_meter.avg or 0.0), 'stp_reward': "{:.4f}".format(stp_reward_meter.avg or 0.0), 'mtc_reward': "{:.4f}".format(mtc_reward_meter.avg or 0.0), 'fnl_reward': "{:.4f}".format(fnl_reward_meter.avg or 0.0), 'act_loss': "{:.4f}".format(act_loss_meter.avg or 0.0), 'val_loss': "{:.4f}".format(val_loss_meter.avg or 0.0), 'act_entropy': "{:.4f}".format(act_entropy_meter.avg or 0.0), 'ptr_entropy': "{:.4f}".format(ptr_entropy_meter.avg or 0.0), })