Example #1
0
def latent_store_stuff() -> Tuple[LatentStore, TypeContext, List[Tuple[
    int, int, AstObjectChoiceSet]]]:
    torch.manual_seed(1)
    tc = TypeContext()
    ft = AInixType(tc, "FT")
    fo1 = AInixObject(tc, "FO1", "FT")
    fo2 = AInixObject(tc, "FO2", "FT")
    fo3 = AInixObject(tc, "FO3", "FT")
    bt = AInixType(tc, "BT")
    bo1 = AInixObject(tc, "BO1", "BT")
    AInixObject(tc, "BO2", "BT")
    tc.finalize_data()

    builder = TorchLatentStore.get_builder(tc.get_type_count(), 3)
    valid_choices = []
    oc = ObjectChoiceNode(ft, ObjectNode(fo1))
    builder.add_example(0, oc)
    s1 = AstObjectChoiceSet(ft)
    s1.add(oc, True, 1, 1)
    valid_choices.append((0, 0, s1))

    oc = ObjectChoiceNode(ft, ObjectNode(fo2))
    builder.add_example(1, oc)
    s1 = AstObjectChoiceSet(ft)
    s1.add(oc, True, 1, 1)
    valid_choices.append((1, 0, s1))

    builder.add_example(2, ObjectChoiceNode(ft, ObjectNode(fo3)))
    return builder.produce_result(), tc, valid_choices
Example #2
0
 def predict(
         self, x_string: str, y_type_name: str, use_only_train_data: bool
 ) -> AstNode:  # TODO (DNGros): change to set
     root_type = self.type_context.get_type_by_name(y_type_name)
     root_node = ObjectChoiceNode(root_type)
     self._predict_step(x_string, root_node, root_node, root_type,
                        use_only_train_data, 0)
     root_node.freeze()
     return root_node
Example #3
0
 def __init__(self, tree_path: ObjectChoiceNode):
     super().__init__()
     self.path_stack: List[AstIterPointer] = list(
         reversed(list(tree_path.depth_first_iter())))
     self.current_y_ind = 0
     self.lattents_log = []
     self.y_inds_log = []
Example #4
0
def _get_unparse_intervals_of_inds(
    dfs_inds_to_include: Sequence[int],
    ast: ObjectChoiceNode,
    unparse: UnparseResult
) -> IntervalTree:
    """Given some indicies we wish include, find the intervals of the total
    unparse string which are covered by those indicies"""
    include_set = set(dfs_inds_to_include)
    interval_tree = IntervalTree()
    currently_including = False
    for ind, pointer in enumerate(ast.depth_first_iter()):
        if ind % 2 != 0:
            # Only take into account the choice nodes. Skip the object nodes
            continue
        assert isinstance(pointer.cur_node, ObjectChoiceNode)
        func_need_to_do_here = None
        if ind in include_set:
            if not currently_including:
                func_need_to_do_here = lambda start, end: interval_tree.add(Interval(start, end))
                currently_including = True
        else:
            if currently_including:
                func_need_to_do_here = lambda start, end: interval_tree.chop(start, end)
                currently_including = False
        if func_need_to_do_here:
            span = unparse.pointer_to_span(pointer)
            if span is None or span[1] - span[0] == 0:
                continue
            start, end = span
            func_need_to_do_here(start, end)
    interval_tree.merge_overlaps()
    return interval_tree
Example #5
0
def _create_gt_compare_result(
    example_ast: ObjectChoiceNode,
    current_leaf: ObjectChoiceNode,
    ground_truth_set: AstObjectChoiceSet
) -> ComparerResult:
    """Creates a `CompareResult` based off some ground truth

    Args:
        example_ast : A parsed AST of the y_text inside the example we are creating
            the ground truth for.
        current_leaf : What we are currently generating for. Used to determine
            what kind of choice we are making.
        ground_truth_set : Our ground truth which we are making the result based
            off of.
    """
    if ground_truth_set.type_to_choose_name != current_leaf.get_type_to_choose_name():
        raise ValueError(f"Unexpected leaf to create gt set. Current leaf has type "
                         f"'{current_leaf.get_type_to_choose_name()}' but the ground_truth "
                         f"set has type '{ground_truth_set.type_to_choose_name}'")
    choices_in_this_example = get_type_choice_nodes(
        example_ast, current_leaf.type_to_choose.name)
    right_choices = [e for e, depth in choices_in_this_example
                     if ground_truth_set.is_known_choice(e.get_chosen_impl_name())]
    in_this_example_impl_name_set = {c.get_chosen_impl_name()
                                     for c, depth in choices_in_this_example}
    right_choices_impl_name_set = {c.get_chosen_impl_name() for c in right_choices}
    this_example_right_prob = 1 if len(right_choices) > 0 else 0
    if right_choices:
        expected_impl_scores = [(1 if impl_name in right_choices_impl_name_set else 0, impl_name)
                                for impl_name in in_this_example_impl_name_set]
        expected_impl_scores.sort()
        expected_impl_scores = tuple(expected_impl_scores)
    else:
        expected_impl_scores = None
    return ComparerResult(this_example_right_prob, expected_impl_scores)
