Пример #1
0
def _output_type_setter(value, owning_component):
    # Can't convert from arrays of length > 1 to number
    if (owning_component.defaults.variable is not None
            and safe_len(owning_component.defaults.variable) > 1
            and owning_component.output_type is FunctionOutputType.RAW_NUMBER):
        raise FunctionError(
            f"{owning_component.__class__.__name__} can't be set to return a "
            "single number since its variable has more than one number.")

    # warn if user overrides the 2D setting for mechanism functions
    # may be removed when
    # https://github.com/PrincetonUniversity/PsyNeuLink/issues/895 is solved
    # properly(meaning Mechanism values may be something other than 2D np array)
    try:
        if (isinstance(owning_component.owner, Mechanism)
                and (value == FunctionOutputType.RAW_NUMBER
                     or value == FunctionOutputType.NP_1D_ARRAY)):
            warnings.warn(
                f'Functions that are owned by a Mechanism but do not return a '
                '2D numpy array may cause unexpected behavior if llvm '
                'compilation is enabled.')
    except (AttributeError, ImportError):
        pass

    return value
Пример #2
0
    def output_type(self, value):
        # Bad outputType specification
        if value is not None and not isinstance(value, FunctionOutputType):
            raise FunctionError(f"value ({self.output_type}) of output_type attribute "
                                f"must be FunctionOutputType for {self.__class__.__name__}.")

        # Can't convert from arrays of length > 1 to number
        if (
            self.defaults.variable is not None
            and safe_len(self.defaults.variable) > 1
            and self.output_type is FunctionOutputType.RAW_NUMBER
        ):
            raise FunctionError(f"{self.__class__.__name__} can't be set to return a single number "
                                f"since its variable has more than one number.")

        # warn if user overrides the 2D setting for mechanism functions
        # may be removed when https://github.com/PrincetonUniversity/PsyNeuLink/issues/895 is solved properly
        # (meaning Mechanism values may be something other than 2D np array)
        try:
            # import here because if this package is not installed, we can assume the user is probably not dealing with compilation
            # so no need to warn unecessarily
            import llvmlite
            if (isinstance(self.owner, Mechanism) and (value == FunctionOutputType.RAW_NUMBER or value == FunctionOutputType.NP_1D_ARRAY)):
                warnings.warn(f'Functions that are owned by a Mechanism but do not return a 2D numpy array '
                              f'may cause unexpected behavior if llvm compilation is enabled.')
        except (AttributeError, ImportError):
            pass

        self._output_type = value
Пример #3
0
    def _update_default_variable(self, new_default_variable, context):
        from psyneulink.core.components.projections.pathway.mappingprojection import MappingProjection
        from psyneulink.core.components.ports.parameterport import ParameterPort

        # this mirrors the transformation in _function
        # it is a hack, and a general solution should be found
        squeezed = np.array(new_default_variable)
        if squeezed.ndim > 1:
            squeezed = np.squeeze(squeezed)

        size = safe_len(squeezed)
        matrix = self.parameters.matrix._get(context)

        if isinstance(matrix, MappingProjection):
            matrix = matrix._parameter_ports[MATRIX]
        elif isinstance(matrix, ParameterPort):
            pass
        else:
            matrix = get_matrix(self.defaults.matrix, size, size)

        self.parameters.matrix._set(matrix, context)

        self._hollow_matrix = get_matrix(HOLLOW_MATRIX, size, size)

        super()._update_default_variable(new_default_variable, context)
