Example #1
0
    def __init__(self,
                 init_tree_w_dummy_reduce: AbstractSyntaxTree,
                 bool_copy_subtree=False,
                 tree=None,
                 memory=None,
                 memory_type='all_init_joint',
                 init_code_tokens=None,
                 length_norm=False):
        self.init_tree_w_dummy_reduce = init_tree_w_dummy_reduce
        self.bool_copy_subtree = bool_copy_subtree
        assert memory_type in ('all_init_joint', 'all_init_distinct',
                               'deleted_distinct')
        self.memory_type = memory_type
        self.init_code_tokens = init_code_tokens
        self.length_norm = length_norm

        if tree is not None:
            self.tree = tree
        else:
            self.tree = init_tree_w_dummy_reduce.copy()

        if bool_copy_subtree and memory is None:
            if self.memory_type == 'all_init_joint':
                self.memory = stack_subtrees(
                    self.init_tree_w_dummy_reduce.root_node)
            elif self.memory_type == 'all_init_distinct':
                self.memory = []
                for node in stack_subtrees(
                        self.init_tree_w_dummy_reduce.root_node):
                    if node not in self.memory:
                        self.memory.append(node)
            else:
                self.memory = []
        else:
            self.memory = memory
        # self.set_tree_all_finish() # redundant?

        self.edits = []
        self.score_per_edit = []
        self.score = 0.

        self.repr2field = {}
        self.open_del_node_and_ids = []  # nodes available to delete
        self.open_add_fields = []  # fields open to add nodes
        self.restricted_frontier_fields = [
        ]  # fields (esp. with single cardinality) grammatically need to fill
        self.update_frontier_info()

        # record the current time step
        self.last_edit_field_node = None  # trace the last edit
        self.t = 0
        self.stop_t = None
Example #2
0
class Hypothesis(object):
    def __init__(self):
        self.tree = None
        self.actions = []
        self.score = 0.
        self.frontier_node = None
        self.frontier_field = None
        self._value_buffer = []

        # record the current time step
        self.t = 0

    def apply_action(self, action):
        if self.tree is None:
            assert isinstance(action, ApplyRuleAction), 'Invalid action [%s], only ApplyRule action is valid ' \
                                                        'at the beginning of decoding'

            self.tree = AbstractSyntaxTree(action.production)
            self.update_frontier_info()
        elif self.frontier_node:
            if isinstance(self.frontier_field.type, ASDLCompositeType):
                if isinstance(action, ApplyRuleAction):
                    field_value = AbstractSyntaxTree(action.production)
                    field_value.created_time = self.t
                    self.frontier_field.add_value(field_value)
                    self.update_frontier_info()
                elif isinstance(action, ReduceAction):
                    assert self.frontier_field.cardinality in ('optional', 'multiple'), 'Reduce action can only be ' \
                                                                                        'applied on field with multiple ' \
                                                                                        'cardinality'
                    self.frontier_field.set_finish()
                    self.update_frontier_info()
                else:
                    raise ValueError('Invalid action [%s] on field [%s]' %
                                     (action, self.frontier_field))
            else:  # fill in a primitive field
                if isinstance(action, GenTokenAction):
                    # only field of type string requires termination signal </primitive>
                    end_primitive = False
                    if self.frontier_field.type.name == 'string':
                        if action.is_stop_signal():
                            self.frontier_field.add_value(' '.join(
                                self._value_buffer))
                            self._value_buffer = []

                            end_primitive = True
                        else:
                            self._value_buffer.append(action.token)
                    else:
                        self.frontier_field.add_value(action.token)
                        end_primitive = True

                    if end_primitive and self.frontier_field.cardinality in (
                            'single', 'optional'):
                        self.frontier_field.set_finish()
                        self.update_frontier_info()

                elif isinstance(action, ReduceAction):
                    assert self.frontier_field.cardinality in ('optional', 'multiple'), 'Reduce action can only be ' \
                                                                                        'applied on field with multiple ' \
                                                                                        'cardinality'
                    self.frontier_field.set_finish()
                    self.update_frontier_info()
                else:
                    raise ValueError(
                        'Can only invoke GenToken or Reduce actions on primitive fields'
                    )

        self.t += 1
        self.actions.append(action)

    def update_frontier_info(self):
        def _find_frontier_node_and_field(tree_node):
            if tree_node:
                for field in tree_node.fields:
                    # if it's an intermediate node, check its children
                    if isinstance(field.type,
                                  ASDLCompositeType) and field.value:
                        if field.cardinality in ('single', 'optional'):
                            iter_values = [field.value]
                        else:
                            iter_values = field.value

                        for child_node in iter_values:
                            result = _find_frontier_node_and_field(child_node)
                            if result: return result

                    # now all its possible children are checked
                    if not field.finished:
                        return tree_node, field

                return None
            else:
                return None

        frontier_info = _find_frontier_node_and_field(self.tree)
        if frontier_info:
            self.frontier_node, self.frontier_field = frontier_info
        else:
            self.frontier_node, self.frontier_field = None, None

    def clone_and_apply_action(self, action):
        new_hyp = self.copy()
        new_hyp.apply_action(action)

        return new_hyp

    def copy(self):
        new_hyp = Hypothesis()
        if self.tree:
            new_hyp.tree = self.tree.copy()

        new_hyp.actions = list(self.actions)
        new_hyp.score = self.score
        new_hyp._value_buffer = list(self._value_buffer)
        new_hyp.t = self.t

        new_hyp.update_frontier_info()

        return new_hyp

    @property
    def completed(self):
        return self.tree and self.frontier_field is None