Example #6
0
    def to_string(self,
                  ast: ObjectChoiceNode,
                  copy_from_str: str = None,
                  root_parser_name: str = None) -> 'UnparseResult':
        """The main method used to convert an AST into a string

        Args:
            ast: The ast to convert into a string form
            copy_from_str: The string to refer to when unparsing copy nodes
            root_parser_name: Overrides the unparsing to use a certain type
                parser for the root of the ast.
        """
        if copy_from_str:
            _, tokenized_metadata = self.input_str_tokenizer.tokenize(
                copy_from_str)
        else:
            tokenized_metadata = None

        root_parser = _get_root_parser(self._type_context,
                                       ast.get_type_to_choose_name(),
                                       root_parser_name)
        if not root_parser:
            raise ValueError(
                f"Unable to get a parser type {ast.get_type_to_choose_name()} and "
                f"root_parser {root_parser_name}")
        result_builder = _UnparseResultBuilder(ast, tokenized_metadata)
        self._unparse_object_choice_node(ast, root_parser, result_builder, 0,
                                         tuple())
        return result_builder.as_result()
Example #7
0
 def _inference_object_step(
     self,
     current_leaf: ObjectNode,
     internal_state,  # For lstm the (h_0, c_0) tensors
     memory_tokens: torch.Tensor,
     valid_for_copy_mask: torch.LongTensor,
     cur_depth,
     override_selector: ActionSelector = None
 ) -> Tuple[torch.Tensor, TypeTranslatePredictMetadata]:
     """makes one step for ObjectNodes. Returns last hidden state"""
     if cur_depth > self.MAX_DEPTH:
         raise ModelSafePredictError("Max length exceeded")
     latest_internal_state = internal_state
     metad = TypeTranslatePredictMetadata.create_empty()
     my_features = self.object_embeddings(
         torch.LongTensor([current_leaf.implementation.ind]))
     for arg in current_leaf.implementation.children:
         if arg.next_choice_type is not None:
             new_node = ObjectChoiceNode(arg.next_choice_type)
         else:
             continue
         current_leaf.set_arg_value(arg.name, new_node)
         latest_internal_state, child_metad = self._inference_objectchoice_step(
             new_node, latest_internal_state, my_features, memory_tokens,
             valid_for_copy_mask, cur_depth + 1, override_selector)
         metad = metad.concat(child_metad)
     return latest_internal_state, metad
Example #8
0
 def compare(
     self,
     gen_query:
     str,  # maybe we should just pass in the already tokenized version
     gen_ast_current_root: ObjectChoiceNode,
     gen_ast_current_leaf: ObjectChoiceNode,
     current_gen_depth: int,
     example_query: str,
     example_ast_root: AstNode,
 ) -> ComparerResult:
     # TODO (DNGros): It would be nice to also pass in other features into
     # this such as the root type of the example vs the root type of the gen
     out_prob_score = self._compare_internal(
         gen_query,
         gen_ast_current_root,
         gen_ast_current_leaf,
         current_gen_depth,
         example_query,
         example_ast_root,
     )
     out_prob_logits = torch.nn.functional.sigmoid(out_prob_score)
     # for now assume just one in
     out_prob_logit = out_prob_logits[0]
     # just use the rulebased thing for selecting which one for now
     potential_type_choice_nodes = treeutil.get_type_choice_nodes(
         example_ast_root, gen_ast_current_leaf.get_type_to_choose_name())
     depth_diffs = SimpleRulebasedComparer.get_impl_depth_difference(
         potential_type_choice_nodes, current_gen_depth)
     ranked_options = sorted([(-score, name)
                              for name, score in depth_diffs.items()],
                             reverse=True)
     return ComparerResult(out_prob_logits, tuple(ranked_options))