Пример #4
0
    def _instantiate_attributes_before_function(self,
                                                function=None,
                                                context=None):
        """Instantiate matrix

        Specified matrix is convolved with HOLLOW_MATRIX
            to eliminate the diagonal (self-connections) from the calculation.
        The `Distance` Function is used for all calculations except ENERGY (which is not really a distance metric).
        If ENTROPY is specified as the metric, convert to CROSS_ENTROPY for use with the Distance Function.
        :param function:

        """
        from psyneulink.core.components.projections.pathway.mappingprojection import MappingProjection
        from psyneulink.core.components.ports.parameterport import ParameterPort

        # this mirrors the transformation in _function
        # it is a hack, and a general solution should be found
        squeezed = np.array(self.defaults.variable)
        if squeezed.ndim > 1:
            squeezed = np.squeeze(squeezed)

        size = safe_len(squeezed)

        matrix = self.parameters.matrix._get(context)

        if isinstance(matrix, MappingProjection):
            matrix = matrix._parameter_ports[MATRIX]
        elif isinstance(matrix, ParameterPort):
            pass
        else:
            matrix = get_matrix(matrix, size, size)

        self.parameters.matrix._set(matrix, context)

        self._hollow_matrix = get_matrix(HOLLOW_MATRIX, size, size)

        default_variable = [self.defaults.variable, self.defaults.variable]

        if self.metric == ENTROPY:
            self.metric_fct = Distance(default_variable=default_variable,
                                       metric=CROSS_ENTROPY,
                                       normalize=self.normalize)
        elif self.metric in DISTANCE_METRICS._set():
            self.metric_fct = Distance(default_variable=default_variable,
                                       metric=self.metric,
                                       normalize=self.normalize)
        else:
            assert False, "Unknown metric"
        #FIXME: This is a hack to make sure metric-fct param is set
        self.parameters.metric_fct.set(self.metric_fct)
