Пример #1
0
 def batch_k_samples(self, states: List[SamplerState[State]],
                     ks: List[int]) \
         -> Generator[DuplicatedSamplerState[State],
                      None, None]:
     with logger.block("batch_k_samples"):
         self.value_network.eval()
         outputs = []
         value_network_inputs = []
         for state in self.sampler.batch_k_samples(states, ks):
             input = self.transform(state.state.state)
             outputs.append(state)
             value_network_inputs.append(input)
             if len(outputs) == self.batch_size:
                 with torch.no_grad(), logger.block("calculate_value"):
                     value = self.value_network(
                         self.collate.collate(value_network_inputs))
                 for value, output in zip(value, outputs):
                     yield DuplicatedSamplerState(
                         SamplerState(value.item(), output.state.state),
                         output.num)
                 outputs = []
                 value_network_inputs = []
         if len(outputs) != 0:
             with torch.no_grad(), logger.block("calculate_value"):
                 value = self.value_network(
                     self.collate.collate(value_network_inputs))
             for value, output in zip(value, outputs):
                 yield DuplicatedSamplerState(
                     SamplerState(value.item(), output.state.state),
                     output.num)
Пример #2
0
    def test_rescore(self):
        def transform(state: str) -> Environment:
            return Environment({"x": torch.tensor([int(state)])})

        collate = Collate(x=CollateOptions(False, 0, 0))
        sampler = SamplerWithValueNetwork(MockSampler(), transform, collate,
                                          MockValueNetwork())
        zero = SamplerState(0, sampler.initialize(0))
        samples = list(sampler.batch_k_samples([zero], [3]))
        assert [
            DuplicatedSamplerState(SamplerState(0, "00"), 1),
            DuplicatedSamplerState(SamplerState(1, "01"), 1),
            DuplicatedSamplerState(SamplerState(2, "02"), 1)
        ] == samples
Пример #3
0
 def _synthesize(self, input: Input, n_required_output: Optional[int] = None) \
         -> Generator[Result[Output], None, None]:
     with logger.block("_synthesize"):
         initial_state = DuplicatedSamplerState(
             SamplerState(0.0, self.sampler.initialize(input)), 1)
         for output in self._search(input, initial_state):
             yield output
Пример #4
0
    def _synthesize(self, input: Input, n_required_output: Optional[int] = None) \
            -> Generator[Result[Output], None, None]:
        with logger.block("_synthesize"):
            # Start from empty sequence
            states = [SamplerState(0.0, self.sampler.initialize(input))]

            k = self.beam_size
            steps = 0
            while steps < self.max_step_size and k > 0:
                if len(states) == 0:
                    return
                next_states = []

                for next_state in self.sampler.top_k_samples(states, k):
                    output_opt = self.sampler.create_output(
                        input, next_state.state.state)
                    if output_opt is not None:
                        output, is_finished = output_opt
                        if steps == self.max_step_size - 1:
                            # The step is last
                            is_finished = True
                        yield Result(output, next_state.state.score,
                                     is_finished, 1)
                        if is_finished:
                            k -= 1
                        else:
                            next_states.append(next_state.state)
                    else:
                        next_states.append(next_state.state)
                states = next_states
                steps += 1
Пример #5
0
 def top_k_samples(self, states: List[SamplerState[Tuple[str, List[str]]]],
                   k: int):
     for s in states[:k]:
         elems = len("".join(s.state[1]))
         if elems < len(s.state[0]):
             gt = s.state[0][elems]
             yield DuplicatedSamplerState(
                 SamplerState(s.score + 0.0,
                              (s.state[0], s.state[1] + [gt])),
                 1)
     s = states[0]
     for i in range(k - len(states)):
         x = chr(i + ord('0'))
         yield DuplicatedSamplerState(
             SamplerState(s.score - i - 1,
                          (s.state[0], s.state[1] + [x])),
             1)
