예제 #1
0
 def test_is_terminal_empty(self):
     self.assertTrue(
         states.ProductionRulesState(
             production_rules_sequence=[]).is_terminal())
     self.assertFalse(
         states.ProductionRulesState(production_rules_sequence=[],
                                     stack=['S']).is_terminal())
예제 #2
0
 def test_eq(self):
   production_rules_sequence = self._strings_to_production_rules([
       'S -> S "+" T',
       'S -> T',
       'T -> "x"',
       'T -> "x"',
       constants.DUMMY_PRODUCTION_RULE,
   ])
   state1 = states.ProductionRulesState(production_rules_sequence)
   state2 = states.ProductionRulesState(production_rules_sequence)
   self.assertEqual(state1, state2)
예제 #3
0
    def test_get_new_states_probs(self):
        state = states.ProductionRulesState(
            self._strings_to_production_rules([
                'S -> S "+" T',
                'S -> T',
            ]))
        # The above production rules sequence are parsed as
        #                 S
        #                 |
        #              S '+' T
        #              |
        #              T
        #
        # Since the order of the production rules sequence is the preorder traversal
        # of the parsing tree, the next symbol to parse is the 'T' on the left side
        # of the above parsing tree. Only production rule with left hand side symbol
        # T are valid production rule.
        # Thus, for grammar with production rules:
        # 'S -> S "+" T'
        # 'S -> T'
        # 'T -> "(" S ")"'
        # 'T -> "x"'
        # Appending the first two production rules will create invalid state, with
        # prior probabilities nan. The last two production rules can be appended
        # and will create new states, with equal prior probabilities.
        expected_new_states = [
            None,
            None,
            states.ProductionRulesState(
                self._strings_to_production_rules([
                    'S -> S "+" T',
                    'S -> T',
                    'T -> "(" S ")"',
                ])),
            states.ProductionRulesState(
                self._strings_to_production_rules([
                    'S -> S "+" T',
                    'S -> T',
                    'T -> "x"',
                ])),
        ]

        policy = policies.ProductionRuleAppendPolicy(grammar=self.grammar)
        new_states, action_probs = policy.get_new_states_probs(state)

        np.testing.assert_allclose(action_probs, [np.nan, np.nan, 0.5, 0.5])
        self.assertEqual(len(new_states), len(expected_new_states))
        for new_state, expected_new_state in zip(new_states,
                                                 expected_new_states):
            self.assertEqual(new_state, expected_new_state)
예제 #4
0
 def test_init_stack_invalid(self):
     with self.assertRaisesRegexp(
             ValueError, 'stack is expected to be list, '
             'GrammarLhsStack or None, but got '
             '<class \'str\'>'):
         states.ProductionRulesState(production_rules_sequence=[],
                                     stack='foo')
예제 #5
0
 def test_init_stack_list(self):
   state = states.ProductionRulesState(
       production_rules_sequence=[], stack=['T', 'R'])
   # Use assertIs to check exact type rather than assertIsInstance.
   # https://docs.python.org/2/library/unittest.html#unittest.TestCase.assertIsInstance
   self.assertIs(type(state._stack), postprocessor.GrammarLhsStack)
   self.assertEqual(state._stack.to_list(), ['T', 'R'])
예제 #6
0
 def test_generate_history(self):
     production_rules_sequence = self._strings_to_production_rules(
         ['S -> S "+" T', 'S -> T', 'T -> "x"'])
     state = states.ProductionRulesState(
         production_rules_sequence=production_rules_sequence)
     self.assertListEqual(state.generate_history(),
                          ['S + T', 'T + T', 'x + T'])
예제 #7
0
 def test_is_valid_to_append_init_stack(self):
   state = states.ProductionRulesState(
       production_rules_sequence=[], stack=['S'])
   # The current stack is [S], the next production rule should start with S.
   self.assertTrue(state.is_valid_to_append(
       self.production_rules_dict['S -> T']))
   self.assertFalse(state.is_valid_to_append(
       self.production_rules_dict['T -> "x"']))
예제 #8
0
 def test_repr_empty(self):
     state = states.ProductionRulesState(
         production_rules_sequence=[], stack=nltk.grammar.nonterminals('S'))
     self.assertEqual(
         str(state), 'ProductionRulesState [symbols: , '
         'length_production_rules_sequence: 0, '
         'stack top: S, '
         'num_terminals / num_symbols: 0 / 0, '
         'terminal_ratio:  nan]')
