Esempio n. 1
0
    def tp_sequence_significance(self, source: SequenceSignificanceSource):
        if self.flock_node._unit is None:
            return None
        process = self.flock_node._unit.flock.tp_flock.trained_forward_process
        if process is None:
            return None

        if source == SequenceSignificanceSource.MODEL_FREQUENCY:
            return self.tp_frequent_seq_occurrences()

        if source == SequenceSignificanceSource.SEQ_LIKELIHOODS_ACTIVE:
            tensor_getter = lambda _: process.seq_likelihoods_active
        elif source == SequenceSignificanceSource.SEQ_LIKELIHOODS_CLUSTERS:
            tensor_getter = lambda _: process.seq_likelihoods_clusters
        elif source == SequenceSignificanceSource.SEQ_LIKELIHOODS_EXPLORATION:
            tensor_getter = lambda _: process.seq_likelihoods_exploration
        elif source == SequenceSignificanceSource.SEQ_LIKELIHOODS_GOAL_DIRECTED:
            tensor_getter = lambda _: process.seq_rewards_goal_directed
        elif source == SequenceSignificanceSource.SEQ_LIKELIHOODS_PRIORS_CLUSTERS:
            tensor_getter = lambda _: process.seq_likelihoods_priors_clusters
        elif source == SequenceSignificanceSource.SEQ_LIKELIHOODS_PRIORS_CLUSTERS_CONTEXT:
            tensor_getter = lambda _: process.seq_likelihoods_priors_clusters_context
        else:
            raise IllegalArgumentException(
                f'Unrecognized sequence significance source {source}')

        tensor = FlockProcessObservable(self.flock_node.params.flock_size,
                                        lambda: process,
                                        tensor_getter).get_tensor()
        return tensor[self.expert_id] if is_valid_tensor(tensor) else None
Esempio n. 2
0
    def __init__(
            self,
            name: str,
            type_name: str,
            value: Any,
            callback: Callable[[str], Optional[str]],
            state: ObserverPropertiesItemState = ObserverPropertiesItemState.
        ENABLED,
            optional: bool = False,
            select_values: List[ObserverPropertiesItemSelectValueItem] = (),
            source_type:
        ObserverPropertiesItemSourceType = ObserverPropertiesItemSourceType.
        OBSERVER,
            hint: str = ''):
        if type_name not in self.TYPE_MAPPING.values():
            raise IllegalArgumentException(f"Unrecognized type '{self.type}'")

        self.type = ObserverPropertiesItem._reverse_type_dict[type_name]
        self.name = name
        self.value = value
        self.callback = callback
        self.state = state
        self.select_values = select_values
        self.optional = optional
        self.source_type = source_type
        self.hint = hint
        self._check_arguments()
Esempio n. 3
0
 def parse(cls, packet: Dict[str, Any]) -> EventData:
     event_type = packet['event_type']
     if event_type not in cls.EVENT_TYPES:
         raise IllegalArgumentException(
             f'Unrecognized event type: "{event_type}')
     t = cls.EVENT_TYPES[event_type]
     # noinspection PyTypeChecker
     return dacite.from_dict(data_class=t, data=packet)
Esempio n. 4
0
 def format_projection_type(value: ClusterObserverProjection) -> int:
     if value == ClusterObserverProjection.PCA:
         return 0
     elif value == ClusterObserverProjection.FD_SIM:
         return 1
     else:
         raise IllegalArgumentException(
             f'Unrecognized projection {value}')
Esempio n. 5
0
 def reset_projection(value):
     if self._projection_type == ClusterObserverProjection.PCA:
         self.pca.reset()
     elif self._projection_type == ClusterObserverProjection.FD_SIM:
         self.fdsim.reset()
     else:
         raise IllegalArgumentException(
             f'Unrecognized projection {value}')
Esempio n. 6
0
def signal(*args):
    count = len(args)
    t = [Signal0, Signal1, Signal2, Signal3, Signal4, Signal5, Signal6]
    if count >= len(t):
        raise IllegalArgumentException(
            "Signals up to 6 parameters are supported. This can be extended by "
            "implementing Signal7, ... classes")
    return t[count](*args)