class ComparatorMechanism(ObjectiveMechanism):
    """
    ComparatorMechanism(                                \
        sample,                                         \
        target,                                         \
        input_ports=[SAMPLE,TARGET]                     \
        function=LinearCombination(weights=[[-1],[1]],  \
        output_ports=OUTCOME)

    Subclass of `ObjectiveMechanism` that compares the values of two `OutputPorts <OutputPort>`.
    See `ObjectiveMechanism <ObjectiveMechanism_Class_Reference>` for additional arguments and attributes.


    Arguments
    ---------

    sample : OutputPort, Mechanism, value, or string
        specifies the value to compare with the `target` by the `function <ComparatorMechanism.function>`.

    target :  OutputPort, Mechanism, value, or string
        specifies the value with which the `sample` is compared by the `function <ComparatorMechanism.function>`.

    input_ports :  List[InputPort, value, str or dict] or Dict[] : default [SAMPLE, TARGET]
        specifies the names and/or formats to use for the values of the sample and target InputPorts;
        by default they are named *SAMPLE* and *TARGET*, and their formats are match the value of the OutputPorts
        specified in the **sample** and **target** arguments, respectively (see `ComparatorMechanism_Structure`
        for additional details).

    function :  Function, function or method : default Distance(metric=DIFFERENCE)
        specifies the `function <Comparator.function>` used to compare the `sample` with the `target`.


    Attributes
    ----------

    COMMENT:
    default_variable : Optional[List[array] or 2d np.array]
    COMMENT

    sample : OutputPort
        determines the value to compare with the `target` by the `function <ComparatorMechanism.function>`.

    target : OutputPort
        determines the value with which `sample` is compared by the `function <ComparatorMechanism.function>`.

    input_ports : ContentAddressableList[InputPort, InputPort]
        contains the two InputPorts named, by default, *SAMPLE* and *TARGET*, each of which receives a
        `MappingProjection` from the OutputPorts referenced by the `sample` and `target` attributes
        (see `ComparatorMechanism_Structure` for additional details).

    function : CombinationFunction, function or method
        used to compare the `sample` with the `target`.  It can be any PsyNeuLink `CombinationFunction`,
        or a python function that takes a 2d array with two items and returns a 1d array of the same length
        as the two input items.

    output_port : OutputPort
        contains the `primary <OutputPort_Primary>` OutputPort of the ComparatorMechanism; the default is
        its *OUTCOME* OutputPort, the value of which is equal to the `value <ComparatorMechanism.value>`
        attribute of the ComparatorMechanism.

    output_ports : ContentAddressableList[OutputPort]
        contains, by default, only the *OUTCOME* (primary) OutputPort of the ComparatorMechanism.

    output_values : 2d np.array
        contains one item that is the value of the *OUTCOME* OutputPort.

    standard_output_ports : list[str]
        list of `Standard OutputPorts <OutputPort_Standard>` that includes the following in addition to the
        `standard_output_ports <ObjectiveMechanism.standard_output_ports>` of an `ObjectiveMechanism`:

        .. _COMPARATOR_MECHANISM_SSE

        *SSE*
            the value of the sum squared error of the Mechanism's function

        .. _COMPARATOR_MECHANISM_MSE

        *MSE*
            the value of the mean squared error of the Mechanism's function

    """
    componentType = COMPARATOR_MECHANISM

    classPreferenceLevel = PreferenceLevel.SUBTYPE
    # These will override those specified in TYPE_DEFAULT_PREFERENCES
    classPreferences = {
        PREFERENCE_SET_NAME: 'ComparatorCustomClassPreferences',
        REPORT_OUTPUT_PREF: PreferenceEntry(False, PreferenceLevel.INSTANCE)}

    class Parameters(ObjectiveMechanism.Parameters):
        """
            Attributes
            ----------

                variable
                    see `variable <Mechanism_Base.variable>`

                    :default value: numpy.array([[0], [0]])
                    :type: numpy.ndarray
                    :read only: True

                function
                    see `function <ComparatorMechanism.function>`

                    :default value: `LinearCombination`(weights=numpy.array([[-1], [ 1]]))
                    :type: `Function`

                sample
                    see `sample <ComparatorMechanism.sample>`

                    :default value: None
                    :type:

                target
                    see `target <ComparatorMechanism.target>`

                    :default value: None
                    :type:

        """
        # By default, ComparatorMechanism compares two 1D np.array input_ports
        variable = Parameter(np.array([[0], [0]]), read_only=True, pnl_internal=True, constructor_argument='default_variable')
        function = Parameter(LinearCombination(weights=[[-1], [1]]), stateful=False, loggable=False)
        sample = None
        target = None

        output_ports = Parameter(
            [OUTCOME],
            stateful=False,
            loggable=False,
            read_only=True,
            structural=True,
        )

    # ComparatorMechanism parameter and control signal assignments):
    paramClassDefaults = Mechanism_Base.paramClassDefaults.copy()

    standard_output_ports = ObjectiveMechanism.standard_output_ports.copy()
    standard_output_ports.extend([{NAME: SSE,
                                   FUNCTION: lambda x: np.sum(x * x)},
                                  {NAME: MSE,
                                   FUNCTION: lambda x: np.sum(x * x) / safe_len(x)}])
    standard_output_port_names = ObjectiveMechanism.standard_output_port_names.copy()
    standard_output_port_names.extend([SSE, MSE])

    @tc.typecheck
    def __init__(self,
                 default_variable=None,
                 sample: tc.optional(tc.any(OutputPort, Mechanism_Base, dict, is_numeric, str))=None,
                 target: tc.optional(tc.any(OutputPort, Mechanism_Base, dict, is_numeric, str))=None,
                 function=LinearCombination(weights=[[-1], [1]]),
                 output_ports:tc.optional(tc.any(str, Iterable)) = None,
                 params=None,
                 name=None,
                 prefs:is_pref_set=None,
                 **kwargs
                 ):

        input_ports = kwargs.pop(INPUT_PORTS, {})
        if input_ports:
            input_ports = {INPUT_PORTS: input_ports}

        input_ports = self._merge_legacy_constructor_args(sample, target, default_variable, input_ports)

        # Default output_ports is specified in constructor as a tuple rather than a list
        # to avoid "gotcha" associated with mutable default arguments
        # (see: bit.ly/2uID3s3 and http://docs.python-guide.org/en/latest/writing/gotchas/)
        if isinstance(output_ports, (str, tuple)):
            output_ports = list(output_ports)

        # IMPLEMENTATION NOTE: The following prevents the default from being updated by subsequent assignment
        #                     (in this case, to [OUTCOME, {NAME= MSE}]), but fails to expose default in IDE
        # output_ports = output_ports or [OUTCOME, MSE]

        super().__init__(monitor=input_ports,
                         function=function,
                         output_ports=output_ports, # prevent default from getting overwritten by later assign
                         params=params,
                         name=name,
                         prefs=prefs,

                         **kwargs
                         )

        # Require Projection to TARGET InputPort (already required for SAMPLE as primary InputPort)
        self.input_ports[1].parameters.require_projection_in_composition._set(True, Context())

    def _validate_params(self, request_set, target_set=None, context=None):
        """If sample and target values are specified, validate that they are compatible
        """

        if INPUT_PORTS in request_set and request_set[INPUT_PORTS] is not None:
            input_ports = request_set[INPUT_PORTS]

            # Validate that there are exactly two input_ports (for sample and target)
            num_input_ports = len(input_ports)
            if num_input_ports != 2:
                raise ComparatorMechanismError(f"{INPUT_PORTS} arg is specified for {self.__class__.__name__} "
                                               f"({len(input_ports)}), so it must have exactly 2 items, "
                                               f"one each for {SAMPLE} and {TARGET}.")

            # Validate that input_ports are specified as dicts
            if not all(isinstance(input_port,dict) for input_port in input_ports):
                raise ComparatorMechanismError("PROGRAM ERROR: all items in input_port args must be converted to dicts"
                                               " by calling Port._parse_port_spec() before calling super().__init__")

            # Validate length of variable for sample = target
            if VARIABLE in input_ports[0]:
                # input_ports arg specified in standard port specification dict format
                lengths = [len(input_port[VARIABLE]) if input_port[VARIABLE] is not None else 0
                           for input_port in input_ports]
            else:
                # input_ports arg specified in {<Port_Name>:<PORT SPECIFICATION DICT>} format
                lengths = [len(list(input_port_dict.values())[0][VARIABLE]) for input_port_dict in input_ports]

            if lengths[0] != lengths[1]:
                raise ComparatorMechanismError(f"Length of value specified for {SAMPLE} InputPort "
                                               f"of {self.__class__.__name__} ({lengths[0]}) must be "
                                               f"same as length of value specified for {TARGET} ({lengths[1]}).")

        elif SAMPLE in request_set and TARGET in request_set:

            sample = request_set[SAMPLE]
            if isinstance(sample, InputPort):
                sample_value = sample.value
            elif isinstance(sample, Mechanism):
                sample_value = sample.input_value[0]
            elif is_value_spec(sample):
                sample_value = sample
            else:
                sample_value = None

            target = request_set[TARGET]
            if isinstance(target, InputPort):
                target_value = target.value
            elif isinstance(target, Mechanism):
                target_value = target.input_value[0]
            elif is_value_spec(target):
                target_value = target
            else:
                target_value = None

            if sample is not None and target is not None:
                if not iscompatible(sample, target, **{kwCompatibilityLength: True,
                                                       kwCompatibilityNumeric: True}):
                    raise ComparatorMechanismError(f"The length of the sample ({len(sample)}) "
                                                   f"must be the same as for the target ({len(target)})"
                                                   f"for {self.__class__.__name__} {self.name}.")

        super()._validate_params(request_set=request_set,
                                 target_set=target_set,
                                 context=context)

    def _merge_legacy_constructor_args(self, sample, target, default_variable=None, input_ports=None):

        # USE sample and target TO CREATE AN InputPort specfication dictionary for each;
        # DO SAME FOR InputPorts argument, USE TO OVERWRITE ANY SPECIFICATIONS IN sample AND target DICTS
        # TRY tuple format AS WAY OF PROVIDED CONSOLIDATED variable AND OutputPort specifications

        sample_dict = _parse_port_spec(owner=self,
                                        port_type=InputPort,
                                        port_spec=sample,
                                        name=SAMPLE)

        target_dict = _parse_port_spec(owner=self,
                                        port_type=InputPort,
                                        port_spec=target,
                                        name=TARGET)

        # If either the default_variable arg or the input_ports arg is provided:
        #    - validate that there are exactly two items in default_variable or input_ports list
        #    - if there is an input_ports list, parse it and use it to update sample and target dicts
        if input_ports:
            input_ports = input_ports[INPUT_PORTS]
            # print("type input_ports = {}".format(type(input_ports)))
            if not isinstance(input_ports, list):
                raise ComparatorMechanismError(f"If an '{INPUT_PORTS}' argument is included in the constructor "
                                               f"for a {ComparatorMechanism.__name__} it must be a list with "
                                               f"two {InputPort.__name__} specifications.")

        input_ports = input_ports or default_variable

        if input_ports is not None:
            if len(input_ports)!=2:
                raise ComparatorMechanismError(f"If an \'input_ports\' arg is included in the constructor for a "
                                               f"{ComparatorMechanism.__name__}, it must be a list with exactly "
                                               f"two items (not {len(input_ports)}).")

            sample_input_port_dict = _parse_port_spec(owner=self,
                                                        port_type=InputPort,
                                                        port_spec=input_ports[0],
                                                        name=SAMPLE,
                                                        value=None)

            target_input_port_dict = _parse_port_spec(owner=self,
                                                        port_type=InputPort,
                                                        port_spec=input_ports[1],
                                                        name=TARGET,
                                                        value=None)

            sample_dict = recursive_update(sample_dict, sample_input_port_dict)
            target_dict = recursive_update(target_dict, target_input_port_dict)

        return [sample_dict, target_dict]