예제 #9
0
 def test_is_valid_to_append(self):
   production_rules_sequence = self._strings_to_production_rules(
       ['S -> S "+" T'])
   state = states.ProductionRulesState(
       production_rules_sequence=production_rules_sequence)
   # The current stack is [T, S], the next production rule should start with S.
   self.assertTrue(state.is_valid_to_append(
       self.production_rules_dict['S -> T']))
   self.assertFalse(state.is_valid_to_append(
       self.production_rules_dict['T -> "x"']))
예제 #10
0
  def test_is_terminal_end_without_terminal_rule(self):
    production_rules_sequence = self._strings_to_production_rules([
        'S -> S "+" T',
        'S -> T',
        'T -> "x"',
        'T -> "x"',
    ])

    state = states.ProductionRulesState(
        production_rules_sequence=production_rules_sequence)
    self.assertTrue(state.is_terminal())
예제 #11
0
 def test_copy(self):
   production_rules_sequence = self._strings_to_production_rules([
       'S -> S "+" T',
       'S -> T',
   ])
   state = states.ProductionRulesState(
       production_rules_sequence=production_rules_sequence)
   new_state = state.copy()
   self.assertEqual(state, new_state)
   # Change in state will not affect new_state.
   state.append_production_rule(self.production_rules_dict['T -> "x"'])
   self.assertLen(state.production_rules_sequence, 3)
   self.assertLen(new_state.production_rules_sequence, 2)
예제 #12
0
  def test_evaluate_not_terminal_without_default_value(self):
    not_terminal_state = states.ProductionRulesState(
        production_rules_sequence=[])
    not_terminal_state.is_terminal = mock.MagicMock(return_value=False)
    reward = rewards.RewardBase(allow_nonterminal=False, default_value=None)
    with self.assertRaisesRegexp(ValueError,
                                 'allow_nonterminal is False and '
                                 'default_value is None, but state is not '
                                 'terminal'):
      reward.evaluate(not_terminal_state)

    # ValueError will not be raised if default value is set.
    reward.set_default_value(42)
    self.assertAlmostEqual(reward.evaluate(not_terminal_state), 42.)
예제 #13
0
 def test_repr_expression_not_terminal(self):
     production_rules_sequence = self._strings_to_production_rules([
         'S -> S "+" T',
         'S -> T',
         'T -> "x"',
     ])
     state = states.ProductionRulesState(
         production_rules_sequence=production_rules_sequence)
     self.assertEqual(
         str(state), 'ProductionRulesState [symbols: x + T, '
         'length_production_rules_sequence: 3, '
         'stack top: T, '
         'num_terminals / num_symbols: 2 / 3, '
         'terminal_ratio: 0.67]')
예제 #14
0
 def test_repr_without_terminal_rule(self):
     production_rules_sequence = self._strings_to_production_rules([
         'S -> S "+" T',
         'S -> T',
         'T -> "x"',
         'T -> "x"',
     ])
     state = states.ProductionRulesState(
         production_rules_sequence=production_rules_sequence)
     self.assertEqual(
         str(state), 'ProductionRulesState [symbols: x + x, '
         'length_production_rules_sequence: 4, '
         'stack top: Nothing, '
         'num_terminals / num_symbols: 3 / 3, '
         'terminal_ratio: 1.00]')
예제 #15
0
 def test_append_production_rule_invalid(self):
   production_rules_sequence = self._strings_to_production_rules([
       'S -> S "+" T',
       'S -> T',
   ])
   state = states.ProductionRulesState(
       production_rules_sequence=production_rules_sequence)
   # The current stack is [T, T], the next production rule should start with T.
   # A ValueError will be raised if the production rule to append does not have
   # left hand side symbol T.
   with self.assertRaisesRegexp(
       ValueError,
       r'The left hand side symbol of production rule S -> T does not match '
       r'the top symbol in the grammar left hand side stack \(T\)'):
     state.append_production_rule(self.production_rules_dict['S -> T'])
예제 #16
0
 def test_init_stack_none(self):
   # _stack attribute should be created from the input
   # production_rules_sequence.
   production_rules_sequence = self._strings_to_production_rules([
       'S -> S "+" T',
       'S -> R',
   ])
   state = states.ProductionRulesState(
       production_rules_sequence=production_rules_sequence, stack=None)
   # Use assertIs to check exact type rather than assertIsInstance.
   # https://docs.python.org/2/library/unittest.html#unittest.TestCase.assertIsInstance
   self.assertIs(type(state._stack), postprocessor.GrammarLhsStack)
   # Add 'S -> S "+" T': first push 'T', then push 'S' to the stack.
   # Stack ['T', 'S']
   # Add 'S -> R': pop 'S', then push 'R' to the stack.
   # Stack ['T', R']
   self.assertEqual(state._stack.to_list(), ['T', 'R'])
