コード例 #1
0
    def test_parse_expressions_to_indices_sequences(self):
        grammar_rules = [
            'S -> S "+" T',  # index 0
            'S -> T',  # index 1
            'T -> "(" S ")"',  # index 2
            'T -> "x"',  # index 3
        ]  # padding rule index 4

        grammar = arithmetic_grammar.Grammar(grammar_rules)
        indices_sequences = grammar.parse_expressions_to_indices_sequences(
            expression_strings=['x + ( x )'], max_length=8)

        np.testing.assert_equal(
            indices_sequences,
            [
                # Expression string: 'x + ( x )'
                # Preorder traversal of parsing tree.
                # S
                # |
                # S '+' T
                # |     |
                # T    '(' S ')'
                # |        |
                # 'x'     'x'
                [
                    0,  # 'S -> S "+" T'
                    1,  # 'S -> T'
                    3,  # 'T -> "x"'
                    2,  # 'T -> "(" S ")"'
                    1,  # 'S -> T'
                    3,  # 'T -> "x"'
                    4,  # Padding dummy production rule.
                    4,  # Padding dummy production rule.
                ]
            ])
コード例 #2
0
    def setUp(self):
        super(GetNextProductionRuleMaskBatchTest, self).setUp()

        self.grammar = arithmetic_grammar.Grammar(
            [
                'S -> S "+" T',  # index 1
                'S -> T',  # index 2
                'T -> "x"',  # index 3
                'T -> "1"',  # index 4
            ],
            padding_at_end=False)  # padding rule index 0
        self.partial_sequences = np.array([
            [1, 0, 0, 0, 0, 0],  # expression 'S + T',
            # the next production rule should start with S.
            [1, 2, 3, 0, 0, 0],  # expression 'x + T'
            # the next production rule should start with T.
            [2, 0, 0, 0, 0, 0],  # expression 'T'
            # the next production rule should start with T.
        ])
        self.partial_sequence_lengths = np.array([1, 3, 1])

        self.expected_next_production_rule_masks = np.array([
            [False, True, True, False,
             False],  # Only allow rules start with S.
            [False, False, False, True,
             True],  # Only allow rules start with T.
            [False, False, False, True,
             True],  # Only allow rules start with T.
        ])
コード例 #3
0
    def test_basic_production_rules(self):
        grammar_rules = [
            'S -> S "+" T',
            'S -> T',
            'T -> "(" S ")"',
            'T -> "x"',
        ]

        grammar = arithmetic_grammar.Grammar(grammar_rules)

        self.assertLen(grammar.prod_rules, 5)
        self.assertEqual(grammar.num_production_rules, 5)
        self.assertEqual(grammar.padding_rule_index, 4)
        self.assertEqual(grammar.start_index.symbol(), 'S')
        self.assertEqual(str(grammar.start_rule), "S -> S '+' T")
        self.assertEqual(grammar.unique_lhs, ['Nothing', 'S', 'T'])
        self.assertEqual(grammar.num_unique_lhs, 3)
        np.testing.assert_allclose(
            grammar.masks,
            [[0., 0., 0., 0., 1.], [1., 1., 0., 0., 0.], [0., 0., 1., 1., 0.]])
        np.testing.assert_allclose(grammar.prod_rule_index_to_lhs_index,
                                   [1, 1, 2, 2, 0])
        self.assertEqual(grammar.prod_rule_rhs_indices,
                         [[1, 2], [2], [1], [], []])
        self.assertEqual(grammar.max_rhs_indices_size, 2)
コード例 #4
0
 def test_input_grammar_rules_contain_padding_dummy_production_rule(self):
     # If dummy production rule exists in the input grammar rules, it will be
     # duplicated with the dummy production rule appended in the
     # arithmetic_grammar.
     with self.assertRaisesRegex(
             ValueError, 'The grammar production rules are not unique.'):
         arithmetic_grammar.Grammar(['foo', 'Nothing -> None'])
