Esempio n. 1
0
    def subset(self, plates_index: Dict[Plate,
                                        List[int]]) -> "EPMeanFieldSubset":
        """Given a dictionary of Plates with a subset of indexes 
        returns the EPMeanFieldSubset that corresponds to the subset of indexes.
        """
        factor_subset_factor = {}
        factor_mean_field_subset = {}
        factor_mean_field_rescale = {}
        for factor, mean_field in self.factor_mean_field.items():
            plate_sizes = VariableData.plate_sizes(mean_field)
            factor_subset_factor[factor] = subset_factor = factor.subset(
                plates_index, plate_sizes)

            mean_field_subset = mean_field.subset(plates_index=plates_index)
            factor_mean_field_subset[subset_factor] = mean_field_subset
            mean_field_size = VariableData.prod(plate_sizes)
            subset_size = VariableData.prod(
                VariableData.plate_sizes(mean_field_subset))
            scale_factor = subset_size / mean_field_size
            factor_mean_field_rescale[subset_factor] = {
                v: scale_factor * mean_field[v].size / message.size
                for v, message in mean_field_subset.items()
            }

        subset_factor_graph = FactorGraph(factor_subset_factor.values())
        return EPMeanFieldSubset(
            subset_factor_graph,
            factor_mean_field_subset,
            factor_mean_field_rescale,
            factor_subset_factor,
            self,
            plates_index,
        )
Esempio n. 2
0
    def update_mean_field(self, mean_field, plates_index=None):
        if plates_index:
            plate_sizes = VariableData.plate_sizes(self)
            for v, new_message in mean_field.items():
                index = v.make_indexes(plates_index, plate_sizes)
                self[v][index] = new_message
        else:
            self.update(mean_field)

        return self
Esempio n. 3
0
    def merge(self, index, mean_field):
        new_dist = dict(self)
        if index:
            plate_sizes = VariableData.plate_sizes(self)
            for v, message in mean_field.items():
                i = v.make_indexes(index, plate_sizes)
                new_dist[v] = new_dist[v].merge(i, message)
        else:
            new_dist.update(mean_field)

        return MeanField(new_dist)
Esempio n. 4
0
    def subset(self, variables=None, plates_index=None):
        cls = type(self) if isinstance(self, MeanField) else MeanField
        variables = variables or self.variables
        if plates_index:
            plate_sizes = VariableData.plate_sizes(self)
            variable_index = (
                (v, v.make_indexes(plates_index, plate_sizes)) for v in variables
            )
            mean_field = dict((v, self[v][index]) for v, index in variable_index)

            return cls(mean_field)

        return cls((v, self[v]) for v in variables if v in self.keys())