def test_multi_aspects_per_edu(self):
     self._multi_aspect_per_edu_tree()
     rules_extractor = EDUTreeRulesExtractor(
         only_hierarchical_relations=False)
     rules = rules_extractor.extract(self.discourse_tree,
                                     [559, 560, 561, 562, 563],
                                     doc_id=1)
     expected_rules = {
         1: [
             EDURelation(edu1=560,
                         edu2=559,
                         relation_type='Elaboration',
                         gerani=0.63),
             EDURelation(edu1=561,
                         edu2=559,
                         relation_type='Elaboration',
                         gerani=0.53),
             EDURelation(edu1=562,
                         edu2=559,
                         relation_type='same-unit',
                         gerani=0.3),
             EDURelation(edu1=563,
                         edu2=559,
                         relation_type='same-unit',
                         gerani=0.2),
             EDURelation(edu1=561,
                         edu2=560,
                         relation_type='Elaboration',
                         gerani=0.75),
             EDURelation(edu1=562,
                         edu2=560,
                         relation_type='same-unit',
                         gerani=0.4),
             EDURelation(edu1=563,
                         edu2=560,
                         relation_type='same-unit',
                         gerani=0.3),
             EDURelation(edu1=562,
                         edu2=561,
                         relation_type='same-unit',
                         gerani=0.5),
             EDURelation(edu1=563,
                         edu2=561,
                         relation_type='same-unit',
                         gerani=0.4),
             EDURelation(edu1=563,
                         edu2=562,
                         relation_type='Elaboration',
                         gerani=0.75)
         ]
     }
     self.assertEqual(rules, expected_rules)
 def test_tree_parsing_and_get_rules_all(self):
     rules_extractor = EDUTreeRulesExtractor(
         weight_type=['gerani'], only_hierarchical_relations=False)
     rules = rules_extractor.extract(self.discourse_tree,
                                     [513, 514, 515, 516, 517], 1)
     expected_rules = {
         1: [
             EDURelation(edu1=514,
                         edu2=513,
                         relation_type='Elaboration',
                         gerani=0.8),
             EDURelation(edu1=515,
                         edu2=513,
                         relation_type='same-unit',
                         gerani=0.42),
             EDURelation(edu1=516,
                         edu2=513,
                         relation_type='same-unit',
                         gerani=0.33),
             EDURelation(edu1=517,
                         edu2=513,
                         relation_type='same-unit',
                         gerani=0.25),
             EDURelation(edu1=515,
                         edu2=514,
                         relation_type='same-unit',
                         gerani=0.5),
             EDURelation(edu1=516,
                         edu2=514,
                         relation_type='same-unit',
                         gerani=0.42),
             EDURelation(edu1=517,
                         edu2=514,
                         relation_type='same-unit',
                         gerani=0.33),
             EDURelation(edu1=516,
                         edu2=515,
                         relation_type='Elaboration',
                         gerani=0.6),
             EDURelation(edu1=517,
                         edu2=515,
                         relation_type='Elaboration',
                         gerani=0.52),
             EDURelation(edu1=517,
                         edu2=516,
                         relation_type='Joint',
                         gerani=0.7)
         ]
     }
     self.assertEqual(rules, expected_rules)
 def test_bfs_for_several_aspects_in_one_edu(self):
     rules_extractor = EDUTreeRulesExtractor(
         weight_type=['gerani'], only_hierarchical_relations=True)
     rules = rules_extractor.extract(self.discourse_tree,
                                     [513, 514, 515, 516, 517], 1)
     expected_rules = {
         1: [
             EDURelation(edu1=514,
                         edu2=513,
                         relation_type='Elaboration',
                         gerani=0.8),
             EDURelation(edu1=516,
                         edu2=515,
                         relation_type='Elaboration',
                         gerani=0.6),
             EDURelation(edu1=517,
                         edu2=515,
                         relation_type='Elaboration',
                         gerani=0.52)
         ]
     }
     self.assertEqual(rules, expected_rules)
def _with_rules():
    return [
        EDURelation(0, 1, weight=10, relation_type="Elaboration"),
        EDURelation(0, 1, weight=20, relation_type="Elaboration"),
        EDURelation(0, 1, weight=11, relation_type="Contrast"),
        EDURelation(0, 1, weight=22, relation_type="Contrast"),
        EDURelation(2, 1, weight=10, relation_type="Elaboration"),
        EDURelation(0, 1, weight=30, relation_type="Elaboration"),
    ]
def filter_top_n_rules(rules: List[EDURelation],
                       aggregation_fn: Callable = None,
                       top_n: int = 1) -> List[EDURelation]:
    if aggregation_fn is None:
        aggregation_fn = partial(max)

    # sort for groupby
    rules = sorted(rules, key=lambda relation: relation[:3])

    rules_filtered = [
        EDURelation(*(group + (aggregation_fn([r.weight
                                               for r in relations]), )))
        for group, relations in groupby(rules,
                                        key=lambda relation: relation[:3])
    ]

    if top_n is None:
        return rules_filtered
    else:
        return sorted(rules, key=attrgetter("weight"), reverse=True)[:top_n]
def filter_rules_gerani(rules: List[EDURelation],
                        aggregation_fn: Callable = None) -> List[EDURelation]:
    """
    Filter rules by its weight for each Discourse Tree.
    We can have many rules per tree.

    Parameters
    ----------
        :param rules: List of relation tuples
        :param aggregation_fn: Function to aggregate repeated tuple's weight.
        By default it is max weight.

    Returns
    -------
    rules_filtered : list
        List of rules/relations between aspects with their maximum
        gerani weights.

        Examples
        [
            EDURelation(edu1=u'screen', edu2=u'phone',
                relation_type='Elaboration', weight=1.38),
            EDURelation(edu1=u'speaker', edu2=u'sound',
                relation_type='Elaboration', weight=0.29),
            EDURelation(edu1=u'speaker', edu2=u'sound',
                relation_type='Elaboration', weight=0.21)
        ]
    """
    if aggregation_fn is None:
        aggregation_fn = partial(max)

    rules = sorted(rules, key=lambda relation: relation[:3])
    return [
        EDURelation(*(group + (aggregation_fn([r.weight
                                               for r in relations]), )))
        for group, relations in groupby(rules,
                                        key=lambda relation: relation[:3])
    ]
    return [
        EDURelation(0, 1, weight=10, relation_type="Elaboration"),
        EDURelation(0, 1, weight=20, relation_type="Elaboration"),
        EDURelation(0, 1, weight=11, relation_type="Contrast"),
        EDURelation(0, 1, weight=22, relation_type="Contrast"),
        EDURelation(2, 1, weight=10, relation_type="Elaboration"),
        EDURelation(0, 1, weight=30, relation_type="Elaboration"),
    ]


@pytest.mark.parametrize(
    "rules, aggregation_fn",
    [
        (
            [
                EDURelation(0, 1, weight=30, relation_type="Elaboration"),
                EDURelation(0, 1, weight=22, relation_type="Contrast"),
                EDURelation(2, 1, weight=10, relation_type="Elaboration"),
            ],
            partial(max),
        ),
        (
            [
                EDURelation(0, 1, weight=10, relation_type="Elaboration"),
                EDURelation(0, 1, weight=11, relation_type="Contrast"),
                EDURelation(2, 1, weight=10, relation_type="Elaboration"),
            ],
            partial(min),
        ),
    ],
)