def test_production_rules_sequence_to_stack_invalid(self): prod_rule_strings = [ 'S -> S "+" T', 'T -> "x"', ] prod_rules_sequence = [ self.prod_rules_dict[prod_rule_string] for prod_rule_string in prod_rule_strings ] with self.assertRaisesRegexp( ValueError, 'Left hand side symbol of production rule T -> \'x\' does not match ' r'the symbol in the stack \(S\)'): postprocessor.production_rules_sequence_to_stack(prod_rules_sequence)
def _get_next_production_rule_mask_batch(partial_sequences, partial_sequence_lengths, grammar): """Gets masks of next production rule for a batch of partial sequences. Args: partial_sequences: Integer numpy array with shape [batch_size, max_length]. Batch of partial sequences of the expression sequences. partial_sequence_lengths: Integer numpy array with shape [batch_size]. The actual length of partial sequences without padding. grammar: arithmetic_grammar.Grammar. Returns: Boolean numpy array with shape [batch_size, num_production_rules]. num_production_rules is the number of production rules in grammar. """ next_production_rule_masks = np.zeros( (len(partial_sequences), grammar.num_production_rules), dtype=bool) for i, (partial_sequence, partial_sequence_length) in enumerate( zip(partial_sequences, partial_sequence_lengths)): stack = postprocessor.production_rules_sequence_to_stack([ grammar.prod_rules[index] for index in partial_sequence[:partial_sequence_length] ]) next_production_rule_masks[i] = grammar.masks[grammar.lhs_to_index[ stack.pop()]] return next_production_rule_masks
def __init__(self, production_rules_sequence, stack=None): """Initializer. If this state is the initial state with no production rules sequence, pass a list of one symbol string to stack argument. This will enforce the next production rule to append starting with this symbol. Args: production_rules_sequence: List of nltk.grammar.Production objects. This sequence is obtained by a preorder traversal of the context-free grammar parsing tree. stack: GrammarLhsStack object or list, the stack to store the string of left hand side symbol. The left hand side symbol of valid production rule to append must match the top element in the stack. If the input is a list, the last element in the list is the top element in the stack. Raises: ValueError: If stack is not list, GrammarLhsStack or None. """ self._production_rules_sequence = production_rules_sequence if stack is None: self._stack = postprocessor.production_rules_sequence_to_stack( production_rules_sequence) elif isinstance(stack, list): self._stack = postprocessor.GrammarLhsStack(stack) elif isinstance(stack, postprocessor.GrammarLhsStack): self._stack = stack.copy() else: raise ValueError('stack is expected to be list, GrammarLhsStack or ' 'None, but got %s.' % type(stack)) # Log the state information defined in __repr__. logging.info('Create %s', self)
def test_production_rules_sequence_to_stack(self): prod_rule_strings = [ 'S -> S "+" T', 'S -> T', ] prod_rules_sequence = [ self.prod_rules_dict[prod_rule_string] for prod_rule_string in prod_rule_strings ] stack = postprocessor.production_rules_sequence_to_stack( prod_rules_sequence) self.assertEqual(stack._stack, ['T', 'T'])
def get_number_valid_next_step(prod_rules_sequence_indices, grammar): """Gets number of valid next production rules. Args: prod_rules_sequence_indices: A 1-D Numpy array of indices of a production rule sequence. grammar: A grammar object. Returns: An integer representing the number of valid next production rules. """ prod_rules_sequence = [ grammar.prod_rules[i] for i in prod_rules_sequence_indices ] stack = postprocessor.production_rules_sequence_to_stack(prod_rules_sequence) return int(sum(grammar.masks[grammar.lhs_to_index[stack.peek()]]))
def test_production_rules_sequence_to_stack_empty_sequence(self): stack = postprocessor.production_rules_sequence_to_stack([]) self.assertEqual(stack._stack, [])