Esempio n. 7
0
    def inverse_projection(self,
                           data: torch.Tensor,
                           n_top_sequences: int = 1) -> torch.Tensor:
        """Calculates the inverse projection for the given output tensor.

        Output projection is computed for all frequent_seq, top n_top_sequences best matching are aggregated and
        projected to SP input space.

        Args:
            data: Tensor matching the shape of projection_output (flock_size, n_cluster_centers).
            n_top_sequences: Number of top sequences to aggregate
        """
        if data.shape != self.projection_outputs.shape:
            raise IllegalArgumentException(
                f"The provided tensor {list(data.shape)} doesn't match "
                f"the shape of projection_outputs {list(self.projection_outputs.shape)}"
            )

        # Compute output projections for each sequence from frequent_seqs
        # [flock_size, n_cluster_centers]
        projection_outputs = torch.empty_like(self.projection_outputs)
        tp_output_projection = TPOutputProjection(
            self.flock_size, self.n_frequent_seqs, self.n_cluster_centers,
            self.seq_length, self.seq_lookahead, self._device)
        tp_output_projection.compute_output_projection_per_sequence(
            self.frequent_seqs, projection_outputs)

        # Compute similarities with input data
        # [flock_size, n_frequent_seqs]
        similarities = tp_output_projection.compute_similarity(
            data, projection_outputs)

        # Scale similarities by seq likelihood
        similarities.mul_(
            self.frequent_seq_likelihoods_priors_clusters_context)

        # Take just top n_top_sequences best matching sequences
        # [flock_size, n_top_sequences]
        sorted_idxs = similarities.sort(dim=1,
                                        descending=True)[1][:,
                                                            0:n_top_sequences]
        # [flock_size, n_top_sequences, seq_length]
        indices = sorted_idxs.unsqueeze(-1).expand(
            (self.flock_size, n_top_sequences, self.seq_length))
        # [flock_size, n_top_sequences, seq_length]
        matched_sequences = torch.gather(self.frequent_seqs, 1, indices)

        # Convert sequences to SP output space - one_hot representation
        # [flock_size, n_top_sequences * seq_length, n_cluster_centers]
        one_hots_per_flock = safe_id_to_one_hot(
            matched_sequences.view((self.flock_size, -1)),
            self.n_cluster_centers)

        # Final aggregation of sequences - just sum and normalize
        # [flock_size, n_cluster_centers]
        summed = one_hots_per_flock.sum(dim=1)
        normalize_probs_(summed, dim=1)
        return summed
Esempio n. 8
0
 def get(self, key: TKey):
     """Get value for the key. Cached value is returned when available (and data_provider is not called)"""
     if key in self._stored_values:
         return self._stored_values[key]
     else:
         if key not in self._data_providers:
             raise IllegalArgumentException(
                 f'Data provider for key {key} not found')
         value = self._data_providers[key]()
         self._stored_values[key] = value
         return value
Esempio n. 9
0
 def _resolve_state(state: Optional[ObserverPropertiesItemState],
                    enabled: Optional[bool]) -> ObserverPropertiesItemState:
     if enabled is not None and state is not None:
         raise IllegalArgumentException(
             'Both arguments state and enabled cannot be set simultaneously. Set either '
             'one to None.')
     if state is not None:
         return state
     elif enabled is not None:
         return ObserverPropertiesItemState.ENABLED if enabled else ObserverPropertiesItemState.DISABLED
     else:
         return ObserverPropertiesItemState.ENABLED
