def beam(tree_node, att_ids): # Populate the children of tree_node init_state, init_att_states, init_att_counts = tree_node.state results, state, att_states, att_counts = run_network(tree_node.token_id, init_state, init_att_states, att_ids, init_att_counts) probs = results[0] predict_ids = results[1] for i in range(self.beam_width): tree_node.add_child(BeamSearchTreeNode(predict_ids[0, i], None, (state, att_states, att_counts), probs[0, i]))
def __call__(self, session, depth): def is_identifier(token_id): token = self.inv_map[token_id] if any(token.startswith(p) for p in astwalker.possible_types()): return 1 return 0 def run_network(token_id, mask, state, att_states, att_ids, att_counts): att_mask = attention_masks(self.attns, np.array([mask]), 1) data = (np.array([[token_id]]), np.array([[1]]), np.array([att_mask]), np.array([1])) feed_dict = construct_feed_dict(self.model, data, state, att_states, att_ids, att_counts) results = session.run(evals, feed_dict) results, state, att_states, att_ids, alpha_states, att_counts, lambda_state = extract_results(results, evals, 2, self.model) return results, state, att_states, att_ids, att_counts def beam(tree_node): # Populate the children of tree_node init_state, init_att_states, init_att_ids, init_att_counts = tree_node.state results, state, att_states, att_ids, att_counts = run_network(tree_node.token_id, tree_node.mask, init_state, init_att_states, init_att_ids, init_att_counts) probs = results[0] predict_ids = results[1] for i in range(self.beam_width): tree_node.add_child(BeamSearchTreeNode(predict_ids[0, i], is_identifier(predict_ids[0, i]), (state, att_states, att_ids, att_counts), probs[0, i])) def beam_search_recursive(tree, current_depth): if current_depth < depth: for child in tree.children: beam(child) beam_search_recursive(child, current_depth+1) to_eval = [self.prediction_op[0], self.prediction_op[1]] evals = get_evals(to_eval, self.model) for testcase in all_test_cases: # Pass the context through the network state, att_states, att_ids, att_counts = get_initial_state(self.model) for i, token in enumerate(testcase[:-1]): token_id, mask = map_token(self.map, token) results, state, att_states, att_ids, att_counts = run_network(token_id, mask, state, att_states, att_ids, att_counts) # print([self.inv_map[id] for id in att_ids[0].tolist()[0]]) token_id, mask = map_token(self.map, testcase[-1]) root = BeamSearchTreeNode(token_id, mask, (state, att_states, att_ids, att_counts), 1) beam(root) beam_search_recursive(root, 1) path = find_path(root)[0] print(" ".join([self.inv_map[map_token(self.map, t)[0]] for t in testcase])) for child in sorted(root.children, key=lambda c: c.probability, reverse=True): print("%s ; %f" % (self.inv_map[child.token_id].replace("\n", "<newline>"), child.probability)) print() print(" ".join([self.inv_map[t].replace("\n", "<newline>") for t in path])) print("\n\n§§§§§§§§§§§§§§\n\n")
def __call__(self, session, depth): def run_network(token_id, state, att_states, att_ids, att_counts): att_mask = attention_masks(self.attns, [0], 1) data = (np.array([[token_id]]), np.array([[1]]), np.array([att_mask]), np.array([1]), np.array([1])) feed_dict, _ = construct_feed_dict(self.model, data, state, att_states, att_ids, att_counts) results = session.run(evals, feed_dict) results, state, att_states, _, _, att_counts, _ = extract_results(results, evals, 2, self.model) return results, state, att_states, att_counts def beam(tree_node, att_ids): # Populate the children of tree_node init_state, init_att_states, init_att_counts = tree_node.state results, state, att_states, att_counts = run_network(tree_node.token_id, init_state, init_att_states, att_ids, init_att_counts) probs = results[0] predict_ids = results[1] for i in range(self.beam_width): tree_node.add_child(BeamSearchTreeNode(predict_ids[0, i], None, (state, att_states, att_counts), probs[0, i])) def beam_search_recursive(tree, current_depth, att_ids): if current_depth < depth: for child in tree.children: beam(child, att_ids) beam_search_recursive(child, current_depth+1, att_ids) to_eval = [self.prediction_op[0], self.prediction_op[1]] evals = get_evals(to_eval, self.model) count = 0 accurate = 0 for testcase in all_test_cases: state, att_states, att_ids, att_counts = get_initial_state(self.model) for i, token in enumerate(testcase[:-depth]): try: results, state, att_states, att_counts = run_network(map_token(self.map, token)[0], state, att_states, att_ids, att_counts) root = BeamSearchTreeNode(map_token(self.map, testcase[-1])[0], None, (state, att_states, att_counts), 1) beam(root, att_ids) beam_search_recursive(root, 1, att_ids) path = find_path(root)[0] # The most likely path actual = [map_token(self.map, t)[0] for t in testcase[i+1:i+depth+1]] print("Token: %s" % token) print("Predicted:") print(" ".join([self.inv_map[t].replace("\n", "<newline>") for t in path])) print("Actual:") print(" ".join([self.inv_map[t].replace("\n", "<newline>") for t in actual])) print("\n") count += 1 if path == actual: accurate += 1 except KeyError as e: pass print("Accuracy: %f" % (accurate/count))
def beam(tree_node): feed_dict = { self.model.input_data: np.array([np.array([tree_node.token_id])]), self.model.initial_state: tree_node.state } probabilities, state = session.run( [self.model.predict, self.model.final_state], feed_dict) best_k_indices = best_k(probabilities[0], self.beam_width) for token_idx in best_k_indices: probability = probabilities[0][token_idx] tree_node.add_child( BeamSearchTreeNode(token_idx, state, probability))
def beam_search_k(self, session, token_id, state, k): root = BeamSearchTreeNode(token_id, state, 1) tree = self.beam_search_tree(session, root) paths = find_path(tree, k) return paths