class ComparatorMechanism(ObjectiveMechanism):
    """
    ComparatorMechanism(                                \
        sample,                                         \
        target,                                         \
        input_states=[SAMPLE,TARGET]                    \
        function=LinearCombination(weights=[[-1],[1]],  \
        output_states=OUTCOME                           \
        params=None,                                    \
        name=None,                                      \
        prefs=None)

    Subclass of `ObjectiveMechanism` that compares the values of two `OutputStates <OutputState>`.

    COMMENT:
        Description:
            ComparatorMechanism is a subtype of the ObjectiveMechanism Subtype of the ProcssingMechanism Type
            of the Mechanism Category of the Component class.
            By default, it's function uses the LinearCombination Function to compare two input variables.
            COMPARISON_OPERATION (functionParams) determines whether the comparison is subtractive or divisive
            The function returns an array with the Hadamard (element-wise) differece/quotient of target vs. sample,
                as well as the mean, sum, sum of squares, and mean sum of squares of the comparison array

        Class attributes:
            + componentType (str): ComparatorMechanism
            + classPreference (PreferenceSet): Comparator_PreferenceSet, instantiated in __init__()
            + classPreferenceLevel (PreferenceLevel): PreferenceLevel.SUBTYPE
            + class_defaults.variable (value):  Comparator_DEFAULT_STARTING_POINT // QUESTION: What to change here
            + paramClassDefaults (dict): {FUNCTION_PARAMS:{COMPARISON_OPERATION: SUBTRACTION}}

        Class methods:
            None

        MechanismRegistry:
            All instances of ComparatorMechanism are registered in MechanismRegistry, which maintains an
              entry for the subclass, a count for all instances of it, and a dictionary of those instances
    COMMENT

    Arguments
    ---------

    sample : OutputState, Mechanism, value, or string
        specifies the value to compare with the `target` by the `function <ComparatorMechanism.function>`.

    target :  OutputState, Mechanism, value, or string
        specifies the value with which the `sample` is compared by the `function <ComparatorMechanism.function>`.

    input_states :  List[InputState, value, str or dict] or Dict[] : default [SAMPLE, TARGET]
        specifies the names and/or formats to use for the values of the sample and target InputStates;
        by default they are named *SAMPLE* and *TARGET*, and their formats are match the value of the OutputStates
        specified in the **sample** and **target** arguments, respectively (see `ComparatorMechanism_Structure`
        for additional details).

    function :  Function, function or method : default Distance(metric=DIFFERENCE)
        specifies the `function <Comparator.function>` used to compare the `sample` with the `target`.

    output_states :  List[OutputState, value, str or dict] or Dict[] : default [OUTCOME]
        specifies the OutputStates for the Mechanism;

    params :  Optional[Dict[param keyword: param value]]
        a `parameter dictionary <ParameterState_Specification>` that can be used to specify the parameters for
        the Mechanism, its function, and/or a custom function and its parameters. Values specified for parameters in
        the dictionary override any assigned to those parameters in arguments of the
        constructor.

    name : str : default see `name <ComparatorMechanism.name>`
        specifies the name of the ComparatorMechanism.

    prefs : PreferenceSet or specification dict : default Mechanism.classPreferences
        specifies the `PreferenceSet` for the ComparatorMechanism; see `prefs <ComparatorMechanism.prefs>` for details.


    Attributes
    ----------

    COMMENT:
    default_variable : Optional[List[array] or 2d np.array]
    COMMENT

    sample : OutputState
        determines the value to compare with the `target` by the `function <ComparatorMechanism.function>`.

    target : OutputState
        determines the value with which `sample` is compared by the `function <ComparatorMechanism.function>`.

    input_states : ContentAddressableList[InputState, InputState]
        contains the two InputStates named, by default, *SAMPLE* and *TARGET*, each of which receives a
        `MappingProjection` from the OutputStates referenced by the `sample` and `target` attributes
        (see `ComparatorMechanism_Structure` for additional details).

    function : CombinationFunction, function or method
        used to compare the `sample` with the `target`.  It can be any PsyNeuLink `CombinationFunction`,
        or a python function that takes a 2d array with two items and returns a 1d array of the same length
        as the two input items.

    value : 1d np.array
        the result of the comparison carried out by the `function <ComparatorMechanism.function>`.

    output_state : OutputState
        contains the `primary <OutputState_Primary>` OutputState of the ComparatorMechanism; the default is
        its *OUTCOME* OutputState, the value of which is equal to the `value <ComparatorMechanism.value>`
        attribute of the ComparatorMechanism.

    output_states : ContentAddressableList[OutputState]
        contains, by default, only the *OUTCOME* (primary) OutputState of the ComparatorMechanism.

    output_values : 2d np.array
        contains one item that is the value of the *OUTCOME* OutputState.

    name : str
        the name of the ComparatorMechanism; if it is not specified in the **name** argument of the constructor, a
        default is assigned by MechanismRegistry (see `Naming` for conventions used for default and duplicate names).

    prefs : PreferenceSet or specification dict
        the `PreferenceSet` for the ComparatorMechanism; if it is not specified in the **prefs** argument of the
        constructor, a default is assigned using `classPreferences` defined in __init__.py (see :doc:`PreferenceSet
        <LINK>` for details).


    """
    componentType = COMPARATOR_MECHANISM

    classPreferenceLevel = PreferenceLevel.SUBTYPE
    # These will override those specified in TypeDefaultPreferences
    classPreferences = {
        kwPreferenceSetName: 'ComparatorCustomClassPreferences',
        kpReportOutputPref: PreferenceEntry(False, PreferenceLevel.INSTANCE)
    }

    class Parameters(ObjectiveMechanism.Parameters):
        """
            Attributes
            ----------

                variable
                    see `variable <ComparatorMechanism.variable>`

                    :default value: numpy.array([[0], [0]])
                    :type: numpy.ndarray
                    :read only: True

                function
                    see `function <ComparatorMechanism.function>`

                    :default value: `LinearCombination`(offset=0.0, operation=sum, scale=1.0, weights=numpy.array([[-1], [ 1]]))
                    :type: `Function`

                sample
                    see `sample <ComparatorMechanism.sample>`

                    :default value: None
                    :type:

                target
                    see `target <ComparatorMechanism.target>`

                    :default value: None
                    :type:

        """
        # By default, ComparatorMechanism compares two 1D np.array input_states
        variable = Parameter(np.array([[0], [0]]), read_only=True)
        function = Parameter(LinearCombination(weights=[[-1], [1]]),
                             stateful=False,
                             loggable=False)
        sample = None
        target = None

    # ComparatorMechanism parameter and control signal assignments):
    paramClassDefaults = Mechanism_Base.paramClassDefaults.copy()

    standard_output_states = ObjectiveMechanism.standard_output_states.copy()

    standard_output_states.extend([{
        NAME: SSE,
        FUNCTION: lambda x: np.sum(x * x)
    }, {
        NAME:
        MSE,
        FUNCTION:
        lambda x: np.sum(x * x) / safe_len(x)
    }])

    @tc.typecheck
    def __init__(
        self,
        default_variable=None,
        sample: tc.optional(
            tc.any(OutputState, Mechanism_Base, dict, is_numeric, str)) = None,
        target: tc.optional(
            tc.any(OutputState, Mechanism_Base, dict, is_numeric, str)) = None,
        function=LinearCombination(weights=[[-1], [1]]),
        output_states: tc.optional(tc.any(str, Iterable)) = (OUTCOME, ),
        params=None,
        name=None,
        prefs: is_pref_set = None,
        **
        input_states  # IMPLEMENTATION NOTE: this is for backward compatibility
    ):

        input_states = self._merge_legacy_constructor_args(
            sample, target, default_variable, input_states)

        # Default output_states is specified in constructor as a tuple rather than a list
        # to avoid "gotcha" associated with mutable default arguments
        # (see: bit.ly/2uID3s3 and http://docs.python-guide.org/en/latest/writing/gotchas/)
        if isinstance(output_states, (str, tuple)):
            output_states = list(output_states)

        # IMPLEMENTATION NOTE: The following prevents the default from being updated by subsequent assignment
        #                     (in this case, to [OUTCOME, {NAME= MSE}]), but fails to expose default in IDE
        # output_states = output_states or [OUTCOME, MSE]

        # Create a StandardOutputStates object from the list of stand_output_states specified for the class
        if not isinstance(self.standard_output_states, StandardOutputStates):
            self.standard_output_states = StandardOutputStates(
                self, self.standard_output_states, indices=PRIMARY)

        super().__init__(  # monitor=[sample, target],
            monitor=input_states,
            function=function,
            output_states=output_states.copy(
            ),  # prevent default from getting overwritten by later assign
            params=params,
            name=name,
            prefs=prefs,
            context=ContextFlags.CONSTRUCTOR)

    def _validate_params(self, request_set, target_set=None, context=None):
        """If sample and target values are specified, validate that they are compatible
        """

        if INPUT_STATES in request_set and request_set[
                INPUT_STATES] is not None:
            input_states = request_set[INPUT_STATES]

            # Validate that there are exactly two input_states (for sample and target)
            num_input_states = len(input_states)
            if num_input_states != 2:
                raise ComparatorMechanismError(
                    "{} arg is specified for {} ({}), so it must have exactly 2 items, "
                    "one each for {} and {}".format(INPUT_STATES,
                                                    self.__class__.__name__,
                                                    len(input_states), SAMPLE,
                                                    TARGET))

            # Validate that input_states are specified as dicts
            if not all(
                    isinstance(input_state, dict)
                    for input_state in input_states):
                raise ComparatorMechanismError(
                    "PROGRAM ERROR: all items in input_state args must be converted to dicts"
                    " by calling State._parse_state_spec() before calling super().__init__"
                )

            # Validate length of variable for sample = target
            if VARIABLE in input_states[0]:
                # input_states arg specified in standard state specification dict format
                lengths = [
                    len(input_state[VARIABLE]) for input_state in input_states
                ]
            else:
                # input_states arg specified in {<STATE_NAME>:<STATE SPECIFICATION DICT>} format
                lengths = [
                    len(list(input_state_dict.values())[0][VARIABLE])
                    for input_state_dict in input_states
                ]

            if lengths[0] != lengths[1]:
                raise ComparatorMechanismError(
                    "Length of value specified for {} InputState of {} ({}) must be "
                    "same as length of value specified for {} ({})".format(
                        SAMPLE, self.__class__.__name__, lengths[0], TARGET,
                        lengths[1]))

        elif SAMPLE in request_set and TARGET in request_set:

            sample = request_set[SAMPLE]
            if isinstance(sample, InputState):
                sample_value = sample.value
            elif isinstance(sample, Mechanism):
                sample_value = sample.input_value[0]
            elif is_value_spec(sample):
                sample_value = sample
            else:
                sample_value = None

            target = request_set[TARGET]
            if isinstance(target, InputState):
                target_value = target.value
            elif isinstance(target, Mechanism):
                target_value = target.input_value[0]
            elif is_value_spec(target):
                target_value = target
            else:
                target_value = None

            if sample is not None and target is not None:
                if not iscompatible(
                        sample, target, **{
                            kwCompatibilityLength: True,
                            kwCompatibilityNumeric: True
                        }):
                    raise ComparatorMechanismError(
                        "The length of the sample ({}) must be the same as for the target ({})"
                        "for {} {}".format(len(sample), len(target),
                                           self.__class__.__name__, self.name))

        super()._validate_params(request_set=request_set,
                                 target_set=target_set,
                                 context=context)

    def _merge_legacy_constructor_args(self,
                                       sample,
                                       target,
                                       default_variable=None,
                                       input_states=None):

        # USE sample and target TO CREATE AN InputState specfication dictionary for each;
        # DO SAME FOR InputStates argument, USE TO OVERWRITE ANY SPECIFICATIONS IN sample AND target DICTS
        # TRY tuple format AS WAY OF PROVIDED CONSOLIDATED variable AND OutputState specifications

        sample_dict = _parse_state_spec(owner=self,
                                        state_type=InputState,
                                        state_spec=sample,
                                        name=SAMPLE)

        target_dict = _parse_state_spec(owner=self,
                                        state_type=InputState,
                                        state_spec=target,
                                        name=TARGET)

        # If either the default_variable arg or the input_states arg is provided:
        #    - validate that there are exactly two items in default_variable or input_states list
        #    - if there is an input_states list, parse it and use it to update sample and target dicts
        if input_states:
            input_states = input_states[INPUT_STATES]
            # print("type input_states = {}".format(type(input_states)))
            if not isinstance(input_states, list):
                raise ComparatorMechanismError(
                    "If an \'{}\' argument is included in the constructor for a {} "
                    "it must be a list with two {} specifications.".format(
                        INPUT_STATES, ComparatorMechanism.__name__,
                        InputState.__name__))

        input_states = input_states or default_variable

        if input_states is not None:
            if len(input_states) != 2:
                raise ComparatorMechanismError(
                    "If an \'input_states\' arg is "
                    "included in the constructor for "
                    "a {}, it must be a list with "
                    "exactly two items (not {})".format(
                        ComparatorMechanism.__name__, len(input_states)))

            sample_input_state_dict = _parse_state_spec(
                owner=self,
                state_type=InputState,
                state_spec=input_states[0],
                name=SAMPLE,
                value=None)

            target_input_state_dict = _parse_state_spec(
                owner=self,
                state_type=InputState,
                state_spec=input_states[1],
                name=TARGET,
                value=None)

            sample_dict = recursive_update(sample_dict,
                                           sample_input_state_dict)
            target_dict = recursive_update(target_dict,
                                           target_input_state_dict)

        return [sample_dict, target_dict]