Esempio n. 10
0
def discover_child_classes(module_name: str,
                           base_class,
                           skip_classes: typing.List = None) -> typing.List:
    """Discoveres classes deriving from `base_class` which are located in `module_name`.

     Skips classes that are in the `skip_classes` list.
     """
    module = import_module(module_name)
    if module is None:
        raise IllegalArgumentException(f'Cannot import module "{module_name}"')
    if module.__file__ is None:
        raise IllegalArgumentException(f'Cannot locate "{module_name}"')
    modules = glob.glob(dirname(module.__file__) + "/*.py")

    for submodule in (basename(f)[:-3] for f in modules
                      if isfile(f) and not f.endswith('__init__.py')):
        import_module(f"{module_name}.{submodule}")

    filtred = remove_abstract_classes(get_subclasses_recursive(base_class))
    if skip_classes is not None:
        filtred = remove_skipped_classes(filtred, skip_classes)
    return filtred
Esempio n. 11
0
        def update_projection_type(value):
            old_type = self._projection_type
            if int(value) == 0:
                self._projection_type = ClusterObserverProjection.PCA
            elif int(value) == 1:
                self._projection_type = ClusterObserverProjection.FD_SIM
            else:
                raise IllegalArgumentException(
                    f'Unrecognized projection {value}')

            if self._projection_type == ClusterObserverProjection.PCA and old_type != ClusterObserverProjection.PCA:
                self.pca.reset()

            return value
Esempio n. 12
0
def check_type(expected_type: type, instance):
    if not isinstance(instance, expected_type):
        raise IllegalArgumentException(
            f'Expected instance of {expected_type} but {type(instance)} received.'
        )