Example #9
0
 def _get_actual_example_from_index(self, gen_query: str,
                                    gen_ast_current_leaf: ObjectChoiceNode):
     lookup_results = self.index.get_nearest_examples(
         gen_query, gen_ast_current_leaf.get_type_to_choose_name(), max_results=25)
     for result in lookup_results:
         if result.xquery == gen_query:
             return result
     raise ValueError(f"Oracle unable to find result for {gen_query}")
Example #10
0
    def _make_node_for_arg(
        self, arg: typecontext.AInixArgument,
        arg_data: ainix_common.parsing.parse_primitives.ObjectParseArgData,
        delegation_to_node_map: Dict[ParseDelegationReturnMetadata,
                                     ObjectChoiceNode]
    ) -> Tuple[ObjectChoiceNode, Optional[ParseDelegationReturnMetadata]]:
        """
        After running running an Object parsers we get results back for each arg.
        This handles creating nodes for a specific arg results.

        Args:
            arg: The arg we just parsed
            arg_data: The data we got back as a result of the parse
            delegation_to_node_map: A map going from any delegations that already
                have been done to their results.

        Returns:
            new_node: A new object node that represents the data of the arg
            child_metadata: stringparse metadata gotten while parsing the new
                node. If the arg is not present, then it will be None.
        """
        arg_is_present = arg_data is not None
        arg_has_already_been_delegated = arg_is_present and \
            arg_data.set_from_delegation is not None
        if arg_has_already_been_delegated:
            done_delegation = arg_data.set_from_delegation
            return delegation_to_node_map[done_delegation], done_delegation

        if not arg.required:
            arg_string_metadata = None
            if arg_is_present:
                arg_has_substructure_to_parse = arg.type_name is not None
                if arg_has_substructure_to_parse:
                    # TODO (DNGros): add back in stripping for slice_string?
                    inner_arg_node, arg_string_metadata = self._parse_object_choice_node(
                        arg_data.slice_string, arg.type_parser, arg.type)
                    arg_map = pmap({
                        typecontext.OPTIONAL_ARGUMENT_NEXT_ARG_NAME:
                        inner_arg_node
                    })
                else:
                    arg_map = pmap({})
                object_choice = ObjectNode(arg.is_present_object, arg_map)
            else:
                object_choice = ObjectNode(arg.not_present_object, pmap({}))
            return ObjectChoiceNode(arg.present_choice_type,
                                    object_choice), arg_string_metadata
        else:
            inner_arg_node, arg_string_metadata = self._parse_object_choice_node(
                arg_data.slice_string, arg.type_parser, arg.type)
            if arg_string_metadata.remaining_right_starti != len(
                    arg_data.slice_string):
                # I don't know if we actually want to error here???
                raise AInixParseError(
                    f"Expected to fully consume string arg delegation {arg_data.slice_string}"
                    f"but only consumed {arg_string_metadata.remaining_right_starti} out of"
                    f"{len(arg_data.slice_string)}")
            return inner_arg_node, arg_string_metadata
def test_touch_set(all_the_stuff_context):
    x_str = 'set the last mod time of out.txt to now'
    tc = all_the_stuff_context
    parser = StringParser(tc)
    string = "touch out.txt"
    ast = parser.create_parse_tree(string, "Program")
    unparser = AstUnparser(tc, NonLetterTokenizer())
    result = unparser.to_string(ast, x_str)
    assert result.total_string == string

    cset = AstObjectChoiceSet(tc.get_type_by_name("Program"))
    cset.add(ast, True, 1, 1)
    new_ast = parser.create_parse_tree(string, "Program")
    assert cset.is_node_known_valid(new_ast)

    tokenizer = NonLetterTokenizer()
    _, tok_metadata = tokenizer.tokenize(x_str)
    ast_copies = make_copy_version_of_tree(ast, unparser, tok_metadata)
    add_copies_to_ast_set(ast, cset, unparser, tok_metadata)
    assert cset.is_node_known_valid(ast_copies)
    assert cset.is_node_known_valid(ast)

    # Scary complicated reconstruction of something that broke it.
    # could be made into a simpler unit test in copy_tools
    touch_o = tc.get_object_by_name("touch")
    file_list = tc.get_type_by_name("PathList")
    r_arg = touch_o.get_arg_by_name("r")
    m_arg = touch_o.get_arg_by_name("m")
    other_copy = ObjectChoiceNode(
        tc.get_type_by_name("Program"),
        ObjectNode(
            touch_o,
            pmap({
                "r":
                ObjectChoiceNode(r_arg.present_choice_type,
                                 ObjectNode(r_arg.not_present_object, pmap())),
                "m":
                ObjectChoiceNode(m_arg.present_choice_type,
                                 ObjectNode(m_arg.not_present_object, pmap())),
                "file_list":
                ObjectChoiceNode(file_list, CopyNode(file_list, 12, 14))
            })))
    other_result = unparser.to_string(other_copy, x_str)
    assert other_result.total_string == string
    assert cset.is_node_known_valid(other_copy)