コード例 #5
0
    def test_parse_production_rule_sequence_batch(self):
        grammar = arithmetic_grammar.Grammar(
            [
                'S -> S "+" T',  # index 1
                'S -> T',  # index 2
                'T -> "x"',  # index 3
                'T -> "1"',  # index 4
            ],
            padding_at_end=False)  # padding rule index 0
        input_features_tensor = {
            'expression_string':
            tf.constant([
                'x',  # Can be parsed into
                # 'S -> T'        index 2
                # 'T -> "x"'      index 3
                '1 + x',  # Can be parsed into
                # 'S -> S "+" T'  index 1
                # 'S -> T'        index 2
                # 'T -> "1"'      index 4
                # 'T -> "x"'      index 3
            ])
        }
        output_features_tensor = input_ops.parse_production_rule_sequence_batch(
            features=input_features_tensor, max_length=5, grammar=grammar)

        with self.test_session():
            self.assertAllEqual(output_features_tensor['expression_sequence'],
                                [[2, 3, 0, 0, 0], [1, 2, 4, 3, 0]])
            self.assertAllEqual(
                output_features_tensor['expression_sequence_mask'],
                [[True, True, False, False, False],
                 [True, True, True, True, False]])
コード例 #6
0
 def setUp(self):
   super(NextProductionRuleInfoBatchTextSummaryTest, self).setUp()
   self.grammar = arithmetic_grammar.Grammar(
       [
           'S -> S "+" T',  # index 1
           'S -> T',        # index 2
           'T -> "x"',      # index 3
           'T -> "1"',      # index 4
       ],
       padding_at_end=False)  # padding rule index 0
コード例 #7
0
    def test_grammar_to_string(self, padding_at_end, indent, expected_string):
        grammar_rules = [
            'S -> T',
            'T -> "x"',
        ]

        grammar = arithmetic_grammar.Grammar(grammar_rules,
                                             padding_at_end=padding_at_end)

        self.assertEqual(grammar.grammar_to_string(indent=indent),
                         expected_string)
コード例 #8
0
    def test_parse_expressions_to_indices_sequences_invalid_expression_string(
            self):
        grammar_rules = [
            'S -> S "+" T',
            'S -> T',
            'T -> "(" S ")"',
            'T -> "x"',
        ]

        grammar = arithmetic_grammar.Grammar(grammar_rules)
        with self.assertRaisesRegex(ValueError,
                                    'cannot be parsed to production rules'):
            grammar.parse_expressions_to_indices_sequences(
                expression_strings=['x x'], max_length=8)
コード例 #9
0
    def test_parse_expressions_to_indices_sequences_short_max_length(self):
        grammar_rules = [
            'S -> S "+" T',
            'S -> T',
            'T -> "(" S ")"',
            'T -> "x"',
        ]

        grammar = arithmetic_grammar.Grammar(grammar_rules)

        with self.assertRaisesRegex(
                ValueError,
                r'The number of production rules to parse expression .* '
                'can not be greater than max_length'):
            grammar.parse_expressions_to_indices_sequences(
                expression_strings=['x + ( x )'], max_length=2)
コード例 #10
0
 def setUp(self):
     super(ExpressionContextFreeGrammarPostprocessorTest, self).setUp()
     grammar_rules = [
         'S -> S "+" T',
         'S -> T',
         'T -> "(" S ")"',
         'T -> "x"',
     ]
     # Get list of nltk.grammar.Production objects.
     self.grammar = arithmetic_grammar.Grammar(grammar_rules)
     self.prod_rules_dict = {
         k: v
         for k, v in zip(grammar_rules + [constants.DUMMY_PRODUCTION_RULE],
                         self.grammar.prod_rules)
     }
     self.delimiter = ' '
コード例 #11
0
    def test_parse_expressions_to_indices_sequences_input_not_list(self):
        grammar_rules = [
            'S -> S "+" T',
            'S -> T',
            'T -> "(" S ")"',
            'T -> "x"',
        ]

        grammar = arithmetic_grammar.Grammar(grammar_rules)

        with self.assertRaisesRegex(
                ValueError,
                'expression_strings is expected to be list, but got'):
            grammar.parse_expressions_to_indices_sequences(
                # Note the input expression_strings is a string not a list of strings.
                expression_strings='x + ( x )',
                max_length=8)
