def test_joint_sampler(self): target_cfg_rules = [ target_grammar.TargetCfgRule.from_string("ROOT => ##A ##B"), target_grammar.TargetCfgRule.from_string("A => a"), target_grammar.TargetCfgRule.from_string("A => and ( ##A , ##A )"), target_grammar.TargetCfgRule.from_string("B => b"), target_grammar.TargetCfgRule.from_string("B => and ( ##B , ##B )"), ] qcfg_rule_1 = qcfg_rule.QCFGRule(tuple("NT_1 NT_2".split()), tuple("NT_1 NT_2".split()), 2) qcfg_rule_2 = qcfg_rule.QCFGRule(tuple("NT_1 and NT_2".split()), tuple("and ( NT_1 , NT_2 )".split()), 2) qcfg_rule_3 = qcfg_rule.QCFGRule(("A", ), ("a", ), 0) qcfg_rule_4 = qcfg_rule.QCFGRule(("B", ), ("b", ), 0) sampler = joint_sampler.JointSampler.from_rules( target_cfg_rules, [qcfg_rule_1, qcfg_rule_2, qcfg_rule_3, qcfg_rule_4]) sampled_source, sampled_target = sampler.sample() print("sampled: %s ### %s" % (" ".join(sampled_source), " ".join(sampled_target))) # The generated string should always contain these two characters based # on the target CFG. self.assertIn("A", sampled_source) self.assertIn("B", sampled_source) self.assertIn("a", sampled_target) self.assertIn("b", sampled_target)
def _maybe_make_rule(source, target): """Canoncalize NT indexes and return QCFGRule.""" source_nts = _get_non_terminals(source) target_nts = _get_non_terminals(target) if source_nts != target_nts: return None arity = len(source_nts) source, target = rule_utils.canonicalize_nts(source, target, arity) return qcfg_rule.QCFGRule(tuple(source), tuple(target), arity)
def get_number_rules(source, target): """Return number seed rules based on exact match.""" rules = set() # First, match numbers without decimals. matches = re.findall(r"Number ([0-9]+?) ", target) for match in matches: if match not in source: print("`%s` is not in `%s`." % (match, source)) continue target_rhs = "Number %s" % match source_rhs = match rules.add( qcfg_rule.QCFGRule( tuple(source_rhs.split(" ")), tuple(target_rhs.split(" ")), 0)) # Second, try to match numbers with 0 as decimal. matches = re.findall(r"Number ([0-9]+?).0 ", target) for match in matches: if match not in source: continue target_rhs = "Number %s.0" % match source_rhs = match rules.add( qcfg_rule.QCFGRule( tuple(source_rhs.split(" ")), tuple(target_rhs.split(" ")), 0)) # Finally, match numbers with decimals. matches = re.findall(r"Number ([0-9]+\.[0-9]+?) ", target) for match in matches: if match not in source: print("`%s` is not in `%s`." % (match, source)) continue target_rhs = "Number %s" % match source_rhs = match rules.add( qcfg_rule.QCFGRule( tuple(source_rhs.split(" ")), tuple(target_rhs.split(" ")), 0)) return rules
def test_joint_sampler_with_score_fn(self): target_cfg_rules = [ target_grammar.TargetCfgRule.from_string("ROOT => ##A ##B"), target_grammar.TargetCfgRule.from_string("A => a"), target_grammar.TargetCfgRule.from_string("A => and ( ##A , ##A )"), target_grammar.TargetCfgRule.from_string("B => b"), target_grammar.TargetCfgRule.from_string("B => and ( ##B , ##B )"), ] qcfg_rule_1 = qcfg_rule.QCFGRule(tuple("NT_1 NT_2".split()), tuple("NT_1 NT_2".split()), 2) qcfg_rule_2 = qcfg_rule.QCFGRule(tuple("NT_1 and NT_2".split()), tuple("and ( NT_1 , NT_2 )".split()), 2) qcfg_rule_3 = qcfg_rule.QCFGRule(("A", ), ("a", ), 0) qcfg_rule_4 = qcfg_rule.QCFGRule(("B", ), ("b", ), 0) qcfg_rules = [qcfg_rule_1, qcfg_rule_2, qcfg_rule_3, qcfg_rule_4] config = test_utils.get_test_config() wrapper = inference_wrapper.InferenceWrapper(qcfg_rules, config, target_cfg_rules) wrapper.compute_application_scores(temperature=1, nonterminal_bias=1) def score_fn(parent_rule, nt_idx, child_rule): scores = np.exp(wrapper.application_scores) rhs_idx = wrapper.rhs_emb_idx_map[(parent_rule, nt_idx)] lhs_idx = wrapper.lhs_emb_idx_map[child_rule] return scores[lhs_idx, rhs_idx] sampler = joint_sampler.JointSampler.from_rules( target_cfg_rules, qcfg_rules) sampled_source, sampled_target = sampler.sample(score_fn=score_fn) print("sampled: %s ### %s" % (" ".join(sampled_source), " ".join(sampled_target))) # The generated string should always contain these two characters based # on the target CFG. self.assertIn("A", sampled_source) self.assertIn("B", sampled_source) self.assertIn("a", sampled_target) self.assertIn("b", sampled_target)
def test_joint_sampler_save_and_load(self): target_cfg_rules = [ target_grammar.TargetCfgRule.from_string("ROOT => ##A ##B"), target_grammar.TargetCfgRule.from_string("A => a"), target_grammar.TargetCfgRule.from_string("A => and ( ##A , ##A )"), target_grammar.TargetCfgRule.from_string("B => b"), target_grammar.TargetCfgRule.from_string("B => and ( ##B , ##B )"), ] qcfg_rule_1 = qcfg_rule.QCFGRule(tuple("NT_1 NT_2".split()), tuple("NT_1 NT_2".split()), 2) qcfg_rule_2 = qcfg_rule.QCFGRule(tuple("NT_1 and NT_2".split()), tuple("and ( NT_1 , NT_2 )".split()), 2) qcfg_rule_3 = qcfg_rule.QCFGRule(("A", ), ("a", ), 0) qcfg_rule_4 = qcfg_rule.QCFGRule(("B", ), ("b", ), 0) sampler = joint_sampler.JointSampler.from_rules( target_cfg_rules, [qcfg_rule_1, qcfg_rule_2, qcfg_rule_3, qcfg_rule_4]) sampler_filepath = os.path.join(self.get_temp_dir(), "sampler.json") sampler.save(sampler_filepath) sampler2 = joint_sampler.JointSampler.from_file(sampler_filepath) self.assertEqual(str(sampler), str(sampler2))
def get_datetime_exact_match(source, target): """Return seed rules based on exact match for dates and times.""" rules = set() for arg in ("date", "time"): regex = r":%s \( (.+?) \)" % arg matches = re.findall(regex, target) for match in matches: source_rhs = string_utils.format_source(match) if source_rhs.lower() not in source.lower(): continue target_rhs = match rules.add( qcfg_rule.QCFGRule( tuple(source_rhs.split(" ")), tuple(target_rhs.split(" ")), 0)) return rules
def get_exact_match_rules(dataset): """Return set of rules for terminal sequences in both source and target.""" matches = set() for source_str, target_str in dataset: source = source_str.split() target = target_str.split() matches.update(_find_exact_matches(source, target)) exact_match_rules = set() for match in matches: rule = qcfg_rule.QCFGRule(source=tuple(match), target=tuple(match), arity=0) exact_match_rules.add(rule) return exact_match_rules
def get_string_rules(source, target): """Return string seed rules based on exact match.""" rules = set() for prefix in ("PersonName", "String", "LocationKeyphrase", "Month", "DayOfWeek"): regex = r'%s " (.+?) "' % prefix matches = re.findall(regex, target) for match in matches: source_rhs = string_utils.format_source(match) if source_rhs.lower() not in source.lower(): print("`%s` is not in `%s`." % (source_rhs, source)) continue target_rhs = '%s " %s "' % (prefix, match) rules.add( qcfg_rule.QCFGRule( tuple(source_rhs.split(" ")), tuple(target_rhs.split(" ")), 0)) return rules
def test_joint_rule_converter_2(self): target_cfg_rules = [ target_grammar.TargetCfgRule.from_string("ROOT => ##A ##B"), target_grammar.TargetCfgRule.from_string("A => a"), target_grammar.TargetCfgRule.from_string("A => and ( ##A , ##B )"), target_grammar.TargetCfgRule.from_string("B => b"), target_grammar.TargetCfgRule.from_string("B => and ( ##B , ##A )"), ] induced_rule = qcfg_rule.QCFGRule(tuple("NT_1 and NT_2".split()), tuple("and ( NT_2 , NT_1 )".split()), 2) converter = joint_sampler.JointRuleConverter(target_cfg_rules) joint_rule = converter.convert(induced_rule) self.assertEqual(joint_rule.qcfg_rule, induced_rule) self.assertEqual(joint_rule.cfg_nts_set, { ("A", "B", "A"), ("B", "A", "B"), })
def _example_to_rule(source_str, target_str): """Convert (source, target) example to a QCFGRule.""" return qcfg_rule.QCFGRule(tuple(source_str.split()), tuple(target_str.split()), arity=0)
def _make_rule(nts, source, target): """Canoncalize NT indexes and return QCFGRule.""" arity = len(nts) source, target = rule_utils.canonicalize_nts(source, target, arity) return qcfg_rule.QCFGRule(tuple(source), tuple(target), arity)