def test_joint_sampler(self):
        target_cfg_rules = [
            target_grammar.TargetCfgRule.from_string("ROOT => ##A ##B"),
            target_grammar.TargetCfgRule.from_string("A => a"),
            target_grammar.TargetCfgRule.from_string("A => and ( ##A , ##A )"),
            target_grammar.TargetCfgRule.from_string("B => b"),
            target_grammar.TargetCfgRule.from_string("B => and ( ##B , ##B )"),
        ]
        qcfg_rule_1 = qcfg_rule.QCFGRule(tuple("NT_1 NT_2".split()),
                                         tuple("NT_1 NT_2".split()), 2)
        qcfg_rule_2 = qcfg_rule.QCFGRule(tuple("NT_1 and NT_2".split()),
                                         tuple("and ( NT_1 , NT_2 )".split()),
                                         2)
        qcfg_rule_3 = qcfg_rule.QCFGRule(("A", ), ("a", ), 0)
        qcfg_rule_4 = qcfg_rule.QCFGRule(("B", ), ("b", ), 0)

        sampler = joint_sampler.JointSampler.from_rules(
            target_cfg_rules,
            [qcfg_rule_1, qcfg_rule_2, qcfg_rule_3, qcfg_rule_4])
        sampled_source, sampled_target = sampler.sample()
        print("sampled: %s ### %s" %
              (" ".join(sampled_source), " ".join(sampled_target)))
        # The generated string should always contain these two characters based
        # on the target CFG.
        self.assertIn("A", sampled_source)
        self.assertIn("B", sampled_source)
        self.assertIn("a", sampled_target)
        self.assertIn("b", sampled_target)
Esempio n. 2
0
def _maybe_make_rule(source, target):
  """Canoncalize NT indexes and return QCFGRule."""
  source_nts = _get_non_terminals(source)
  target_nts = _get_non_terminals(target)
  if source_nts != target_nts:
    return None
  arity = len(source_nts)
  source, target = rule_utils.canonicalize_nts(source, target, arity)
  return qcfg_rule.QCFGRule(tuple(source), tuple(target), arity)
def get_number_rules(source, target):
  """Return number seed rules based on exact match."""
  rules = set()

  # First, match numbers without decimals.
  matches = re.findall(r"Number ([0-9]+?) ", target)
  for match in matches:
    if match not in source:
      print("`%s` is not in `%s`." % (match, source))
      continue
    target_rhs = "Number %s" % match
    source_rhs = match
    rules.add(
        qcfg_rule.QCFGRule(
            tuple(source_rhs.split(" ")), tuple(target_rhs.split(" ")), 0))

  # Second, try to match numbers with 0 as decimal.
  matches = re.findall(r"Number ([0-9]+?).0 ", target)
  for match in matches:
    if match not in source:
      continue
    target_rhs = "Number %s.0" % match
    source_rhs = match
    rules.add(
        qcfg_rule.QCFGRule(
            tuple(source_rhs.split(" ")), tuple(target_rhs.split(" ")), 0))

  # Finally, match numbers with decimals.
  matches = re.findall(r"Number ([0-9]+\.[0-9]+?) ", target)
  for match in matches:
    if match not in source:
      print("`%s` is not in `%s`." % (match, source))
      continue
    target_rhs = "Number %s" % match
    source_rhs = match
    rules.add(
        qcfg_rule.QCFGRule(
            tuple(source_rhs.split(" ")), tuple(target_rhs.split(" ")), 0))

  return rules
    def test_joint_sampler_with_score_fn(self):
        target_cfg_rules = [
            target_grammar.TargetCfgRule.from_string("ROOT => ##A ##B"),
            target_grammar.TargetCfgRule.from_string("A => a"),
            target_grammar.TargetCfgRule.from_string("A => and ( ##A , ##A )"),
            target_grammar.TargetCfgRule.from_string("B => b"),
            target_grammar.TargetCfgRule.from_string("B => and ( ##B , ##B )"),
        ]
        qcfg_rule_1 = qcfg_rule.QCFGRule(tuple("NT_1 NT_2".split()),
                                         tuple("NT_1 NT_2".split()), 2)
        qcfg_rule_2 = qcfg_rule.QCFGRule(tuple("NT_1 and NT_2".split()),
                                         tuple("and ( NT_1 , NT_2 )".split()),
                                         2)
        qcfg_rule_3 = qcfg_rule.QCFGRule(("A", ), ("a", ), 0)
        qcfg_rule_4 = qcfg_rule.QCFGRule(("B", ), ("b", ), 0)
        qcfg_rules = [qcfg_rule_1, qcfg_rule_2, qcfg_rule_3, qcfg_rule_4]

        config = test_utils.get_test_config()
        wrapper = inference_wrapper.InferenceWrapper(qcfg_rules, config,
                                                     target_cfg_rules)
        wrapper.compute_application_scores(temperature=1, nonterminal_bias=1)

        def score_fn(parent_rule, nt_idx, child_rule):
            scores = np.exp(wrapper.application_scores)
            rhs_idx = wrapper.rhs_emb_idx_map[(parent_rule, nt_idx)]
            lhs_idx = wrapper.lhs_emb_idx_map[child_rule]
            return scores[lhs_idx, rhs_idx]

        sampler = joint_sampler.JointSampler.from_rules(
            target_cfg_rules, qcfg_rules)
        sampled_source, sampled_target = sampler.sample(score_fn=score_fn)
        print("sampled: %s ### %s" %
              (" ".join(sampled_source), " ".join(sampled_target)))
        # The generated string should always contain these two characters based
        # on the target CFG.
        self.assertIn("A", sampled_source)
        self.assertIn("B", sampled_source)
        self.assertIn("a", sampled_target)
        self.assertIn("b", sampled_target)
    def test_joint_sampler_save_and_load(self):
        target_cfg_rules = [
            target_grammar.TargetCfgRule.from_string("ROOT => ##A ##B"),
            target_grammar.TargetCfgRule.from_string("A => a"),
            target_grammar.TargetCfgRule.from_string("A => and ( ##A , ##A )"),
            target_grammar.TargetCfgRule.from_string("B => b"),
            target_grammar.TargetCfgRule.from_string("B => and ( ##B , ##B )"),
        ]
        qcfg_rule_1 = qcfg_rule.QCFGRule(tuple("NT_1 NT_2".split()),
                                         tuple("NT_1 NT_2".split()), 2)
        qcfg_rule_2 = qcfg_rule.QCFGRule(tuple("NT_1 and NT_2".split()),
                                         tuple("and ( NT_1 , NT_2 )".split()),
                                         2)
        qcfg_rule_3 = qcfg_rule.QCFGRule(("A", ), ("a", ), 0)
        qcfg_rule_4 = qcfg_rule.QCFGRule(("B", ), ("b", ), 0)

        sampler = joint_sampler.JointSampler.from_rules(
            target_cfg_rules,
            [qcfg_rule_1, qcfg_rule_2, qcfg_rule_3, qcfg_rule_4])
        sampler_filepath = os.path.join(self.get_temp_dir(), "sampler.json")
        sampler.save(sampler_filepath)
        sampler2 = joint_sampler.JointSampler.from_file(sampler_filepath)
        self.assertEqual(str(sampler), str(sampler2))
