def _synthesize(self, input: Environment, n_required_output: Optional[int] = None) \ -> Generator[Result[Output], None, None]: with logger.block("_synthesize"): assert n_required_output is None baseline = 0.0 # TODO handle multi-process evaluation # Backup state_dict orig_model_state = { key: value.clone() for key, value in self.model.state_dict().items() } try: idx = 0 n_try = 0 to_rollout = input.clone_without_supervision() to_rollout.to(self.device) while True: rollouts = [] reward = 0.0 with logger.block("rollout"): with torch.no_grad(): self.model.eval() for rollout in logger.iterable_block( "sample", self.synthesizer( to_rollout, n_required_output=self.n_rollout)): yield rollout if not rollout.is_finished: continue for _ in range(rollout.num): output = input.clone() output["ground_truth"] = rollout.output output.mark_as_supervision("ground_truth") r = self.reward(input.clone(), rollout.output) reward += r output["reward"] = torch.tensor(r - baseline) rollouts.append(output) if len(rollouts) == 0: logger.warning("No rollout") n_try += 1 if n_try >= self.max_try_num: return continue if len(rollouts) != self.n_rollout: logger.warning( "#rollout is unexpected: " f"expected={self.n_rollout} actual={len(rollouts)}" ) with logger.block("calculate_baseline"): reward = reward / len(rollouts) m = self.baseline_momentum baseline = (1 - m) * reward + m * baseline with logger.block("train"): self.model.train() with logger.block("collate"): batch = self.collate(rollouts) with logger.block("to"): batch.to(self.device) with logger.block("forward"): self.model.train() output = self.model(batch) loss = self.loss_fn(output) with logger.block("backward"): self.model.zero_grad() loss.backward() with logger.block("optimizer.step"): self.optimizer.step() if idx % 10 == 0: logger.info( f"idx={idx} reward(avg)={baseline} reward={reward} " f"loss={loss.item()}") idx += 1 finally: # Restore state_dict self.model.load_state_dict(orig_model_state) self.optimizer.load_state_dict(self.optimizer_state_dict)
def test_clone_without_supervision(self) -> None: e = Environment() e["key"] = 0 e.mark_as_supervision("key") e2 = e.clone_without_supervision() assert e2.to_dict() == {}