Example #12
0
 def _train_step(
     self, x_query: str, expected: AstObjectChoiceSet,
     current_gen_root: ObjectChoiceNode, current_gen_leaf: ObjectChoiceNode,
     teacher_force_path: ObjectChoiceNode, current_depth: int
 ):  # TODO This should likely eventually return the new current_gen ast
     self.type_predictor.train(x_query, current_gen_root, current_gen_leaf,
                               expected, current_depth)
     # figure out where going next
     next_expected_node = expected.get_next_node_for_choice(
         teacher_force_path.get_chosen_impl_name()).next_node
     assert next_expected_node is not None, "Teacher force path not in expected ast set!"
     next_object_node = ObjectNode(
         teacher_force_path.next_node.implementation)
     current_gen_leaf.set_choice(next_object_node)
     self._train_obj_node_step(x_query, next_expected_node,
                               current_gen_root, next_object_node,
                               teacher_force_path.next_node,
                               current_depth + 1)
Example #13
0
 def _get_class_ind_for_node(self, node: ObjectChoiceNode,
                             add_if_not_present: bool) -> Optional[int]:
     name_str = "~COPY~" if node.copy_was_chosen else node.get_chosen_impl_name(
     )
     if name_str not in self._object_name_to_ind:
         if add_if_not_present:
             self._object_name_to_ind[name_str] = self._num_classes_seen
             self._num_classes_seen += 1
         else:
             return None
     return self._object_name_to_ind[name_str]
Example #14
0
 def _search(
     self,
     x_query,
     current_leaf: ObjectChoiceNode,
     use_only_training_data: bool
 ) -> List[XValue]:
     type_name = current_leaf.get_type_to_choose_name()
     split_filter = (DataSplits.TRAIN,) if use_only_training_data else None
     return list(self.index.get_nearest_examples(
         x_query, choose_type_name=type_name, filter_splits=split_filter,
         max_results=self.max_examples_to_compare))
Example #15
0
 def update_fn(new_summary, ast: ObjectChoiceNode):
     for pointer in ast.depth_first_iter():
         node = pointer.cur_node
         if isinstance(node, ObjectChoiceNode):
             type_name = node.type_to_choose.name
             if type_name not in type_name_to_nb_model:
                 type_name_to_nb_model[type_name] = ObjectChoiceModel(
                     node.type_to_choose)
             new_summary_and_depth = torch.cat(
                 (new_summary, torch.tensor([float(pointer.get_depth())])))
             type_name_to_nb_model[type_name].add_examples(
                 new_summary_and_depth.unsqueeze(0), [node])
