def forward(self, entry: Environment) -> Environment: entry = cast(Environment, entry.clone()) if "test_cases" in entry: return entry query = entry["code"] entry["test_cases"] = [(query, None)] return entry
def forward(self, entry: Environment) -> Environment: entry = cast(Environment, entry.clone()) state = entry["interpreter_state"] inputs = state.context code = inputs[0] entry["code"] = code return entry
def __call__(self, entry: Environment) -> Environment: entry = cast(Environment, entry.clone()) train = True if "train" in entry: train = entry["train"] if "action_sequence" in entry: train = entry.is_supervision("action_sequence") if train or self.key not in entry: entry[self.key] = self.initial return entry
def forward(self, entry: Environment) -> List[Environment]: ground_truth = entry["ground_truth"] inputs = [input for input, _ in entry["test_cases"]] retval: List[Environment] = [] state = self.interpreter.create_state(inputs) for code in self.expander.expand(ground_truth): xs = cast(Environment, entry.clone()) xs["ground_truth"] = code xs["reference"] = [] xs["variables"] = [] xs["interpreter_state"] = state state = self.interpreter.execute(code, state) retval.append(xs) return retval
def forward(self, entry: Environment) -> List[Environment]: ground_truth = entry["ground_truth"] test_cases = entry["test_cases"] inputs = [input for input, _ in test_cases] retval: List[Environment] = [] state = self.interpreter.create_state(inputs) for code in self.expander.expand(ground_truth): xs = cast(Environment, entry.clone()) xs["reference"] = [ Token(state.type_environment[v], v, v) for v in state.environment.keys() ] xs["variables"] = [ state.environment[token.value] for token in xs["reference"] ] xs["ground_truth"] = code xs.mark_as_supervision("ground_truth") state = self.interpreter.execute(code, state) retval.append(xs) return retval
def _synthesize(self, input: Environment, n_required_output: Optional[int] = None) \ -> Generator[Result[Output], None, None]: with logger.block("_synthesize"): assert n_required_output is None baseline = 0.0 # TODO handle multi-process evaluation # Backup state_dict orig_model_state = { key: value.clone() for key, value in self.model.state_dict().items() } try: idx = 0 n_try = 0 to_rollout = input.clone_without_supervision() to_rollout.to(self.device) while True: rollouts = [] reward = 0.0 with logger.block("rollout"): with torch.no_grad(): self.model.eval() for rollout in logger.iterable_block( "sample", self.synthesizer( to_rollout, n_required_output=self.n_rollout)): yield rollout if not rollout.is_finished: continue for _ in range(rollout.num): output = input.clone() output["ground_truth"] = rollout.output output.mark_as_supervision("ground_truth") r = self.reward(input.clone(), rollout.output) reward += r output["reward"] = torch.tensor(r - baseline) rollouts.append(output) if len(rollouts) == 0: logger.warning("No rollout") n_try += 1 if n_try >= self.max_try_num: return continue if len(rollouts) != self.n_rollout: logger.warning( "#rollout is unexpected: " f"expected={self.n_rollout} actual={len(rollouts)}" ) with logger.block("calculate_baseline"): reward = reward / len(rollouts) m = self.baseline_momentum baseline = (1 - m) * reward + m * baseline with logger.block("train"): self.model.train() with logger.block("collate"): batch = self.collate(rollouts) with logger.block("to"): batch.to(self.device) with logger.block("forward"): self.model.train() output = self.model(batch) loss = self.loss_fn(output) with logger.block("backward"): self.model.zero_grad() loss.backward() with logger.block("optimizer.step"): self.optimizer.step() if idx % 10 == 0: logger.info( f"idx={idx} reward(avg)={baseline} reward={reward} " f"loss={loss.item()}") idx += 1 finally: # Restore state_dict self.model.load_state_dict(orig_model_state) self.optimizer.load_state_dict(self.optimizer_state_dict)
def enumerate_samples_per_state(self, rule_pred: torch.Tensor, token_pred: torch.Tensor, reference_pred: torch.Tensor, next_state: Environment, state: SamplerState[Environment], enumeration: Enumeration, k: Optional[int]) \ -> Generator[DuplicatedSamplerState[Environment], None, None]: def indices(pred: torch.Tensor): # 0 is unknown token if enumeration == Enumeration.Top: _, indices = torch.sort(pred[1:], descending=True) if k is not None: indices = indices[:k] for index in indices: yield index + 1, 1 elif enumeration == Enumeration.Random: indices = list(range(1, len(pred))) if k is not None: indices = indices[:k] for index in indices: yield index, 1 else: assert k is not None with logger.block("normalize_prob"): s = pred[1:].sum().item() if s < self.eps: return ps = (pred[1:] / s - self.eps).numpy() npred = [max(0, p) for p in ps] for i, n in enumerate(self.rng.multinomial(k, npred)): if n == 0: continue yield i + 1, n with logger.block("enumerate_samples_per_state"): head = state.state["action_sequence"].head assert head is not None head_field = \ cast(ExpandTreeRule, cast( ApplyRule, state.state["action_sequence"] .action_sequence[head.action] ).rule).children[head.field][1] if head_field.constraint == NodeConstraint.Token: # Generate token ref_ids = self.encoder.batch_encode_raw_value( [x.raw_value for x in state.state["reference"]]) tokens = list(self.encoder._token_encoder.vocab) + \ state.state["reference"] # the score will be merged into predefined token for i, ids in enumerate(ref_ids): for ref_id in ids: # merge token and reference pred # Add to unknown probability # if there is not the corresponding token. token_pred[ref_id] += reference_pred[i] if ref_id != 0: reference_pred[i] = 0.0 pred = torch.cat([token_pred, reference_pred], dim=0) # CloseVariadicFieldRule is a candidate if variadic fields if head_field.is_variadic: close_rule_idx = \ self.encoder._rule_encoder.encode( CloseVariadicFieldRule()) p = rule_pred[close_rule_idx].item() tokens.append(ApplyRule(CloseVariadicFieldRule())) pred = torch.cat([pred, torch.tensor([p])], dim=0) with logger.block("exclude_invalid_tokens"): # token for kind, idxes in self.token_kind_to_idx.items(): if kind is not None and \ not self.is_subtype(kind, head_field.type_name): pred[idxes] = 0.0 # reference for x, (p, token) in enumerate( zip(pred[len(token_pred):], tokens[len(token_pred):])): x += len(token_pred) if not isinstance(token, ApplyRule): if isinstance(token, Token): t = token.kind else: t = token[0] if t is not None and \ not self.is_subtype(t, head_field.type_name): pred[x] = 0.0 n_action = 0 for x, n in logger.iterable_block("sample-tokens", indices(pred)): # Finish enumeration if n_action == k: return p = pred[x].item() token = tokens[x] if isinstance(token, ApplyRule): action: Action = token elif isinstance(token, Token): action = GenerateToken(token.kind, token.raw_value) else: action = GenerateToken(token[0], token[1]) if p == 0.0: continue elif p < self.eps: lp = np.log(self.eps) else: lp = np.log(p) n_action += n next_state = next_state.clone() # TODO we may have to clear outputs next_state["action_sequence"] = \ LazyActionSequence( state.state["action_sequence"], action) yield DuplicatedSamplerState( SamplerState(state.score + lp, next_state), n) else: # Apply rule with logger.block("exclude_invalid_rules"): # expand tree rule for kind, idxes in self.rule_kind_to_idx.items(): if not (kind is not None and self.is_subtype( kind, head_field.type_name)): rule_pred[idxes] = 0.0 # CloseVariadicField idx = self.encoder._rule_encoder.encode( CloseVariadicFieldRule()) if not (head_field is not None and head_field.is_variadic): rule_pred[idx] = 0.0 n_rule = 0 for x, n in logger.iterable_block("sample-rule", indices(rule_pred)): # Finish enumeration if n_rule == k: return p = rule_pred[x].item() if p == 0.0: continue elif p < self.eps: lp = np.log(self.eps) else: lp = np.log(p) n_rule += n rule = self.encoder._rule_encoder.vocab[x] next_state = next_state.clone() next_state["action_sequence"] = \ LazyActionSequence( state.state["action_sequence"], ApplyRule(rule)) yield DuplicatedSamplerState( SamplerState(state.score + lp, next_state), n)
def test_clone(self) -> None: e = Environment() e2 = e.clone() e["key"] = 0 assert e2.to_dict() == {}