Ejemplo n.º 1
0
class ClusterObserver(ClusterObservable):
    _sequences_builder: SequencesBuilder
    _show_cluster_centers: bool
    _show_cluster_datapoints: bool
    _show_spring_lines: bool
    _show_spline_arrows: bool
    _projection_type: ClusterObserverProjection
    _prop_builder: ObserverPropertiesBuilder
    _n_cluster_centers: int
    _n_sequences: int
    _sequence_length: int

    _width: int = 640
    _height: int = 480
    _has_temporal_pooler: bool

    def __init__(self, tensor_provider: TensorProvider):
        self._has_temporal_pooler = tensor_provider.has_temporal_pooler()

        self._n_cluster_centers = tensor_provider.n_cluster_centers()
        self._n_sequences = tensor_provider.n_sequences()
        self._sequence_length = tensor_provider.sequence_length()

        self.cluster_centers = ClusterCentersDataBuilder(tensor_provider)
        self.fdsim = FDsimDataBuilder(tensor_provider)
        self.n_dims = 2
        self.pca = PcaDataBuilder(tensor_provider)
        self.spring_lines = SpringLinesBuilder(tensor_provider)
        self.spline_arrows = SplineArrowsBuilder(tensor_provider)
        self._prop_builder = ObserverPropertiesBuilder()
        self._sequences_builder = SequencesBuilder(tensor_provider)
        self._show_cluster_centers = True
        self._show_cluster_datapoints = True
        self._show_spring_lines = self._has_temporal_pooler
        self._show_spline_arrows = self._has_temporal_pooler
        self._projection_type = ClusterObserverProjection.PCA
        # self._pca_transformer = PcaTransformer()

    def get_data(self) -> ClusterObserverData:
        # if self._projection_type == ClusterObserverProjection.PCA:
        #     self.pca.update_pca_transformer(self._pca_transformer)

        return ClusterObserverData(
            cluster_centers=self.cluster_centers.get_data()
            if self._show_cluster_centers else None,
            fdsim=self.fdsim.get_data(),
            n_dims=self.n_dims,
            n_cluster_centers=self._n_cluster_centers,
            n_sequences=self._n_sequences,
            sequence_length=self._sequence_length,
            pca=self.pca.get_data(self.n_dims, self._show_cluster_datapoints)
            if self._projection_type == ClusterObserverProjection.PCA else
            None,
            projection_type="PCA" if self._projection_type
            == ClusterObserverProjection.PCA else "FDsim",
            width=self._width,
            height=self._height,
            spring_lines=self.spring_lines.get_data()
            if self._show_spring_lines else None,
            sequences=self._sequences_builder.get_data(),
            spline_arrows=self.spline_arrows.get_data()
            if self._show_spline_arrows else None,
        )

    def get_properties(self) -> List[ObserverPropertiesItem]:
        def update_projection_dim(value):
            if int(value) == 0:
                self.n_dims = 2
            else:
                self.n_dims = 3
            return value

        def update_show_cluster_centers(value: bool) -> bool:
            self._show_cluster_centers = value
            return value

        def update_show_cluster_datapoints(value: bool) -> bool:
            self._show_cluster_datapoints = value
            return value

        def update_show_spring_lines(value: bool) -> bool:
            self._show_spring_lines = value
            return value

        def update_show_spline_arrows(value: bool) -> bool:
            self._show_spline_arrows = value
            return value

        def format_projection_type(value: ClusterObserverProjection) -> int:
            if value == ClusterObserverProjection.PCA:
                return 0
            elif value == ClusterObserverProjection.FD_SIM:
                return 1
            else:
                raise IllegalArgumentException(
                    f'Unrecognized projection {value}')

        def update_projection_type(value):
            old_type = self._projection_type
            if int(value) == 0:
                self._projection_type = ClusterObserverProjection.PCA
            elif int(value) == 1:
                self._projection_type = ClusterObserverProjection.FD_SIM
            else:
                raise IllegalArgumentException(
                    f'Unrecognized projection {value}')

            if self._projection_type == ClusterObserverProjection.PCA and old_type != ClusterObserverProjection.PCA:
                self.pca.reset()

            return value

        def reset_projection(value):
            if self._projection_type == ClusterObserverProjection.PCA:
                self.pca.reset()
            elif self._projection_type == ClusterObserverProjection.FD_SIM:
                self.fdsim.reset()
            else:
                raise IllegalArgumentException(
                    f'Unrecognized projection {value}')

        def update_width(value):
            self._width = int(value)
            return value

        def update_height(value):
            self._height = int(value)
            return value

        def yield_props():
            yield ObserverPropertiesItem(
                'Projection',
                'select',
                format_projection_type(self._projection_type),
                update_projection_type,
                select_values=[
                    ObserverPropertiesItemSelectValueItem('PCA'),
                    ObserverPropertiesItemSelectValueItem('Force simulation')
                ],
                state=ObserverPropertiesItemState.ENABLED
                if self._has_temporal_pooler else
                ObserverPropertiesItemState.READ_ONLY)

            yield ObserverPropertiesItem(
                'Projection dimensionality',
                'select',
                0 if self.n_dims == 2 else 1,
                update_projection_dim,
                select_values=[
                    ObserverPropertiesItemSelectValueItem('2D'),
                    ObserverPropertiesItemSelectValueItem('3D')
                ])

            yield ObserverPropertiesItem('Reset Projection', 'button', "Reset",
                                         reset_projection)

            # Enablers
            yield self._prop_builder.checkbox('Show Cluster Centers',
                                              self._show_cluster_centers,
                                              update_show_cluster_centers)
            yield self._prop_builder.checkbox(
                'Show Cluster Datapoints',
                self._show_cluster_datapoints if self._projection_type
                == ClusterObserverProjection.PCA else False,
                update_show_cluster_datapoints,
                state=ObserverPropertiesItemState.ENABLED
                if self._projection_type == ClusterObserverProjection.PCA else
                ObserverPropertiesItemState.DISABLED)
            yield self._prop_builder.checkbox(
                'Show Spring Lines',
                self._show_spring_lines
                if self._has_temporal_pooler else False,
                update_show_spring_lines,
                state=ObserverPropertiesItemState.ENABLED
                if self._has_temporal_pooler else
                ObserverPropertiesItemState.DISABLED)
            yield self._prop_builder.checkbox(
                'Show Spline Arrows',
                self._show_spline_arrows
                if self._has_temporal_pooler else False,
                update_show_spline_arrows,
                state=ObserverPropertiesItemState.ENABLED
                if self._has_temporal_pooler else
                ObserverPropertiesItemState.DISABLED)

            # Cluster Centers
            yield self._prop_builder.collapsible_header(
                'Cluster Centers', default_is_expanded=True)
            yield from self.cluster_centers.get_properties(
                enabled=self._show_cluster_centers)

            # Spline Arrows
            yield self._prop_builder.collapsible_header(
                'Spline Arrows', default_is_expanded=True)
            yield from self.spline_arrows.get_properties(
                enabled=self._show_spline_arrows)

            # Canvas
            yield self._prop_builder.collapsible_header(
                'Canvas', default_is_expanded=True)
            yield ObserverPropertiesItem('Width', 'number', self._width,
                                         update_width)
            yield ObserverPropertiesItem('Height', 'number', self._height,
                                         update_height)

            # Force Simulation
            if self._has_temporal_pooler:
                yield ObserverPropertiesItem('Force simulation',
                                             'collapsible_header', True,
                                             lambda _: "True")
                yield from self.fdsim.get_properties()

        return list(yield_props())
