def test_production_rules_sequence_to_stack_invalid(self):
   prod_rule_strings = [
       'S -> S "+" T',
       'T -> "x"',
   ]
   prod_rules_sequence = [
       self.prod_rules_dict[prod_rule_string]
       for prod_rule_string in prod_rule_strings
   ]
   with self.assertRaisesRegexp(
       ValueError,
       'Left hand side symbol of production rule T -> \'x\' does not match '
       r'the symbol in the stack \(S\)'):
     postprocessor.production_rules_sequence_to_stack(prod_rules_sequence)
Пример #2
0
def _get_next_production_rule_mask_batch(partial_sequences,
                                         partial_sequence_lengths, grammar):
    """Gets masks of next production rule for a batch of partial sequences.

  Args:
    partial_sequences: Integer numpy array with shape [batch_size, max_length].
        Batch of partial sequences of the expression sequences.
    partial_sequence_lengths: Integer numpy array with shape [batch_size]. The
        actual length of partial sequences without padding.
    grammar: arithmetic_grammar.Grammar.

  Returns:
    Boolean numpy array with shape [batch_size, num_production_rules].
    num_production_rules is the number of production rules in grammar.
  """
    next_production_rule_masks = np.zeros(
        (len(partial_sequences), grammar.num_production_rules), dtype=bool)
    for i, (partial_sequence, partial_sequence_length) in enumerate(
            zip(partial_sequences, partial_sequence_lengths)):
        stack = postprocessor.production_rules_sequence_to_stack([
            grammar.prod_rules[index]
            for index in partial_sequence[:partial_sequence_length]
        ])
        next_production_rule_masks[i] = grammar.masks[grammar.lhs_to_index[
            stack.pop()]]
    return next_production_rule_masks
Пример #3
0
  def __init__(self, production_rules_sequence, stack=None):
    """Initializer.

    If this state is the initial state with no production rules sequence, pass
    a list of one symbol string to stack argument. This will enforce the next
    production rule to append starting with this symbol.

    Args:
      production_rules_sequence: List of nltk.grammar.Production objects. This
          sequence is obtained by a preorder traversal of the context-free
          grammar parsing tree.
      stack: GrammarLhsStack object or list, the stack to store the string of
          left hand side symbol. The left hand side symbol of valid production
          rule to append must match the top element in the stack. If the input
          is a list, the last element in the list is the top element in the
          stack.

    Raises:
      ValueError: If stack is not list, GrammarLhsStack or None.
    """
    self._production_rules_sequence = production_rules_sequence
    if stack is None:
      self._stack = postprocessor.production_rules_sequence_to_stack(
          production_rules_sequence)
    elif isinstance(stack, list):
      self._stack = postprocessor.GrammarLhsStack(stack)
    elif isinstance(stack, postprocessor.GrammarLhsStack):
      self._stack = stack.copy()
    else:
      raise ValueError('stack is expected to be list, GrammarLhsStack or '
                       'None, but got %s.' % type(stack))
    # Log the state information defined in __repr__.
    logging.info('Create %s', self)
Пример #4
0
 def test_production_rules_sequence_to_stack(self):
     prod_rule_strings = [
         'S -> S "+" T',
         'S -> T',
     ]
     prod_rules_sequence = [
         self.prod_rules_dict[prod_rule_string]
         for prod_rule_string in prod_rule_strings
     ]
     stack = postprocessor.production_rules_sequence_to_stack(
         prod_rules_sequence)
     self.assertEqual(stack._stack, ['T', 'T'])
def get_number_valid_next_step(prod_rules_sequence_indices, grammar):
  """Gets number of valid next production rules.

  Args:
    prod_rules_sequence_indices: A 1-D Numpy array of indices of a production
        rule sequence.
    grammar: A grammar object.

  Returns:
    An integer representing the number of valid next production rules.
  """
  prod_rules_sequence = [
      grammar.prod_rules[i] for i in prod_rules_sequence_indices
  ]
  stack = postprocessor.production_rules_sequence_to_stack(prod_rules_sequence)
  return int(sum(grammar.masks[grammar.lhs_to_index[stack.peek()]]))
Пример #6
0
 def test_production_rules_sequence_to_stack_empty_sequence(self):
     stack = postprocessor.production_rules_sequence_to_stack([])
     self.assertEqual(stack._stack, [])