Пример #6
0
 def test_ast_set_sample(self):
     asts = ["c0", "c1", "c2"]
     sampler = SequentialProgramSampler(MockSynthesizer(asts),
                                        transform_input, Collate(),
                                        MockEncoder(), MockExpander(),
                                        MockInterpreter())
     zero = SamplerState(0, sampler.initialize([(None, None)]))
     samples = list(sampler.batch_k_samples([zero], [3]))
     samples.sort(key=lambda x: -x.state.score)
     assert 3 == len(samples)
     assert samples[0] == DuplicatedSamplerState(
         SamplerState(
             1,
             Environment({
                 "test_cases": [(None, None)],
                 "reference": [Token(None, str(asts[0]), str(asts[0]))],
                 "variables": [["#" + str(asts[0])]],
                 "interpreter_state":
                 BatchedState({str(asts[0]): None},
                              {str(asts[0]): ["#" + str(asts[0])]},
                              [str(asts[0])], ["#" + str(asts[0])])
             })), 1)
     assert DuplicatedSamplerState(
         SamplerState(
             0.5,
             Environment({
                 "test_cases": [(None, None)],
                 "reference": [Token(None, str(asts[1]), str(asts[1]))],
                 "variables": [["#" + str(asts[1])]],
                 "interpreter_state":
                 BatchedState({str(asts[1]): None},
                              {str(asts[1]): ["#" + str(asts[1])]},
                              [str(asts[1])], ["#" + str(asts[1])])
             })), 1) == samples[1]
     assert DuplicatedSamplerState(
         SamplerState(
             1.0 / 3,
             Environment({
                 "test_cases": [(None, None)],
                 "reference": [Token(None, str(asts[2]), str(asts[2]))],
                 "variables": [["#" + str(asts[2])]],
                 "interpreter_state":
                 BatchedState({str(asts[2]): None},
                              {str(asts[2]): ["#" + str(asts[2])]},
                              [str(asts[2])], ["#" + str(asts[2])])
             })), 1) == samples[2]
Пример #7
0
 def all_samples(self,
                 states: List[SamplerState[Tuple[str, List[int]]]],
                 sorted: bool = True):
     for s in states:
         elems = len(s.state[1])
         for i in range(3 - elems):
             yield DuplicatedSamplerState(
                 SamplerState(s.score + (3 - i),
                              (s.state[0], s.state[1] + [i])), 1)
    def test_rule(self):
        rule_prob = torch.tensor([
            [[
                1.0,  # unknown
                1.0,  # close variadic field
                0.2,  # Root2X
                0.1,  # Root2Y
                1.0,  # X2Y_list
                1.0,  # Ysub2Str
            ]],
            [[
                1.0,  # unknown
                1.0,  # close variadic field
                1.0,  # Root2X
                1.0,  # Root2Y
                0.5,  # X2Y_list
                1.0,  # Ysub2Str
            ]]])
        token_prob = torch.tensor([[[]], [[]]])
        reference_prob = torch.tensor([[[]], [[]]])
        sampler = ActionSequenceSampler(
            create_encoder(),
            is_subtype,
            create_transform_input([]), transform_action_sequence,
            collate,
            Module(encoder_module,
                   DecoderModule(rule_prob, token_prob, reference_prob))
        )
        s = SamplerState(0.0, sampler.initialize(Environment()))
        topk_results = list(sampler.top_k_samples([s], 1))
        assert 1 == len(topk_results)
        assert 1 == topk_results[0].state.state["length"].item()
        assert np.allclose(log(0.2), topk_results[0].state.score)
        random_results = list(sampler.batch_k_samples([s], [1]))
        assert 1 == len(random_results)
        assert 1 == random_results[0].state.state["length"].item()
        assert \
            log(0.1) - 1e-5 <= random_results[0].state.score <= log(0.2) + 1e-5
        all_results = list(sampler.all_samples([s]))
        assert 2 == len(all_results)
        assert 1 == all_results[0].state.state["length"].item()
        assert np.allclose(log(0.2), all_results[0].state.score)
        assert np.allclose(log(0.1), all_results[1].state.score)
        all_results = list(sampler.all_samples([s], sorted=False))
        assert 1 == all_results[0].state.state["length"].item()
        assert \
            log(0.1) - 1e-5 <= all_results[0].state.score <= log(0.2) + 1e-5

        next = list(sampler.top_k_samples(
            [s.state for s in topk_results], 1))[0]
        assert 2 == next.state.state["length"].item()
        assert np.allclose(log(0.2) + log(0.5), next.state.score)