예제 #17
0
  def test_is_terminal_end_with_terminal_rule(self):
    production_rules_sequence = self._strings_to_production_rules([
        'S -> S "+" T',
        'S -> T',
        'T -> "x"',
        'T -> "x"',
        # NOTE(leeley): I want to mimic the procedure in the grammar variational
        # autoencoder to use DUMMY_PRODUCTION_RULE as the padding rule.
        # The generation of symbols by grammar production rules sequence will
        # stop if all the symbols are terminal. For the grammar rules in this
        # unittest, the last one dummy rules are actually not used.
        constants.DUMMY_PRODUCTION_RULE,
    ])

    state = states.ProductionRulesState(
        production_rules_sequence=production_rules_sequence)
    self.assertTrue(state.is_terminal())
예제 #18
0
 def test_get_expression_not_terminal(self):
   production_rules_sequence = self._strings_to_production_rules([
       'S -> S "+" T',
       'S -> T',
       'T -> "x"',
   ])
   # Parsing tree:
   #  S
   #  |
   #  S "+" T
   #  |
   #  T
   #  |
   # "x"
   # Expression (non-terminal):
   # x + T
   state = states.ProductionRulesState(
       production_rules_sequence=production_rules_sequence)
   self.assertEqual(state.get_expression(), 'x + T')
   self.assertEqual(state.get_expression(coefficients={'x': 42}), '42 + T')
예제 #19
0
 def test_append_production_rule(self):
   production_rules_sequence = self._strings_to_production_rules([
       'S -> S "+" T',
       'S -> T',
   ])
   state = states.ProductionRulesState(
       production_rules_sequence=production_rules_sequence)
   self.assertLen(state.production_rules_sequence, 2)
   # The grammar left hand side symbol stack is [T, T], the next production
   # rule should start with T.
   state.append_production_rule(self.production_rules_dict['T -> "x"'])
   self.assertLen(state.production_rules_sequence, 3)
   # The grammar left hand side symbol stack is [T], the next production rule
   # should start with T.
   state.append_production_rule(self.production_rules_dict['T -> "x"'])
   self.assertLen(state.production_rules_sequence, 4)
   # The grammar left hand side symbol stack is empty, the next production rule
   # can only be the dummy production rule.
   state.append_production_rule(
       self.production_rules_dict[constants.DUMMY_PRODUCTION_RULE])
   self.assertLen(state.production_rules_sequence, 5)
