Example #1
0
 def forward(self, entry: Environment) -> Environment:
     entry = cast(Environment, entry.clone())
     if "test_cases" in entry:
         return entry
     query = entry["code"]
     entry["test_cases"] = [(query, None)]
     return entry
Example #2
0
    def forward(self, entry: Environment) -> Environment:
        entry = cast(Environment, entry.clone())
        state = entry["interpreter_state"]
        inputs = state.context
        code = inputs[0]
        entry["code"] = code

        return entry
Example #3
0
    def __call__(self, entry: Environment) -> Environment:
        entry = cast(Environment, entry.clone())
        train = True
        if "train" in entry:
            train = entry["train"]
        if "action_sequence" in entry:
            train = entry.is_supervision("action_sequence")
        if train or self.key not in entry:
            entry[self.key] = self.initial

        return entry
Example #4
0
    def forward(self, entry: Environment) -> List[Environment]:
        ground_truth = entry["ground_truth"]
        inputs = [input for input, _ in entry["test_cases"]]

        retval: List[Environment] = []
        state = self.interpreter.create_state(inputs)
        for code in self.expander.expand(ground_truth):
            xs = cast(Environment, entry.clone())
            xs["ground_truth"] = code
            xs["reference"] = []
            xs["variables"] = []
            xs["interpreter_state"] = state
            state = self.interpreter.execute(code, state)
            retval.append(xs)
        return retval