Ejemplo n.º 2
0
class ObserverView(PropertiesObservable):
    """A node that encompasses all the model's observables and passes them on to the observer system."""
    _strip_observer_name_prefix: str

    _observables: Dict[str, Observable]
    _first_show: bool = True

    def __init__(self,
                 name: str,
                 observer_system: ObserverSystem,
                 strip_observer_name_prefix: str = ''):
        self._strip_observer_name_prefix = strip_observer_name_prefix
        self.name = name
        self._observer_system = observer_system
        self._observables = {}
        observer_system.signals.window_closed.connect(self.on_window_closed)
        self._prop_builder = ObserverPropertiesBuilder(self)

    def _persist(self):
        self._observer_system.persist_observer_values(self.name, self)

    def on_window_closed(self, observer_name: str):
        if observer_name in self._observables:
            self._observer_system.unregister_observer(observer_name, False)
            self._persist()

    def close(self):
        self._unregister_observers()
        self._observer_system.unregister_observer(self.name, True)

    def set_observables(self, observables: Dict[str, Observable]):
        self._unregister_observers()
        self._observables = observables
        # default is no observers visible
        # self._register_observers()
        if self._first_show:
            self._observer_system.register_observer(self.name, self)
            self._first_show = False

    def _register_observers(self):
        for name, observable in self._observables.items():
            self._observer_system.register_observer(name, observable)

    def _unregister_observers(self):
        for name in self._observables.keys():
            self._observer_system.unregister_observer(name, True)

    def get_properties(self) -> List[ObserverPropertiesItem]:
        def enable_observers_handler(prop_name: str, value: bool):
            if value:
                logger.debug(f"Register observer {name}")
                self._observer_system.register_observer(
                    prop_name, self._observables[prop_name])
            else:
                logger.debug(f"Unregister observer {name}")
                self._observer_system.unregister_observer(prop_name, True)

        def remove_prefix(text: str, prefix: str):
            if text.startswith(prefix):
                return text[len(prefix):]
            else:
                return text

        observers = []
        last_header = ''
        for name, observable in self._observables.items():
            observer_name = remove_prefix(name,
                                          self._strip_observer_name_prefix)
            header = observer_name.split('.')[0]
            observer_name = remove_prefix(observer_name, f'{header}.')
            # add collapsible_header
            if last_header != header:
                last_header = header
                observers.append(
                    self._prop_builder.collapsible_header(header, False))

            observers.append(
                self._prop_builder.checkbox(
                    observer_name,
                    self._observer_system.is_observer_registered(name),
                    partial(enable_observers_handler, name)))

        def set_all():
            self._register_observers()
            self._persist()

        def set_none():
            self._unregister_observers()
            self._persist()

        return [
            self._prop_builder.button('All', set_all),
            self._prop_builder.button('None', set_none),
        ] + observers