Пример #9
0
 def batch_k_samples(self, states: List[SamplerState[Tuple[str, str]]],
                     ks: List[int]):
     for state, k in zip(states, ks):
         elems = len(state.state[1])
         if len(state.state[0]) > elems:
             gt: Optional[str] = state.state[0][elems]
         else:
             gt = None
         for _ in range(k // len(states)):
             x = self.rng.choice(['x', 'y', '0', '1'])
             score = 0.0 if gt == x else -1.0
             yield DuplicatedSamplerState(
                 SamplerState(state.score + score,
                              (state.state[0], state.state[1] + x)), 1)
Пример #10
0
    def batch_k_samples(self, states: List[SamplerState[Environment]],
                        ks: List[int]) \
            -> Generator[DuplicatedSamplerState[Environment],
                         None, None]:
        assert all([len(state.state._supervisions) for state in states]) == 0

        originals = [state.state.clone() for state in states]

        with logger.block("batch_k_samples"):
            for original, state, k in zip(originals, states, ks):
                if k == 0:
                    continue

                cnt = 0
                for result in self.synthesizer(state.state,
                                               n_required_output=k):
                    new_state = original.clone()
                    # Clear reference and variables
                    new_state["reference"] = []
                    new_state["variables"] = []
                    new_state["interpreter_state"] = \
                        self.interpreter.execute(
                            result.output,
                            state.state["interpreter_state"]
                    )
                    for code in \
                            new_state["interpreter_state"].environment:
                        new_state["reference"].append(
                            Token[Kind, Code](
                                new_state["interpreter_state"]
                                .type_environment[code],
                                code, code)
                        )
                        new_state["variables"].append(
                            new_state["interpreter_state"]
                            .environment[code]
                        )
                    yield DuplicatedSamplerState(
                        SamplerState(result.score, new_state), result.num)
                    cnt += 1
                    if cnt == k:
                        break
 def test_reference(self):
     torch.manual_seed(0)
     rule_prob = torch.tensor([
         [[
             1.0,  # unknown
             1.0,  # close variadic field
             0.1,  # Root2X
             0.2,  # Root2Y
             1.0,  # X2Y_list
             1.0,  # Ysub2Str
         ]],
         [[
             1.0,  # unknown
             1.0,  # close variadic field
             1.0,  # Root2X
             1.0,  # Root2Y
             1.0,  # X2Y_list
             1.0,  # Ysub2Str
         ]],
         [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]])
     token_prob = torch.tensor([
         [[0.0, 0.0, 0.0]],
         [[0.0, 0.0, 0.0]],
         [[
             1.0,  # Unknown
             0.8,  # x
             0.2,  # 1
         ]]])
     reference_prob = torch.tensor(
         [[[0.0, 0.0]], [[0.0, 0.0]], [[0.1, 0.1]]])
     sampler = ActionSequenceSampler(
         create_encoder(),
         is_subtype,
         create_transform_input([Token("Str", "x", "x"),
                                 Token(None, "x", "x")]),
         transform_action_sequence,
         collate,
         Module(encoder_module,
                DecoderModule(rule_prob, token_prob, reference_prob)),
         rng=np.random.RandomState(0)
     )
     s = SamplerState(0.0, sampler.initialize(Environment()))
     results = [s.state for s in sampler.top_k_samples([s], 1)]
     results = [s.state for s in sampler.top_k_samples(results, 1)]
     topk_results = list(sampler.top_k_samples(results, 1))
     assert 1 == len(topk_results)
     assert 3 == topk_results[0].state.state["length"].item()
     assert np.allclose(log(0.2) + log(1.),
                        topk_results[0].state.score)
     random_results = list(sampler.batch_k_samples(results[:1], [1]))
     assert 1 == len(random_results)
     assert 3 == random_results[0].state.state["length"].item()
     assert np.allclose(log(0.2) + log(1.0),
                        random_results[0].state.score)
     all_results = list(sampler.all_samples(results))
     assert 1 == len(all_results)
     assert 3 == all_results[0].state.state["length"].item()
     assert np.allclose(log(0.2) + log(1.),
                        all_results[0].state.score)
     all_results = list(sampler.all_samples(results, sorted=False))
     assert 1 == len(all_results)
     assert 3 == all_results[0].state.state["length"].item()
     assert np.allclose(log(0.2) + log(1.),
                        all_results[0].state.score)
Пример #12
0
    def _synthesize(self, input: Input, n_required_output: Optional[int] = None) \
            -> Generator[Result[Output], None, None]:
        with logger.block("_synthesize"):
            if n_required_output is None:
                n_initial_particle = self.initial_particle_size
            else:
                n_initial_particle = n_required_output
            initial_state = self.sampler.initialize(input)
            i = 0
            while True:
                logger.debug(
                    f"start {i} th trial: n_particle={n_initial_particle}")
                i += 1
                # Initialize state
                n_particle = n_initial_particle
                particles = [
                    DuplicatedSamplerState(SamplerState(0.0, initial_state),
                                           n_particle)
                ]
                step = 0
                while step < self.max_step_size and n_particle > 0:
                    # Generate particles
                    samples: Dict[Key, DuplicatedSamplerState[State]] = {}
                    for sample in self.sampler.batch_k_samples(
                        [state.state for state in particles],
                        [state.num for state in particles]):
                        if sample.num == 0:
                            # This sample does not exist
                            continue
                        key = self.to_key(sample.state.state)
                        if key in samples:
                            state = samples[key]
                            samples[key] = \
                                DuplicatedSamplerState(
                                    state.state, state.num + sample.num
                            )
                        else:
                            samples[key] = sample

                    if len(samples) == 0:
                        # Output last particle with is_finished=True
                        for state in particles:
                            output_opt = \
                                self.sampler.create_output(input,
                                                           state.state.state)
                            if output_opt is not None:
                                output, _ = output_opt
                                yield Result(output, state.state.score, True,
                                             state.num)
                        break

                    # Resample
                    list_samples = [state for state in samples.values()]
                    log_weights = [
                        math.log(state.num) + state.state.score
                        for state in list_samples
                    ]
                    probs = [
                        math.exp(log_weight - max(log_weights))
                        for log_weight in log_weights
                    ]
                    probs = [p / sum(probs) for p in probs]
                    particles = []
                    resampled = self.rng.multinomial(n_particle, probs)
                    for state, n in zip(list_samples, resampled):
                        # Create output
                        output_opt = \
                            self.sampler.create_output(input,
                                                       state.state.state)
                        if output_opt is not None:
                            output, is_finished = output_opt
                            if step == self.max_step_size - 1:
                                is_finished = True
                            yield Result(output, state.state.score,
                                         is_finished, n)
                        else:
                            is_finished = False

                        if is_finished:
                            # Exclude finished particles
                            n_particle -= n

                        if n > 0:
                            particles.append(
                                DuplicatedSamplerState(state.state, n))
                    step += 1

                n_initial_particle *= self.factor
                if i == self.max_try_num:
                    break
Пример #13
0
 def batch_k_samples(self, states: List[SamplerState[str]], n: List[int]):
     for state, k in zip(states, n):
         for i in range(k // len(states)):
             x = state.state + str(i)
             yield DuplicatedSamplerState(SamplerState(len(x), x), 1)