Example #16
0
    def _delegate_object_arg_parse(
        self, delegation: ArgParseDelegation
    ) -> Tuple[ParseDelegationReturnMetadata, Optional[ObjectChoiceNode]]:
        """Called to parse an argument which a parser asked to delegate and get
        the resulting remaining string from.

        Returns:
            ParseDelegationReturnMetadata: The return value to send back through into
                the calling parser.
            ObjectChoiceNode: The parsed value for arg. If it was failure, then
                will be None.
        """
        arg = delegation.arg

        # Do parsing if needed.
        if arg.type is not None:
            try:
                arg_type_choice, parse_metadata = self._parse_object_choice_node(
                    delegation.string_to_parse, arg.type_parser, arg.type)
                out_delegation_return = ParseDelegationReturnMetadata(
                    parse_metadata.parse_success, parse_metadata.string_parsed,
                    delegation.slice_to_parse[0], arg,
                    parse_metadata.remaining_right_starti,
                    parse_metadata.fail_reason)
            except AInixParseError as e:
                # TODO (DNGros): Use the metadata rather than exceptions to manage this
                metadata = ParseDelegationReturnMetadata(
                    False, delegation.string_to_parse,
                    delegation.slice_to_parse[0], arg, None, str(e))
                return metadata, None
        else:
            # If has None type, then we don't have to any parsing. Assume it was
            # a success and that the arg is present.
            arg_type_choice = None
            out_delegation_return = ParseDelegationReturnMetadata.make_for_unparsed_string(
                delegation.string_to_parse, arg, delegation.slice_to_parse[0])

        # Figure out the actual node we need to output
        if not arg.required:
            # If it is an optional node, wrap it in a "is present" present node.
            if arg.type is None:
                object_choice = ObjectNode(arg.is_present_object, pmap({}))
            else:
                parsed_v_as_arg = pmap({
                    typecontext.OPTIONAL_ARGUMENT_NEXT_ARG_NAME:
                    arg_type_choice
                })
                object_choice = ObjectNode(arg.is_present_object,
                                           parsed_v_as_arg)
            out_node = ObjectChoiceNode(arg.present_choice_type, object_choice)
        else:
            out_node = arg_type_choice
        return out_delegation_return, out_node
Example #17
0
    def _inference_objectchoice_step(
        self,
        current_leaf: ObjectChoiceNode,
        internal_state,  # For an lstm cell this is the (h_0, c_0) tensors
        parent_node_features: Optional[torch.Tensor],
        memory_tokens: torch.Tensor,
        valid_for_copy_mask: torch.LongTensor,
        cur_depth: int,
        override_action_selector: ActionSelector = None
    ) -> Tuple[torch.Tensor, TypeTranslatePredictMetadata]:
        if cur_depth > self.MAX_DEPTH:
            raise ModelException()
        outs, internal_state = self.rnn_cell(
            internal_state=internal_state,
            type_to_predict_features=self._get_obj_choice_features(
                current_leaf),
            parent_node_features=parent_node_features,
            parent_node_hidden=None,
            memory_tokens=memory_tokens)
        if len(outs) != 1:
            raise NotImplemented("Batches not implemented")

        use_selector = override_action_selector or self.action_selector
        predicted_action, my_metad = use_selector.infer_predict(
            outs, memory_tokens, valid_for_copy_mask,
            current_leaf.type_to_choose)
        if isinstance(predicted_action, CopyAction):
            current_leaf.set_choice(
                CopyNode(current_leaf.type_to_choose, predicted_action.start,
                         predicted_action.end))
        elif isinstance(predicted_action, ProduceObjectAction):
            new_node = ObjectNode(predicted_action.implementation)
            current_leaf.set_choice(new_node)
            internal_state, child_metad = self._inference_object_step(
                new_node, internal_state, memory_tokens, valid_for_copy_mask,
                cur_depth + 1, override_action_selector)
            my_metad = my_metad.concat(child_metad)
        else:
            raise ValueError()
        return internal_state, my_metad
Example #18
0
 def add_example(self, example_id: int, ast: ObjectChoiceNode):
     ind_map = []
     for y_depth, pointer in enumerate(ast.depth_first_iter()):
         if isinstance(pointer.cur_node, ObjectChoiceNode):
             assert y_depth % 2 == 0
             type_ind = pointer.cur_node.type_to_choose.ind
             ind_map.append(len(self.type_ind_to_example_id[type_ind]))
             self.type_ind_to_example_id[type_ind].append(example_id)
             chosen_impl_id = COPY_IND if pointer.cur_node.copy_was_chosen \
                 else pointer.cur_node.next_node_not_copy.implementation.ind
             self.type_ind_to_impl_choice[type_ind].append(chosen_impl_id)
             self.type_ind_to_y_dfsdeth[type_ind].append(y_depth)
     assert example_id not in self.example_id_to_inds_in_type
     self.example_id_to_inds_in_type[example_id] = ind_map