예제 #20
0
def generate_expression(sess,
                        grammar,
                        max_length,
                        symbolic_properties_dict=None,
                        numerical_values=None,
                        clip_value_min=None,
                        clip_value_max=None,
                        random_state=None,
                        sampling=False,
                        empirical_distribution_df=None,
                        tail_length=None,
                        partial_sequence=None,
                        input_variable_scope='serving_input'):
  """Generates an expression by a trained partial sequence model.

  Args:
    sess: tf.Session, the session contains the trained model to predict next
        production rule from input partial sequence. If None, each step will be
        selected randomly.
    grammar: arithmetic_grammar.Grammar object.
    max_length: Integer, the max length of production rule sequence.
    symbolic_properties_dict: Dict, the keys are the symbolic properties used as
        conditions. Values are the corresponding desired values of the symbolic
        properties.
    numerical_values: Float numpy array with shape [num_numerical_points]. The
        value of expression evaluated on points.
    clip_value_min: Float, the minimum value to clip by.
    clip_value_max: Float, the maximum value to clip by.
    random_state: Numpy RandomState. Default None.
    sampling: Boolean, whether to do sampling. If True, the next production rule
        will be sampled from the probabilities predicted by the partial sequence
        model. If False, the generator deterministically chooses the next
        production rule with highest probability at each step.
    empirical_distribution_df: Pandas dataframe recording the empirical
        probability distribution of the next production rule under various
        settings of partial_sequence_indices and conditions. Each row gives the
        probability distribution of the next production rule corresponding to
        one particular partial_sequence (or a tail of it), and conditions such
        as leading_at_0 and leading_at_inf. The partial_sequence (or a tail of
        it) and conditions are placed in the dataframe as multi-indices. The
        columns are the probabilities of the next production rule (the rules are
        represented by indices), e.g.:
        partial_sequence_indices  leading_at_0  leading_at_inf  0  1  2   ...
                1_4_3_5                -1            -1         0  0  0.5 ...
    tail_length: Integer, length of the tail partial sequence used for
        generating the empirical distribution dataframe. If None, the entire
        partial sequence is used.
    partial_sequence: List of integers, the partial sequence to start the
        generation. Default None, the generation will start from scratch.
    input_variable_scope: String, the variable scope for the tensor in input
        features. Default 'serving_input'. Used when sess is not None.

  Returns:
    Dict with the following keys:
      * 'expression_string': String.
      * 'is_terminal': Boolean, whether all the symbols in the generated
            expression are terminal.
      * 'production_rule_sequence': List of integers, the indices of generated
            sequence of production rules in grammar.
      * 'history': List of strings, the history of expression generation.

  Raises:
    ValueError: The proposed probability distribution of the next production
        rule is invalid.
  """
  if sess is None:
    logging.info('Input sess is None, '
                 'each step in the generator will be selected randomly.')

  if random_state is None:
    random_state = np.random.RandomState()

  conditions = {}
  if symbolic_properties_dict is not None:
    conditions.update({
        key: np.array([value], dtype=np.float32)
        for key, value in six.iteritems(symbolic_properties_dict)
    })
  if numerical_values is not None:
    conditions['numerical_values'] = np.atleast_2d(
        np.clip(numerical_values, clip_value_min, clip_value_max)
        ).astype(np.float32)

  partial_sequence = _get_starting_partial_sequence(
      partial_sequence=partial_sequence,
      grammar=grammar,
      random_state=random_state)

  # NOTE(leeley): ProductionRulesState records (partial) expression by
  # non-terminal symbol stack and sequence of production rule objects.
  # partial_sequence is used to record the indices of production rules in
  # grammar instead of production rule objects.
  state = states.ProductionRulesState(
      production_rules_sequence=[
          grammar.prod_rules[production_rule_index]
          for production_rule_index in partial_sequence
      ])

  while len(partial_sequence) < max_length and not state.is_terminal():
    next_production_rule_mask = grammar.masks[
        grammar.lhs_to_index[state.stack_peek()]]

    if sess is None:
      if empirical_distribution_df is None:
        next_production_rule_distribution = next_production_rule_mask
      else:
        current_partial_sequence_indices = '_'.join(map(str, partial_sequence))
        next_production_rule_distribution = (
            get_next_production_rule_distribution(
                empirical_distribution_df,
                tail_length,
                current_partial_sequence_indices,
                symbolic_properties_dict,
                next_production_rule_mask))
        logging.info('Current partial sequence indices: %s.',
                     current_partial_sequence_indices)
        logging.info('Symbolic properties dict: %s.', symbolic_properties_dict)
        logging.info('Next production rule probabilities: %s.',
                     next_production_rule_distribution)
        # If there is no rule found, leave the sequence unterminated.
        if next_production_rule_distribution is None:
          break
      next_production_rule = generate_next_production_rule_randomly(
          num_production_rules=grammar.num_production_rules,
          next_production_rule_distribution=next_production_rule_distribution,
          random_state=random_state)
    else:
      next_production_rule = generate_next_production_rule_from_model(
          sess=sess,
          max_length=max_length,
          partial_sequence=partial_sequence,
          next_production_rule_mask=next_production_rule_mask,
          conditions=conditions,
          sampling=sampling,
          random_state=random_state,
          input_variable_scope=input_variable_scope)

    # Update the partial sequence as input features for next round.
    partial_sequence.append(next_production_rule)
    # Update the expression state.
    state.append_production_rule(grammar.prod_rules[next_production_rule])

  return {
      'expression_string': state.get_expression(),
      'is_terminal': state.is_terminal(),
      'production_rule_sequence': partial_sequence,
      'history': state.generate_history(),
  }
예제 #21
0
 def test_evaluate_not_implemented(self):
   state = states.ProductionRulesState(production_rules_sequence=[])
   reward = rewards.RewardBase()
   with self.assertRaisesRegexp(NotImplementedError,
                                'Must be implemented by subclass'):
     reward.evaluate(state)
예제 #22
0
 def test_evaluate_not_terminal_with_default_value(self):
   not_terminal_state = states.ProductionRulesState(
       production_rules_sequence=[])
   not_terminal_state.is_terminal = mock.MagicMock(return_value=False)
   reward = rewards.RewardBase(allow_nonterminal=False, default_value=42)
   self.assertAlmostEqual(reward.evaluate(not_terminal_state), 42)
예제 #23
0
 def test_stack_peek(self):
   production_rules_sequence = self._strings_to_production_rules(
       ['S -> S "+" T'])
   state = states.ProductionRulesState(
       production_rules_sequence=production_rules_sequence)
   self.assertEqual(state.stack_peek(), 'S')
예제 #24
0
 def test_stack_peek_init_stack(self):
   state = states.ProductionRulesState(
       production_rules_sequence=[], stack=['S'])
   self.assertEqual(state.stack_peek(), 'S')