Esempio n. 13
0
    def auto(self,
             name: str,
             prop: property,
             state: Optional[ObserverPropertiesItemState] = None,
             enabled: Optional[bool] = None,
             edit_strategy: Optional[EditStrategy] = None,
             hint: str = ''):
        def prop_tuple_int(count: int):
            def parse_tuple_int(v: str) -> Iterable[int]:
                try:
                    parsed_list = self._parse_list(v, lambda i: int(i))
                    if len(parsed_list) != count:
                        raise ValueError(
                            f'Expected exactly {count} items, but {len(parsed_list)} received'
                        )
                    return tuple(parsed_list)
                except ValueError as e:
                    raise FailedValidationException(
                        f"Expected Tuple[{','.join(['int'] * count)}], syntax error: {e}"
                    )

            return self.prop(name,
                             prop,
                             parse_tuple_int,
                             self._format_list,
                             ObserverPropertiesItemType.TEXT,
                             state,
                             hint=hint)

        def prop_tuple_float(count: int):
            def parse_tuple_float(v: str) -> Iterable[float]:
                try:
                    parsed_list = self._parse_list(v, lambda i: float(i))
                    if len(parsed_list) != count:
                        raise ValueError(
                            f'Expected exactly {count} items, but {len(parsed_list)} received'
                        )
                    return tuple(parsed_list)
                except ValueError as e:
                    raise FailedValidationException(
                        f"Expected Tuple[{','.join(['float'] * count)}], syntax error: {e}"
                    )

            return self.prop(name,
                             prop,
                             parse_tuple_float,
                             self._format_list,
                             ObserverPropertiesItemType.TEXT,
                             state,
                             hint=hint)

        tuple_mapping = {
            Tuple[int]:
            partial(prop_tuple_int, 1),
            Tuple[int, int]:
            partial(prop_tuple_int, 2),
            Tuple[int, int, int]:
            partial(prop_tuple_int, 3),
            Tuple[int, int, int, int]:
            partial(prop_tuple_int, 4),
            Tuple[int, int, int, int, int]:
            partial(prop_tuple_int, 5),
            Tuple[int, int, int, int, int, int]:
            partial(prop_tuple_int, 6),
            Tuple[float]:
            partial(prop_tuple_float, 1),
            Tuple[float, float]:
            partial(prop_tuple_float, 2),
            Tuple[float, float, float]:
            partial(prop_tuple_float, 3),
            Tuple[float, float, float, float]:
            partial(prop_tuple_float, 4),
            Tuple[float, float, float, float, float]:
            partial(prop_tuple_float, 5),
            Tuple[float, float, float, float, float, float]:
            partial(prop_tuple_float, 6),
        }

        def is_tuple_type(prop_type):
            for t in tuple_mapping.keys():
                if prop_type is t:
                    return True
            return False

        def make_tuple_prop(prop_type):
            for (t, f) in tuple_mapping.items():
                if prop_type is t:
                    return f()
            raise IllegalArgumentException(f'Unrecognized type {prop_type}')

        self._check_instance_is_set()
        state, state_reason = self._resolve_state_strategy(
            self._resolve_state(state, enabled), edit_strategy)

        hints = get_type_hints(prop.fget)
        if 'return' not in hints:
            raise IllegalArgumentException(
                f'Property getter must be annotated by a return value type hint'
            )
        prop_type = hints['return']

        if prop_type is int:
            return self.prop(name,
                             prop,
                             int,
                             None,
                             ObserverPropertiesItemType.NUMBER,
                             state,
                             hint=hint)
        elif prop_type is str:
            return self.prop(name,
                             prop,
                             None,
                             None,
                             ObserverPropertiesItemType.TEXT,
                             state,
                             hint=hint)
        elif prop_type is Optional[int]:
            return self.prop(name,
                             prop,
                             lambda v: None if v is None else int(v),
                             None,
                             ObserverPropertiesItemType.NUMBER,
                             state,
                             optional=True,
                             hint=hint)
        elif prop_type is float:
            return self.prop(name,
                             prop,
                             float,
                             None,
                             ObserverPropertiesItemType.NUMBER,
                             state,
                             hint=hint)
        elif prop_type is bool:
            return self.prop(name,
                             prop,
                             parse_bool,
                             None,
                             ObserverPropertiesItemType.CHECKBOX,
                             state,
                             hint=hint)
        elif prop_type is List[int]:
            return self.prop(name,
                             prop,
                             self._parse_list_int,
                             self._format_list,
                             ObserverPropertiesItemType.TEXT,
                             state,
                             hint=hint)
        elif prop_type is List[float]:
            return self.prop(name,
                             prop,
                             self._parse_list_float,
                             self._format_list,
                             ObserverPropertiesItemType.TEXT,
                             state,
                             hint=hint)
        elif prop_type is Optional[List[int]]:
            return self.prop(name,
                             prop,
                             lambda v: None
                             if v is None else self._parse_list_int(v),
                             lambda v: None
                             if v is None else self._format_list(v),
                             ObserverPropertiesItemType.TEXT,
                             state,
                             optional=True,
                             hint=hint)
        elif isinstance(prop_type, EnumMeta):
            select_items = list(prop_type)
            return self.prop(name,
                             prop,
                             lambda v: select_items[int(v)],
                             lambda v: str(select_items.index(v)),
                             ObserverPropertiesItemType.SELECT,
                             state,
                             select_values=[
                                 ObserverPropertiesItemSelectValueItem(e.name)
                                 for e in select_items
                             ],
                             hint=hint)
        elif is_tuple_type(prop_type):
            return make_tuple_prop(prop_type)
        else:
            raise IllegalArgumentException(
                f'Unrecognized prop type {prop_type}')
Esempio n. 14
0
 def make_tuple_prop(prop_type):
     for (t, f) in tuple_mapping.items():
         if prop_type is t:
             return f()
     raise IllegalArgumentException(f'Unrecognized type {prop_type}')
Esempio n. 15
0
 def _check_instance_is_set(self):
     if self._instance is None:
         raise IllegalArgumentException(
             f'Instance not set. Pass `instance` param to __init__().')
Esempio n. 16
0
 def raise_exception():
     raise IllegalArgumentException(
         f"Don't know how to distribute the data {whole_input_shape} among {flock_size} experts, "
         f"the product of first n dimensions of the data should be equal to the flock_size."
     )
Esempio n. 17
0
 def _check_arguments(self):
     if self.type == self.TYPE_MAPPING[
             ObserverPropertiesItemType.SELECT] and len(
                 self.select_values) == 0:
         raise IllegalArgumentException(
             f"Select type must have select_values defined")