def test_frequencies_to_kplus_reaches(self):
     self.assertEqual(ReachPoint.frequencies_to_kplus_reaches([]), [])
     self.assertEqual(ReachPoint.frequencies_to_kplus_reaches([1]), [1])
     self.assertEqual(ReachPoint.frequencies_to_kplus_reaches([2, 1]),
                      [3, 1])
    def _form_venn_diagram_regions(
        self, spends: List[float], max_frequency: int = 1
    ) -> Dict[int, List]:
        """Form primitive Venn diagram regions that contain k+ reaches

        For each subset in the powerset of publishers with nonzero spend,
        computes k+ reaches for those users who are reached by the publishers
        in that subset.

        Args:
            spends:  The hypothetical spend vector, equal in length to the
              number of publishers.  spends[i] is the amount that is spent with
              publisher i. Note that the publishers with 0 spends, i.e. inactive
              publishers, will not be included in the Venn diagram regions.
            max_frequency:  The maximum frequency for which to report reach.
        Returns:
            regions:  A dictionary in which each key is the binary
              representation of each primitive region of the Venn diagram, and
              each value is a list of the k+ reaches in the corresponding
              region.
              Note that the binary representation of a key represents the
              formation of publisher IDs in that primitive region. For example,
              primitive_regions[key] with key = 5 = bin('101') is the region
              which belongs to pub_id-0 and id-2.
              The k+ reaches for a given region is given as a list r[] where
              r[k] is the number of people who were reached AT LEAST k+1 times.
        """
        # Get user counts by spend for each active publisher
        user_counts_by_pub_id = {
            pub_id: self._publishers[pub_id]._publisher_data.user_counts_by_spend(spend)
            for pub_id, spend in enumerate(spends)
            if spend
        }

        if len(user_counts_by_pub_id) > MAX_ACTIVE_PUBLISHERS:
            raise ValueError(
                f"There are {len(user_counts_by_pub_id)} publishers for the Venn "
                f"diagram algorithm. The maximum limit is {MAX_ACTIVE_PUBLISHERS}."
            )

        # Generate the representations of all primitive regions from the
        # powerset of the active publishers, excluding the empty set.
        # Ex: if active_pubs = [0, 2] among all publishers [0, 1, 2], then
        # active_pub_powerset is [[0], [2], [0, 2]]. For the regions from the
        # the powerset of the active publishers, they are: [2^0, 2^2, 2^0 + 2^2]
        active_pubs = list(user_counts_by_pub_id.keys())
        active_pub_powerset = chain.from_iterable(
            combinations(active_pubs, r) for r in range(1, len(active_pubs) + 1)
        )
        regions_with_active_pubs = [
            sum(1 << pub_id for pub_id in pub_ids) for pub_ids in active_pub_powerset
        ]

        # Locate user's region which is represented by the binary representation
        # of a number, and sum the user's impressions.
        user_region = defaultdict(int)
        user_impressions = defaultdict(int)

        for pub_id, user_counts in user_counts_by_pub_id.items():
            for user_id, impressions in user_counts.items():
                # To update the user's located region, we use bit operation here.
                # Ex: For a user reached by publisher id-0 and id-2, it's located
                # at the region with the binary representation = bin('101') = 5.
                # If the user is also reached by publisher id-1, then the updated
                # representation will be bin('111') = 7.
                user_region[user_id] |= 1 << pub_id
                user_impressions[user_id] += impressions

        # Compute the frequencies in the occupied primitive Venn diagram regions
        # with the user counts capped by max_frequency.
        frequencies_by_region = {
            r: [0] * (max_frequency + 1) for r in regions_with_active_pubs
        }
        for user_id, region in user_region.items():
            impressions = min(max_frequency, user_impressions[user_id])
            frequencies_by_region[region][impressions] += 1

        # Compute k+ reaches in the primitive regions. Ignore 0-frequency.
        occupied_regions = set(user_region.values())
        regions = {
            r: ReachPoint.frequencies_to_kplus_reaches(freq[1:])
            if r in occupied_regions
            else freq[1:]  # i.e. zero vector
            for r, freq in frequencies_by_region.items()
        }

        return regions