コード例 #12
0
 def setUp(self):
   super(ProductionRulesStateTest, self).setUp()
   grammar_rules = [
       'S -> S "+" T',
       'S -> T',
       'T -> "(" S ")"',
       'T -> "x"',
       'S -> R',
       'R -> "y"',
   ]
   # Get list of nltk.grammar.Production objects.
   self.grammar = arithmetic_grammar.Grammar(grammar_rules)
   self.production_rules_dict = {
       k: v
       for k, v in zip(grammar_rules + [constants.DUMMY_PRODUCTION_RULE],
                       self.grammar.prod_rules)
   }
コード例 #13
0
 def test_grammar_with_callables(self):
     grammar_rules = [
         'S -> S "+" S',  # index 0
         'S -> S "-" S',  # index 1
         'S -> "FUNCTION1(" P ")"',  # index 2
         'P -> T',  # index 3
         'P -> "1" "+" T',  # index 4
         'S -> T',  # index 5
         'T -> "FUNCTION2(" "x" "," "c" ")"',  # index 6
     ]  # padding rule index 7
     grammar = arithmetic_grammar.Grammar(grammar_rules)
     indices_sequences = grammar.parse_expressions_to_indices_sequences(
         expression_strings=[
             'FUNCTION1( FUNCTION2( x , c ) ) - '
             'FUNCTION2( x , c ) + FUNCTION2( x , c )'
         ],
         max_length=10)
     np.testing.assert_equal(
         indices_sequences,
         [
             # Preorder traversal of parsing tree.
             # S
             # |
             # S                        '+'             S
             # |                                        |
             # S         '-'             S              T
             # |                         |              |
             # 'FUNCTION1(' P ')'        T      'FUNCTION2( x , c )'
             #              |            |
             #              T     'FUNCTION2( x , c )'
             #              |
             # 'FUNCTION2( x , c )'
             [
                 0,  # 'S -> S "+" S'
                 1,  # 'S -> S "-" S'
                 2,  # 'S -> "FUNCTION1(" P ")"'
                 3,  # 'P -> T'
                 6,  # 'T -> "FUNCTION2(" "x" "," "c" ")"'
                 5,  # 'S -> T'
                 6,  # 'T -> "FUNCTION2(" "x" "," "c" ")"'
                 5,  # 'S -> T'
                 6,  # 'T -> "FUNCTION2(" "x" "," "c" ")"'
                 7,  # Padding dummy production rule.
             ]
         ])
コード例 #14
0
    def test_parse_expressions_to_indices_sequences_pad_front_unique_start(
            self):
        grammar_rules = [
            'S -> S "+" T',  # index 2
            'S -> T',  # index 3
            'T -> "(" S ")"',  # index 4
            'T -> "x"',  # index 5
        ]  # padding rule index 0
        # 'O -> S' will be added with index 1.

        grammar = arithmetic_grammar.Grammar(
            grammar_rules,
            padding_at_end=False,
            add_unique_production_rule_to_start=True)
        indices_sequences = grammar.parse_expressions_to_indices_sequences(
            expression_strings=['x + ( x )'], max_length=8)

        np.testing.assert_equal(
            indices_sequences,
            [
                # Expression string: 'x + ( x )'
                # Preorder traversal of parsing tree.
                # O
                # |
                # S
                # |
                # S '+' T
                # |     |
                # T    '(' S ')'
                # |        |
                # 'x'     'x'
                [
                    1,  # 'O -> S'
                    2,  # 'S -> S "+" T'
                    3,  # 'S -> T'
                    5,  # 'T -> "x"'
                    4,  # 'T -> "(" S ")"'
                    3,  # 'S -> T'
                    5,  # 'T -> "x"'
                    0,  # Padding dummy production rule.
                ]
            ])