Example #19
0
 def _predict_step(self, x_query: str, current_root: ObjectChoiceNode,
                   current_leaf: AstNode, root_y_type: AInixType,
                   use_only_train_data: bool, current_depth: int):
     if isinstance(current_leaf, ObjectChoiceNode):
         predicted_impl = self.type_predictor.predict(
             x_query, current_root, current_leaf, use_only_train_data,
             current_depth)
         new_node = ObjectNode(predicted_impl)
         current_leaf.set_choice(new_node)
         if new_node is not None:
             self._predict_step(x_query, current_root, new_node,
                                root_y_type, use_only_train_data,
                                current_depth + 1)
     elif isinstance(current_leaf, ObjectNode):
         # TODO (DNGros): this is messy. Should have better iteration based
         # off next unfilled node rather than having to mutate state.
         #####
         # Actually it would probably be better to just store a pointer to
         # the root and a pointer to the leaf. Then you call an add_to_leaf()
         # on the root which does a copying add, with structure sharing of
         # any arg before the arg that leaf is on. This is also nice because
         # then AstNodes can be made purly immutable and beam search becomes
         # nearly trivial to do
         for arg in current_leaf.implementation.children:
             if not arg.required:
                 new_node = ObjectChoiceNode(arg.present_choice_type)
             elif arg.type is not None:
                 new_node = ObjectChoiceNode(arg.type)
             else:
                 continue
             current_leaf.set_arg_value(arg.name, new_node)
             self._predict_step(x_query, current_root, new_node,
                                root_y_type, use_only_train_data,
                                current_depth + 1)
     else:
         raise ValueError(f"leaf node {current_leaf} not predictable")
Example #20
0
def add_copies_to_ast_set(ast: ObjectChoiceNode,
                          ast_set: AstObjectChoiceSet,
                          unparser: AstUnparser,
                          token_metadata: StringTokensMetadata,
                          copy_node_weight: float = 1) -> None:
    """Takes in an AST that has been parsed and adds copynodes where appropriate
    to an AstSet that contains that AST"""
    unparse = unparser.to_string(ast)
    df_ast_pointers = list(ast.depth_first_iter())
    df_ast_nodes = [pointer.cur_node for pointer in ast.depth_first_iter()]
    df_ast_set = list(
        depth_first_iterate_ast_set_along_path(ast_set, df_ast_nodes))
    assert len(df_ast_nodes) == len(df_ast_set)
    for pointer, cur_set in zip(df_ast_pointers, df_ast_set):
        if isinstance(pointer.cur_node, ObjectChoiceNode):
            # TODO (DNGros): Figure out if we are handling weight and probability right
            # I think works fine now if known valid
            _try_add_copy_node_at_object_choice(pointer, cur_set, True,
                                                copy_node_weight, 1, unparse,
                                                token_metadata)
        elif isinstance(pointer.cur_node, ObjectNode):
            pass
        else:
            raise ValueError("Unrecognized node?")
Example #21
0
 def _parse_object_choice_node(
     self, string: str, parser: TypeParser, type: typecontext.AInixType
 ) -> Tuple[ObjectChoiceNode, ParseDelegationReturnMetadata]:
     """Parses a string into a ObjectChoiceNode. This is more internal use.
     For the more user friendly method see create_parse_tree()"""
     result, delegation_map = self._run_type_parser_with_delegations(
         string, type, parser)
     if result._accepted_delegation is not None:
         next_object_node, child_string_metadata = delegation_map[
             result._accepted_delegation]
     else:
         next_object_node, child_string_metadata = self._parse_object_node(
             result.get_implementation(), result.get_next_string(),
             result.next_parser,
             result.get_next_slice()[0])
     metadata = self._object_choice_result_to_string_metadata(
         result, child_string_metadata)
     return ObjectChoiceNode(type, next_object_node), metadata
Example #22
0
 def compare(
     self,
     gen_query: str,
     gen_ast_current_root: ObjectChoiceNode,
     gen_ast_current_leaf: ObjectChoiceNode,
     current_gen_depth: int,
     example_query: str,
     example_ast_root: AstNode,
 ) -> ComparerResult:
     potential_type_choice_nodes = get_type_choice_nodes(
         example_ast_root, gen_ast_current_leaf.get_type_to_choose_name())
     depth_diffs = self.get_impl_depth_difference(
         potential_type_choice_nodes, current_gen_depth)
     ranked_options = sorted(
         [(1/(depth_diff+1), name) for name, depth_diff in depth_diffs.items()],
         reverse=True
     )
     return ComparerResult(1, tuple(ranked_options))
