def batch_k_samples(self, states: List[SamplerState[State]], ks: List[int]) \ -> Generator[DuplicatedSamplerState[State], None, None]: with logger.block("batch_k_samples"): self.value_network.eval() outputs = [] value_network_inputs = [] for state in self.sampler.batch_k_samples(states, ks): input = self.transform(state.state.state) outputs.append(state) value_network_inputs.append(input) if len(outputs) == self.batch_size: with torch.no_grad(), logger.block("calculate_value"): value = self.value_network( self.collate.collate(value_network_inputs)) for value, output in zip(value, outputs): yield DuplicatedSamplerState( SamplerState(value.item(), output.state.state), output.num) outputs = [] value_network_inputs = [] if len(outputs) != 0: with torch.no_grad(), logger.block("calculate_value"): value = self.value_network( self.collate.collate(value_network_inputs)) for value, output in zip(value, outputs): yield DuplicatedSamplerState( SamplerState(value.item(), output.state.state), output.num)
def test_rescore(self): def transform(state: str) -> Environment: return Environment({"x": torch.tensor([int(state)])}) collate = Collate(x=CollateOptions(False, 0, 0)) sampler = SamplerWithValueNetwork(MockSampler(), transform, collate, MockValueNetwork()) zero = SamplerState(0, sampler.initialize(0)) samples = list(sampler.batch_k_samples([zero], [3])) assert [ DuplicatedSamplerState(SamplerState(0, "00"), 1), DuplicatedSamplerState(SamplerState(1, "01"), 1), DuplicatedSamplerState(SamplerState(2, "02"), 1) ] == samples
def _synthesize(self, input: Input, n_required_output: Optional[int] = None) \ -> Generator[Result[Output], None, None]: with logger.block("_synthesize"): initial_state = DuplicatedSamplerState( SamplerState(0.0, self.sampler.initialize(input)), 1) for output in self._search(input, initial_state): yield output
def _synthesize(self, input: Input, n_required_output: Optional[int] = None) \ -> Generator[Result[Output], None, None]: with logger.block("_synthesize"): # Start from empty sequence states = [SamplerState(0.0, self.sampler.initialize(input))] k = self.beam_size steps = 0 while steps < self.max_step_size and k > 0: if len(states) == 0: return next_states = [] for next_state in self.sampler.top_k_samples(states, k): output_opt = self.sampler.create_output( input, next_state.state.state) if output_opt is not None: output, is_finished = output_opt if steps == self.max_step_size - 1: # The step is last is_finished = True yield Result(output, next_state.state.score, is_finished, 1) if is_finished: k -= 1 else: next_states.append(next_state.state) else: next_states.append(next_state.state) states = next_states steps += 1
def top_k_samples(self, states: List[SamplerState[Tuple[str, List[str]]]], k: int): for s in states[:k]: elems = len("".join(s.state[1])) if elems < len(s.state[0]): gt = s.state[0][elems] yield DuplicatedSamplerState( SamplerState(s.score + 0.0, (s.state[0], s.state[1] + [gt])), 1) s = states[0] for i in range(k - len(states)): x = chr(i + ord('0')) yield DuplicatedSamplerState( SamplerState(s.score - i - 1, (s.state[0], s.state[1] + [x])), 1)
def test_ast_set_sample(self): asts = ["c0", "c1", "c2"] sampler = SequentialProgramSampler(MockSynthesizer(asts), transform_input, Collate(), MockEncoder(), MockExpander(), MockInterpreter()) zero = SamplerState(0, sampler.initialize([(None, None)])) samples = list(sampler.batch_k_samples([zero], [3])) samples.sort(key=lambda x: -x.state.score) assert 3 == len(samples) assert samples[0] == DuplicatedSamplerState( SamplerState( 1, Environment({ "test_cases": [(None, None)], "reference": [Token(None, str(asts[0]), str(asts[0]))], "variables": [["#" + str(asts[0])]], "interpreter_state": BatchedState({str(asts[0]): None}, {str(asts[0]): ["#" + str(asts[0])]}, [str(asts[0])], ["#" + str(asts[0])]) })), 1) assert DuplicatedSamplerState( SamplerState( 0.5, Environment({ "test_cases": [(None, None)], "reference": [Token(None, str(asts[1]), str(asts[1]))], "variables": [["#" + str(asts[1])]], "interpreter_state": BatchedState({str(asts[1]): None}, {str(asts[1]): ["#" + str(asts[1])]}, [str(asts[1])], ["#" + str(asts[1])]) })), 1) == samples[1] assert DuplicatedSamplerState( SamplerState( 1.0 / 3, Environment({ "test_cases": [(None, None)], "reference": [Token(None, str(asts[2]), str(asts[2]))], "variables": [["#" + str(asts[2])]], "interpreter_state": BatchedState({str(asts[2]): None}, {str(asts[2]): ["#" + str(asts[2])]}, [str(asts[2])], ["#" + str(asts[2])]) })), 1) == samples[2]
def all_samples(self, states: List[SamplerState[Tuple[str, List[int]]]], sorted: bool = True): for s in states: elems = len(s.state[1]) for i in range(3 - elems): yield DuplicatedSamplerState( SamplerState(s.score + (3 - i), (s.state[0], s.state[1] + [i])), 1)
def test_rule(self): rule_prob = torch.tensor([ [[ 1.0, # unknown 1.0, # close variadic field 0.2, # Root2X 0.1, # Root2Y 1.0, # X2Y_list 1.0, # Ysub2Str ]], [[ 1.0, # unknown 1.0, # close variadic field 1.0, # Root2X 1.0, # Root2Y 0.5, # X2Y_list 1.0, # Ysub2Str ]]]) token_prob = torch.tensor([[[]], [[]]]) reference_prob = torch.tensor([[[]], [[]]]) sampler = ActionSequenceSampler( create_encoder(), is_subtype, create_transform_input([]), transform_action_sequence, collate, Module(encoder_module, DecoderModule(rule_prob, token_prob, reference_prob)) ) s = SamplerState(0.0, sampler.initialize(Environment())) topk_results = list(sampler.top_k_samples([s], 1)) assert 1 == len(topk_results) assert 1 == topk_results[0].state.state["length"].item() assert np.allclose(log(0.2), topk_results[0].state.score) random_results = list(sampler.batch_k_samples([s], [1])) assert 1 == len(random_results) assert 1 == random_results[0].state.state["length"].item() assert \ log(0.1) - 1e-5 <= random_results[0].state.score <= log(0.2) + 1e-5 all_results = list(sampler.all_samples([s])) assert 2 == len(all_results) assert 1 == all_results[0].state.state["length"].item() assert np.allclose(log(0.2), all_results[0].state.score) assert np.allclose(log(0.1), all_results[1].state.score) all_results = list(sampler.all_samples([s], sorted=False)) assert 1 == all_results[0].state.state["length"].item() assert \ log(0.1) - 1e-5 <= all_results[0].state.score <= log(0.2) + 1e-5 next = list(sampler.top_k_samples( [s.state for s in topk_results], 1))[0] assert 2 == next.state.state["length"].item() assert np.allclose(log(0.2) + log(0.5), next.state.score)
def batch_k_samples(self, states: List[SamplerState[Tuple[str, str]]], ks: List[int]): for state, k in zip(states, ks): elems = len(state.state[1]) if len(state.state[0]) > elems: gt: Optional[str] = state.state[0][elems] else: gt = None for _ in range(k // len(states)): x = self.rng.choice(['x', 'y', '0', '1']) score = 0.0 if gt == x else -1.0 yield DuplicatedSamplerState( SamplerState(state.score + score, (state.state[0], state.state[1] + x)), 1)
def batch_k_samples(self, states: List[SamplerState[Environment]], ks: List[int]) \ -> Generator[DuplicatedSamplerState[Environment], None, None]: assert all([len(state.state._supervisions) for state in states]) == 0 originals = [state.state.clone() for state in states] with logger.block("batch_k_samples"): for original, state, k in zip(originals, states, ks): if k == 0: continue cnt = 0 for result in self.synthesizer(state.state, n_required_output=k): new_state = original.clone() # Clear reference and variables new_state["reference"] = [] new_state["variables"] = [] new_state["interpreter_state"] = \ self.interpreter.execute( result.output, state.state["interpreter_state"] ) for code in \ new_state["interpreter_state"].environment: new_state["reference"].append( Token[Kind, Code]( new_state["interpreter_state"] .type_environment[code], code, code) ) new_state["variables"].append( new_state["interpreter_state"] .environment[code] ) yield DuplicatedSamplerState( SamplerState(result.score, new_state), result.num) cnt += 1 if cnt == k: break
def test_reference(self): torch.manual_seed(0) rule_prob = torch.tensor([ [[ 1.0, # unknown 1.0, # close variadic field 0.1, # Root2X 0.2, # Root2Y 1.0, # X2Y_list 1.0, # Ysub2Str ]], [[ 1.0, # unknown 1.0, # close variadic field 1.0, # Root2X 1.0, # Root2Y 1.0, # X2Y_list 1.0, # Ysub2Str ]], [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]]) token_prob = torch.tensor([ [[0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0]], [[ 1.0, # Unknown 0.8, # x 0.2, # 1 ]]]) reference_prob = torch.tensor( [[[0.0, 0.0]], [[0.0, 0.0]], [[0.1, 0.1]]]) sampler = ActionSequenceSampler( create_encoder(), is_subtype, create_transform_input([Token("Str", "x", "x"), Token(None, "x", "x")]), transform_action_sequence, collate, Module(encoder_module, DecoderModule(rule_prob, token_prob, reference_prob)), rng=np.random.RandomState(0) ) s = SamplerState(0.0, sampler.initialize(Environment())) results = [s.state for s in sampler.top_k_samples([s], 1)] results = [s.state for s in sampler.top_k_samples(results, 1)] topk_results = list(sampler.top_k_samples(results, 1)) assert 1 == len(topk_results) assert 3 == topk_results[0].state.state["length"].item() assert np.allclose(log(0.2) + log(1.), topk_results[0].state.score) random_results = list(sampler.batch_k_samples(results[:1], [1])) assert 1 == len(random_results) assert 3 == random_results[0].state.state["length"].item() assert np.allclose(log(0.2) + log(1.0), random_results[0].state.score) all_results = list(sampler.all_samples(results)) assert 1 == len(all_results) assert 3 == all_results[0].state.state["length"].item() assert np.allclose(log(0.2) + log(1.), all_results[0].state.score) all_results = list(sampler.all_samples(results, sorted=False)) assert 1 == len(all_results) assert 3 == all_results[0].state.state["length"].item() assert np.allclose(log(0.2) + log(1.), all_results[0].state.score)
def _synthesize(self, input: Input, n_required_output: Optional[int] = None) \ -> Generator[Result[Output], None, None]: with logger.block("_synthesize"): if n_required_output is None: n_initial_particle = self.initial_particle_size else: n_initial_particle = n_required_output initial_state = self.sampler.initialize(input) i = 0 while True: logger.debug( f"start {i} th trial: n_particle={n_initial_particle}") i += 1 # Initialize state n_particle = n_initial_particle particles = [ DuplicatedSamplerState(SamplerState(0.0, initial_state), n_particle) ] step = 0 while step < self.max_step_size and n_particle > 0: # Generate particles samples: Dict[Key, DuplicatedSamplerState[State]] = {} for sample in self.sampler.batch_k_samples( [state.state for state in particles], [state.num for state in particles]): if sample.num == 0: # This sample does not exist continue key = self.to_key(sample.state.state) if key in samples: state = samples[key] samples[key] = \ DuplicatedSamplerState( state.state, state.num + sample.num ) else: samples[key] = sample if len(samples) == 0: # Output last particle with is_finished=True for state in particles: output_opt = \ self.sampler.create_output(input, state.state.state) if output_opt is not None: output, _ = output_opt yield Result(output, state.state.score, True, state.num) break # Resample list_samples = [state for state in samples.values()] log_weights = [ math.log(state.num) + state.state.score for state in list_samples ] probs = [ math.exp(log_weight - max(log_weights)) for log_weight in log_weights ] probs = [p / sum(probs) for p in probs] particles = [] resampled = self.rng.multinomial(n_particle, probs) for state, n in zip(list_samples, resampled): # Create output output_opt = \ self.sampler.create_output(input, state.state.state) if output_opt is not None: output, is_finished = output_opt if step == self.max_step_size - 1: is_finished = True yield Result(output, state.state.score, is_finished, n) else: is_finished = False if is_finished: # Exclude finished particles n_particle -= n if n > 0: particles.append( DuplicatedSamplerState(state.state, n)) step += 1 n_initial_particle *= self.factor if i == self.max_try_num: break
def batch_k_samples(self, states: List[SamplerState[str]], n: List[int]): for state, k in zip(states, n): for i in range(k // len(states)): x = state.state + str(i) yield DuplicatedSamplerState(SamplerState(len(x), x), 1)