Пример #7
0
    def _validate_params(self,
                         variable,
                         request_set,
                         target_set=None,
                         context=None):
        """Validate matrix param

        `matrix <Stability.matrix>` argument must be one of the following
            - 2d list, np.ndarray or np.matrix
            - ParameterPort for one of the above
            - MappingProjection with a parameterPorts[MATRIX] for one of the above

        Parse matrix specification to insure it resolves to a square matrix
        (but leave in the form in which it was specified so that, if it is a ParameterPort or MappingProjection,
         its current value can be accessed at runtime (i.e., it can be used as a "pointer")
        """

        # Validate matrix specification
        # (str can be automatically transformed to variable shape)
        if MATRIX in target_set and not isinstance(target_set[MATRIX], str):

            from psyneulink.core.components.projections.pathway.mappingprojection import MappingProjection
            from psyneulink.core.components.ports.parameterport import ParameterPort

            matrix = target_set[MATRIX]

            if isinstance(matrix, MappingProjection):
                try:
                    matrix = matrix._parameter_ports[MATRIX].value
                    param_type_string = "MappingProjection's ParameterPort"
                except KeyError:
                    raise FunctionError(
                        "The MappingProjection specified for the {} arg of {} ({}) must have a {} "
                        "ParameterPort that has been assigned a 2d array or matrix"
                        .format(MATRIX, self.name, matrix.shape, MATRIX))

            elif isinstance(matrix, ParameterPort):
                try:
                    matrix = matrix.value
                    param_type_string = "ParameterPort"
                except KeyError:
                    raise FunctionError(
                        "The value of the {} parameterPort specified for the {} arg of {} ({}) "
                        "must be a 2d array or matrix".format(
                            MATRIX, MATRIX, self.name, matrix.shape))

            else:
                param_type_string = "array or matrix"

            matrix = np.array(matrix)
            if matrix.ndim != 2:
                raise FunctionError(
                    "The value of the {} specified for the {} arg of {} ({}) "
                    "must be a 2d array or matrix".format(
                        param_type_string, MATRIX, self.name, matrix))
            rows = matrix.shape[0]
            cols = matrix.shape[1]

            # this mirrors the transformation in _function
            # it is a hack, and a general solution should be found
            squeezed = np.array(self.defaults.variable)
            if squeezed.ndim > 1:
                squeezed = np.squeeze(squeezed)

            size = safe_len(squeezed)

            if rows != size:
                raise FunctionError(
                    "The value of the {} specified for the {} arg of {} is the wrong size;"
                    "it is {}x{}, but must be square matrix of size {}".format(
                        param_type_string, MATRIX, self.name, rows, cols,
                        size))

            if rows != cols:
                raise FunctionError(
                    "The value of the {} specified for the {} arg of {} ({}) "
                    "must be a square matrix".format(param_type_string, MATRIX,
                                                     self.name, matrix))

        super()._validate_params(request_set=request_set,
                                 target_set=target_set,
                                 context=context)