def get_datetime_exact_match(source, target):
  """Return seed rules based on exact match for dates and times."""
  rules = set()
  for arg in ("date", "time"):
    regex = r":%s \( (.+?) \)" % arg
    matches = re.findall(regex, target)
    for match in matches:
      source_rhs = string_utils.format_source(match)
      if source_rhs.lower() not in source.lower():
        continue
      target_rhs = match
      rules.add(
          qcfg_rule.QCFGRule(
              tuple(source_rhs.split(" ")), tuple(target_rhs.split(" ")), 0))
  return rules
def get_exact_match_rules(dataset):
    """Return set of rules for terminal sequences in both source and target."""

    matches = set()
    for source_str, target_str in dataset:
        source = source_str.split()
        target = target_str.split()
        matches.update(_find_exact_matches(source, target))

    exact_match_rules = set()
    for match in matches:
        rule = qcfg_rule.QCFGRule(source=tuple(match),
                                  target=tuple(match),
                                  arity=0)
        exact_match_rules.add(rule)

    return exact_match_rules
def get_string_rules(source, target):
  """Return string seed rules based on exact match."""
  rules = set()
  for prefix in ("PersonName", "String", "LocationKeyphrase", "Month",
                 "DayOfWeek"):
    regex = r'%s " (.+?) "' % prefix
    matches = re.findall(regex, target)
    for match in matches:
      source_rhs = string_utils.format_source(match)
      if source_rhs.lower() not in source.lower():
        print("`%s` is not in `%s`." % (source_rhs, source))
        continue
      target_rhs = '%s " %s "' % (prefix, match)
      rules.add(
          qcfg_rule.QCFGRule(
              tuple(source_rhs.split(" ")), tuple(target_rhs.split(" ")), 0))
  return rules
 def test_joint_rule_converter_2(self):
     target_cfg_rules = [
         target_grammar.TargetCfgRule.from_string("ROOT => ##A ##B"),
         target_grammar.TargetCfgRule.from_string("A => a"),
         target_grammar.TargetCfgRule.from_string("A => and ( ##A , ##B )"),
         target_grammar.TargetCfgRule.from_string("B => b"),
         target_grammar.TargetCfgRule.from_string("B => and ( ##B , ##A )"),
     ]
     induced_rule = qcfg_rule.QCFGRule(tuple("NT_1 and NT_2".split()),
                                       tuple("and ( NT_2 , NT_1 )".split()),
                                       2)
     converter = joint_sampler.JointRuleConverter(target_cfg_rules)
     joint_rule = converter.convert(induced_rule)
     self.assertEqual(joint_rule.qcfg_rule, induced_rule)
     self.assertEqual(joint_rule.cfg_nts_set, {
         ("A", "B", "A"),
         ("B", "A", "B"),
     })
Esempio n. 10
0
def _example_to_rule(source_str, target_str):
    """Convert (source, target) example to a QCFGRule."""
    return qcfg_rule.QCFGRule(tuple(source_str.split()),
                              tuple(target_str.split()),
                              arity=0)
Esempio n. 11
0
def _make_rule(nts, source, target):
  """Canoncalize NT indexes and return QCFGRule."""
  arity = len(nts)
  source, target = rule_utils.canonicalize_nts(source, target, arity)
  return qcfg_rule.QCFGRule(tuple(source), tuple(target), arity)