コード例 #15
0
def load_grammar(grammar_path):
    """Loads context-free grammar from file.

  The grammar used for symbolic regularization has specific setup. The padding
  production rule is the 0-th index production rule. And an unique starting
  production rule O -> S is added as 1-st index production rule. The production
  rules in grammar_path are added after this two rules.

  Args:
    grammar_path: String, the path to the grammar file.

  Returns:
    arithmetic_grammar.Grammar or TokenGrammar object.
  """
    grammar = arithmetic_grammar.Grammar(
        arithmetic_grammar.read_grammar_from_file(filename=grammar_path,
                                                  return_list=True),
        padding_at_end=False,
        add_unique_production_rule_to_start=True)
    return grammar
コード例 #16
0
    def test_parse_expressions_to_tensor_padding_at_end_false(self):
        grammar_rules = [
            'S -> S "+" T',
            'S -> T',
            'T -> "(" S ")"',
            'T -> "x"',
        ]

        grammar = arithmetic_grammar.Grammar(grammar_rules,
                                             padding_at_end=False)

        expression_tensor = grammar.parse_expressions_to_tensor(
            expression_strings=['x + ( x )'], max_length=8)

        np.testing.assert_allclose(
            expression_tensor,
            [
                # Expression string: 'x + ( x )'
                # Preorder traversal of parsing tree.
                # S
                # |
                # S '+' T
                # |     |
                # T    '(' S ')'
                # |        |
                # 'x'     'x'
                [
                    [0., 1., 0., 0., 0.],  # 'S -> S "+" T'
                    [0., 0., 1., 0., 0.],  # 'S -> T'
                    [0., 0., 0., 0., 1.],  # 'T -> "x"'
                    [0., 0., 0., 1., 0.],  # 'T -> "(" S ")"'
                    [0., 0., 1., 0., 0.],  # 'S -> T'
                    [0., 0., 0., 0., 1.],  # 'T -> "x"'
                    [1., 0., 0., 0., 0.],  # Padding dummy production rule.
                    [1., 0., 0., 0., 0.],  # Padding dummy production rule.
                ]
            ])
コード例 #17
0
 def test_invalid_grammar_string_no_space_before_arrow(self):
     with self.assertRaisesRegex(ValueError, 'Unable to parse'):
         # No space between arrow and left hand side symbol.
         arithmetic_grammar.Grammar(['a-> b'])
コード例 #18
0
 def test_invalid_grammar_string_no_space_after_arrow(self):
     # No space between arrow and right hand side symbol.
     # This is a valid input and should not raise error.
     arithmetic_grammar.Grammar(['a ->b'])
コード例 #19
0
 def test_input_grammar_rules_not_change(self):
     grammar_rules = ['S -> T', 'T -> "x"']
     arithmetic_grammar.Grammar(grammar_rules)
     self.assertListEqual(grammar_rules, ['S -> T', 'T -> "x"'])
コード例 #20
0
 def test_invalid_grammar_string_no_arrow(self):
     with self.assertRaisesRegex(ValueError, 'Unable to parse'):
         # Invalid input with no arrow.
         arithmetic_grammar.Grammar(['a b'])
コード例 #21
0
 def test_input_grammar_rules_not_unique(self):
     with self.assertRaisesRegex(
             ValueError, 'The grammar production rules are not unique.'):
         arithmetic_grammar.Grammar(['foo', 'foo'])
コード例 #22
0
 def test_input_grammar_rules_not_list(self):
     with self.assertRaisesRegex(ValueError,
                                 'The input grammar_rules should be list.'):
         arithmetic_grammar.Grammar('foo')
コード例 #23
0
 def test_invalid_grammar_string_no_left_hand_side_symbol(self):
     with self.assertRaisesRegex(ValueError, 'Unable to parse'):
         # Invalid input with no left hand side symbol.
         arithmetic_grammar.Grammar([' -> c'])
コード例 #24
0
 def test_invalid_grammar_string_empty_right_hand_side_symbol(self):
     # No right hand side symbol.
     # This is a valid input and should not raise error.
     arithmetic_grammar.Grammar(['a -> '])