Example #5
0
    def forward(self, entry: Environment) -> List[Environment]:
        ground_truth = entry["ground_truth"]
        test_cases = entry["test_cases"]
        inputs = [input for input, _ in test_cases]

        retval: List[Environment] = []
        state = self.interpreter.create_state(inputs)
        for code in self.expander.expand(ground_truth):
            xs = cast(Environment, entry.clone())
            xs["reference"] = [
                Token(state.type_environment[v], v, v)
                for v in state.environment.keys()
            ]
            xs["variables"] = [
                state.environment[token.value]
                for token in xs["reference"]
            ]
            xs["ground_truth"] = code
            xs.mark_as_supervision("ground_truth")
            state = self.interpreter.execute(code, state)
            retval.append(xs)

        return retval
    def _synthesize(self, input: Environment, n_required_output: Optional[int] = None) \
            -> Generator[Result[Output], None, None]:
        with logger.block("_synthesize"):
            assert n_required_output is None

            baseline = 0.0

            # TODO handle multi-process evaluation
            # Backup state_dict
            orig_model_state = {
                key: value.clone()
                for key, value in self.model.state_dict().items()
            }

            try:
                idx = 0
                n_try = 0

                to_rollout = input.clone_without_supervision()
                to_rollout.to(self.device)

                while True:
                    rollouts = []
                    reward = 0.0
                    with logger.block("rollout"):
                        with torch.no_grad():
                            self.model.eval()
                            for rollout in logger.iterable_block(
                                    "sample",
                                    self.synthesizer(
                                        to_rollout,
                                        n_required_output=self.n_rollout)):
                                yield rollout
                                if not rollout.is_finished:
                                    continue
                                for _ in range(rollout.num):
                                    output = input.clone()
                                    output["ground_truth"] = rollout.output
                                    output.mark_as_supervision("ground_truth")
                                    r = self.reward(input.clone(),
                                                    rollout.output)
                                    reward += r
                                    output["reward"] = torch.tensor(r -
                                                                    baseline)
                                    rollouts.append(output)

                    if len(rollouts) == 0:
                        logger.warning("No rollout")
                        n_try += 1
                        if n_try >= self.max_try_num:
                            return
                        continue
                    if len(rollouts) != self.n_rollout:
                        logger.warning(
                            "#rollout is unexpected: "
                            f"expected={self.n_rollout} actual={len(rollouts)}"
                        )

                    with logger.block("calculate_baseline"):
                        reward = reward / len(rollouts)
                        m = self.baseline_momentum
                        baseline = (1 - m) * reward + m * baseline

                    with logger.block("train"):
                        self.model.train()
                        with logger.block("collate"):
                            batch = self.collate(rollouts)
                        with logger.block("to"):
                            batch.to(self.device)
                        with logger.block("forward"):
                            self.model.train()
                            output = self.model(batch)
                            loss = self.loss_fn(output)
                        with logger.block("backward"):
                            self.model.zero_grad()
                            loss.backward()
                        with logger.block("optimizer.step"):
                            self.optimizer.step()

                    if idx % 10 == 0:
                        logger.info(
                            f"idx={idx} reward(avg)={baseline} reward={reward} "
                            f"loss={loss.item()}")

                    idx += 1
            finally:
                # Restore state_dict
                self.model.load_state_dict(orig_model_state)
                self.optimizer.load_state_dict(self.optimizer_state_dict)
    def enumerate_samples_per_state(self,
                                    rule_pred: torch.Tensor,
                                    token_pred: torch.Tensor,
                                    reference_pred: torch.Tensor,
                                    next_state: Environment,
                                    state: SamplerState[Environment],
                                    enumeration: Enumeration,
                                    k: Optional[int]) \
            -> Generator[DuplicatedSamplerState[Environment], None, None]:
        def indices(pred: torch.Tensor):
            # 0 is unknown token
            if enumeration == Enumeration.Top:
                _, indices = torch.sort(pred[1:], descending=True)
                if k is not None:
                    indices = indices[:k]
                for index in indices:
                    yield index + 1, 1
            elif enumeration == Enumeration.Random:
                indices = list(range(1, len(pred)))
                if k is not None:
                    indices = indices[:k]
                for index in indices:
                    yield index, 1
            else:
                assert k is not None
                with logger.block("normalize_prob"):
                    s = pred[1:].sum().item()
                    if s < self.eps:
                        return
                    ps = (pred[1:] / s - self.eps).numpy()
                    npred = [max(0, p) for p in ps]
                for i, n in enumerate(self.rng.multinomial(k, npred)):
                    if n == 0:
                        continue
                    yield i + 1, n

        with logger.block("enumerate_samples_per_state"):
            head = state.state["action_sequence"].head
            assert head is not None
            head_field = \
                cast(ExpandTreeRule, cast(
                    ApplyRule,
                    state.state["action_sequence"]
                    .action_sequence[head.action]
                ).rule).children[head.field][1]
            if head_field.constraint == NodeConstraint.Token:
                # Generate token
                ref_ids = self.encoder.batch_encode_raw_value(
                    [x.raw_value for x in state.state["reference"]])
                tokens = list(self.encoder._token_encoder.vocab) + \
                    state.state["reference"]
                # the score will be merged into predefined token
                for i, ids in enumerate(ref_ids):
                    for ref_id in ids:
                        # merge token and reference pred
                        # Add to unknown probability
                        # if there is not the corresponding token.
                        token_pred[ref_id] += reference_pred[i]
                        if ref_id != 0:
                            reference_pred[i] = 0.0
                pred = torch.cat([token_pred, reference_pred], dim=0)

                # CloseVariadicFieldRule is a candidate if variadic fields
                if head_field.is_variadic:
                    close_rule_idx = \
                        self.encoder._rule_encoder.encode(
                            CloseVariadicFieldRule())
                    p = rule_pred[close_rule_idx].item()
                    tokens.append(ApplyRule(CloseVariadicFieldRule()))
                    pred = torch.cat([pred, torch.tensor([p])], dim=0)

                with logger.block("exclude_invalid_tokens"):
                    # token
                    for kind, idxes in self.token_kind_to_idx.items():
                        if kind is not None and \
                                not self.is_subtype(kind,
                                                    head_field.type_name):
                            pred[idxes] = 0.0
                    # reference
                    for x, (p, token) in enumerate(
                            zip(pred[len(token_pred):],
                                tokens[len(token_pred):])):
                        x += len(token_pred)
                        if not isinstance(token, ApplyRule):
                            if isinstance(token, Token):
                                t = token.kind
                            else:
                                t = token[0]
                            if t is not None and \
                                    not self.is_subtype(t,
                                                        head_field.type_name):
                                pred[x] = 0.0

                n_action = 0
                for x, n in logger.iterable_block("sample-tokens",
                                                  indices(pred)):
                    # Finish enumeration
                    if n_action == k:
                        return

                    p = pred[x].item()
                    token = tokens[x]

                    if isinstance(token, ApplyRule):
                        action: Action = token
                    elif isinstance(token, Token):
                        action = GenerateToken(token.kind, token.raw_value)
                    else:
                        action = GenerateToken(token[0], token[1])

                    if p == 0.0:
                        continue
                    elif p < self.eps:
                        lp = np.log(self.eps)
                    else:
                        lp = np.log(p)

                    n_action += n
                    next_state = next_state.clone()
                    # TODO we may have to clear outputs
                    next_state["action_sequence"] = \
                        LazyActionSequence(
                            state.state["action_sequence"], action)
                    yield DuplicatedSamplerState(
                        SamplerState(state.score + lp, next_state), n)
            else:
                # Apply rule
                with logger.block("exclude_invalid_rules"):
                    # expand tree rule
                    for kind, idxes in self.rule_kind_to_idx.items():
                        if not (kind is not None and self.is_subtype(
                                kind, head_field.type_name)):
                            rule_pred[idxes] = 0.0
                    # CloseVariadicField
                    idx = self.encoder._rule_encoder.encode(
                        CloseVariadicFieldRule())
                    if not (head_field is not None and head_field.is_variadic):
                        rule_pred[idx] = 0.0

                n_rule = 0
                for x, n in logger.iterable_block("sample-rule",
                                                  indices(rule_pred)):
                    # Finish enumeration
                    if n_rule == k:
                        return

                    p = rule_pred[x].item()
                    if p == 0.0:
                        continue
                    elif p < self.eps:
                        lp = np.log(self.eps)
                    else:
                        lp = np.log(p)

                    n_rule += n
                    rule = self.encoder._rule_encoder.vocab[x]

                    next_state = next_state.clone()
                    next_state["action_sequence"] = \
                        LazyActionSequence(
                            state.state["action_sequence"],
                            ApplyRule(rule))
                    yield DuplicatedSamplerState(
                        SamplerState(state.score + lp, next_state), n)
Example #8
0
 def test_clone(self) -> None:
     e = Environment()
     e2 = e.clone()
     e["key"] = 0
     assert e2.to_dict() == {}