Example #23
0
    def _train_objectchoice_step(
            self,
            last_internal_state,  # For an lstm cell this is the (h_0, c_0) tensors
            memory_tokens: torch.Tensor,
            valid_for_copy_mask: torch.LongTensor,
            parent_node_features: Optional[torch.Tensor],
            expected: AstObjectChoiceSet,
            teacher_force_path: ObjectChoiceNode,
            num_parents_with_a_copy_option: int,
            example_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
        outs, internal_state = self.rnn_cell(
            internal_state=last_internal_state,
            type_to_predict_features=self._get_obj_choice_features(
                teacher_force_path),
            parent_node_features=parent_node_features,
            parent_node_hidden=None,
            memory_tokens=memory_tokens)

        loss = self.action_selector.forward_train(
            latent_vec=outs,
            memory_tokens=memory_tokens,
            valid_for_copy_mask=valid_for_copy_mask,
            types_to_select=[teacher_force_path.type_to_choose],
            expected=expected,
            num_of_parents_with_copy_option=num_parents_with_a_copy_option,
            example_inds=[example_id])

        child_loss = 0
        if not expected.copy_is_known_choice():
            # Don't descend if copy? Not sure if this is good. Need to do something
            # to usensure the altering of internal state is same between train and
            # test and this is an easy fix
            next_expected_set = expected.get_next_node_for_choice(
                impl_name_chosen=teacher_force_path.get_chosen_impl_name(
                )).next_node
            assert next_expected_set is not None, "Teacher force path not in expected ast set!"
            next_object_node = teacher_force_path.next_node_not_copy
            internal_state, child_loss = self._train_objectnode_step(
                internal_state, memory_tokens, valid_for_copy_mask,
                next_expected_set, next_object_node,
                num_parents_with_a_copy_option +
                (1 if expected.copy_is_known_choice() else 0), example_id)
        return internal_state, loss + child_loss
Example #24
0
 def forward_predict(
     self,
     query_summary: torch.Tensor,
     memory_encoding: torch.Tensor,
     actual_tokens: List[List[ModifiedStringToken]],
     root_type: AInixType,
     override_action_selector: ActionSelector = None
 ) -> Tuple[ObjectChoiceNode, TypeTranslatePredictMetadata]:
     if self.training:
         raise ValueError(
             "Expect to not being in training mode during inference.")
     # TODO (DNGros): make steps this not mutate state and iterative
     prediction_root_node = ObjectChoiceNode(root_type)
     valid_for_copy_mask = get_valid_for_copy_mask(actual_tokens)
     internal_state, metad = self._inference_objectchoice_step(
         prediction_root_node,
         self.rnn_cell.create_first_state(query_summary), None,
         memory_encoding, valid_for_copy_mask, 0, override_action_selector)
     return prediction_root_node, metad
Example #25
0
 def _train_obj_node_step(self, x_query: str, expected: ObjectNodeSet,
                          current_gen_root: ObjectChoiceNode,
                          current_gen_leaf: ObjectNode,
                          teacher_force_path: ObjectNode,
                          current_depth: int):
     arg_set_data = expected.get_arg_set_data(
         teacher_force_path.as_childless_node())
     assert arg_set_data is not None, "Teacher force path not in expected ast set!"
     for arg in teacher_force_path.implementation.children:
         #if arg.type is None:
         #    continue
         next_choice_set = arg_set_data.arg_to_choice_set[arg.name]
         # TODO (DNGros): This is currently somewhat gross as it relies on the _train_step
         # call mutating state. Once it is changed to make changes on current_gen_root
         # this shouldn't be an issue.
         next_gen_leaf = ObjectChoiceNode(arg.next_choice_type)
         current_gen_leaf.set_arg_value(arg.name, next_gen_leaf)
         self._train_step(
             x_query, next_choice_set, current_gen_root, next_gen_leaf,
             teacher_force_path.get_choice_node_for_arg(arg.name),
             current_depth + 1)
Example #26
0
def _try_add_copy_node_at_object_choice(
    pointer: AstIterPointer,
    ast_set: AstObjectChoiceSet,
    known_valid: bool,
    max_weight: float,
    max_probability: float,
    unparse: UnparseResult,
    token_metadata: StringTokensMetadata,
):
    node: ObjectChoiceNode = pointer.cur_node
    if is_obj_choice_a_not_present_node(node):
        return
    if (pointer.get_child_nums_here(),
            pointer.cur_node) not in unparse.child_path_and_node_to_span:
        return  # This might be a terrible idea since won't know when bad parser...
    this_node_str = unparse.pointer_to_string(pointer)
    copy_pos = string_in_tok_list(this_node_str, token_metadata)
    if copy_pos:
        copy_node = CopyNode(node.type_to_choose, copy_pos[0], copy_pos[1])
        object_node_set: ObjectNodeSet = ast_set.parent
        ast_set.add_node_when_copy(copy_node, known_valid, max_weight,
                                   max_probability)
        if object_node_set is not None:
            # If we have a parent object node, it needs to know about the copy too
            # This will probably screw up probabilities and weights and stuff
            # This is really awkward tree surergy which could likely be cleaner
            object_node: ObjectNode = pointer.parent.cur_node
            object_node, _ = object_node.path_clone([
                object_node,
                object_node.get_nth_child(pointer.parent_child_ind)
            ])
            object_node.set_arg_value(
                object_node.implementation.children[
                    pointer.parent_child_ind].name,
                ObjectChoiceNode(node.type_to_choose, copy_node))
            object_node_set.add(object_node, known_valid, max_weight,
                                max_probability)
Example #27
0
def get_paths_to_all_copies(ast: ObjectChoiceNode) -> Tuple[Tuple[int, ...]]:
    all_copy_paths = []
    for pointer in ast.depth_first_iter():
        if isinstance(pointer.cur_node, CopyNode):
            all_copy_paths.append(pointer.get_child_nums_here())
    return tuple(all_copy_paths)
Example #28
0
def test_train_retriever_selector_copy():
    torch.manual_seed(1)
    tc = TypeContext()
    ft = AInixType(tc, "FT")
    fo1 = AInixObject(tc, "FO1", "FT")
    fo2 = AInixObject(tc, "FO2", "FT")
    fo3 = AInixObject(tc, "FO3", "FT")
    bt = AInixType(tc, "BT")
    bo1 = AInixObject(tc, "BO1", "BT")
    AInixObject(tc, "BO2", "BT")
    tc.finalize_data()
    latent_size = 3

    builder = TorchLatentStore.get_builder(tc.get_type_count(), latent_size)
    valid_choices = []
    oc = ObjectChoiceNode(ft, ObjectNode(fo1))
    builder.add_example(0, oc)
    s1 = AstObjectChoiceSet(ft)
    s1.add(oc, True, 1, 1)
    valid_choices.append((0, 0, s1))

    oc = ObjectChoiceNode(ft, CopyNode(ft, 0, 3))
    builder.add_example(1, oc)
    s1 = AstObjectChoiceSet(ft)
    s1.add(oc, True, 1, 1)
    valid_choices.append((1, 0, s1))
    latent_store = builder.produce_result()
    #
    embed = torch.nn.Embedding(len(valid_choices), latent_size)
    inputs = [(torch.LongTensor([i]), c, torch.randn(1, 5, latent_size))
              for i, c in enumerate(valid_choices)]
    instance = RetrievalActionSelector(latent_store, tc, retrieve_dropout_p=0)
    instance.start_train_session()
    params = itertools.chain(instance.parameters(), embed.parameters())
    optim = torch.optim.Adam(params, lr=1e-2)
    print(inputs)

    def do_train():
        optim.zero_grad()
        loss = 0
        for x, (example_id, step, astset), mem_tokens in inputs:
            x_v = embed(x)
            loss += instance.forward_train(x_v, mem_tokens,
                                           [tc.get_type_by_name("FT")], astset,
                                           0, [example_id])
        loss.backward()
        optim.step()
        return loss

    for e in range(100):
        loss = do_train()
        #print("LOSS", loss)
        #s: TorchLatentStore = instance.latent_store
        #print("LATENTS", s.type_ind_to_latents)

    for x, (example_id, step, astset), mem_tokens in inputs:
        x_v = embed(x)
        pred = instance.infer_predict(x_v, mem_tokens,
                                      tc.get_type_by_name("FT"))
        #print("x", x)
        #print("pred", pred)
        if isinstance(pred, ProduceObjectAction):
            assert astset.is_known_choice(pred.implementation.name)
        elif isinstance(pred, CopyAction):
            n = ObjectChoiceNode(ft, CopyNode(ft, pred.start, pred.end))
            assert astset.is_node_known_valid(n)
        else:
            raise ValueError()
Example #29
0
 def train(self, x_string: str, y_ast: AstObjectChoiceSet,
           teacher_force_path: ObjectChoiceNode) -> None:
     current_gen_node = ObjectChoiceNode(teacher_force_path.type_to_choose)
     self._train_step(x_string, y_ast, current_gen_node, current_gen_node,
                      teacher_force_path, 0)