def _output_type_setter(value, owning_component): # Can't convert from arrays of length > 1 to number if (owning_component.defaults.variable is not None and safe_len(owning_component.defaults.variable) > 1 and owning_component.output_type is FunctionOutputType.RAW_NUMBER): raise FunctionError( f"{owning_component.__class__.__name__} can't be set to return a " "single number since its variable has more than one number.") # warn if user overrides the 2D setting for mechanism functions # may be removed when # https://github.com/PrincetonUniversity/PsyNeuLink/issues/895 is solved # properly(meaning Mechanism values may be something other than 2D np array) try: if (isinstance(owning_component.owner, Mechanism) and (value == FunctionOutputType.RAW_NUMBER or value == FunctionOutputType.NP_1D_ARRAY)): warnings.warn( f'Functions that are owned by a Mechanism but do not return a ' '2D numpy array may cause unexpected behavior if llvm ' 'compilation is enabled.') except (AttributeError, ImportError): pass return value
def output_type(self, value): # Bad outputType specification if value is not None and not isinstance(value, FunctionOutputType): raise FunctionError(f"value ({self.output_type}) of output_type attribute " f"must be FunctionOutputType for {self.__class__.__name__}.") # Can't convert from arrays of length > 1 to number if ( self.defaults.variable is not None and safe_len(self.defaults.variable) > 1 and self.output_type is FunctionOutputType.RAW_NUMBER ): raise FunctionError(f"{self.__class__.__name__} can't be set to return a single number " f"since its variable has more than one number.") # warn if user overrides the 2D setting for mechanism functions # may be removed when https://github.com/PrincetonUniversity/PsyNeuLink/issues/895 is solved properly # (meaning Mechanism values may be something other than 2D np array) try: # import here because if this package is not installed, we can assume the user is probably not dealing with compilation # so no need to warn unecessarily import llvmlite if (isinstance(self.owner, Mechanism) and (value == FunctionOutputType.RAW_NUMBER or value == FunctionOutputType.NP_1D_ARRAY)): warnings.warn(f'Functions that are owned by a Mechanism but do not return a 2D numpy array ' f'may cause unexpected behavior if llvm compilation is enabled.') except (AttributeError, ImportError): pass self._output_type = value
def _update_default_variable(self, new_default_variable, context): from psyneulink.core.components.projections.pathway.mappingprojection import MappingProjection from psyneulink.core.components.ports.parameterport import ParameterPort # this mirrors the transformation in _function # it is a hack, and a general solution should be found squeezed = np.array(new_default_variable) if squeezed.ndim > 1: squeezed = np.squeeze(squeezed) size = safe_len(squeezed) matrix = self.parameters.matrix._get(context) if isinstance(matrix, MappingProjection): matrix = matrix._parameter_ports[MATRIX] elif isinstance(matrix, ParameterPort): pass else: matrix = get_matrix(self.defaults.matrix, size, size) self.parameters.matrix._set(matrix, context) self._hollow_matrix = get_matrix(HOLLOW_MATRIX, size, size) super()._update_default_variable(new_default_variable, context)
def _instantiate_attributes_before_function(self, function=None, context=None): """Instantiate matrix Specified matrix is convolved with HOLLOW_MATRIX to eliminate the diagonal (self-connections) from the calculation. The `Distance` Function is used for all calculations except ENERGY (which is not really a distance metric). If ENTROPY is specified as the metric, convert to CROSS_ENTROPY for use with the Distance Function. :param function: """ from psyneulink.core.components.projections.pathway.mappingprojection import MappingProjection from psyneulink.core.components.ports.parameterport import ParameterPort # this mirrors the transformation in _function # it is a hack, and a general solution should be found squeezed = np.array(self.defaults.variable) if squeezed.ndim > 1: squeezed = np.squeeze(squeezed) size = safe_len(squeezed) matrix = self.parameters.matrix._get(context) if isinstance(matrix, MappingProjection): matrix = matrix._parameter_ports[MATRIX] elif isinstance(matrix, ParameterPort): pass else: matrix = get_matrix(matrix, size, size) self.parameters.matrix._set(matrix, context) self._hollow_matrix = get_matrix(HOLLOW_MATRIX, size, size) default_variable = [self.defaults.variable, self.defaults.variable] if self.metric == ENTROPY: self.metric_fct = Distance(default_variable=default_variable, metric=CROSS_ENTROPY, normalize=self.normalize) elif self.metric in DISTANCE_METRICS._set(): self.metric_fct = Distance(default_variable=default_variable, metric=self.metric, normalize=self.normalize) else: assert False, "Unknown metric" #FIXME: This is a hack to make sure metric-fct param is set self.parameters.metric_fct.set(self.metric_fct)
class ComparatorMechanism(ObjectiveMechanism): """ ComparatorMechanism( \ sample, \ target, \ input_ports=[SAMPLE,TARGET] \ function=LinearCombination(weights=[[-1],[1]], \ output_ports=OUTCOME) Subclass of `ObjectiveMechanism` that compares the values of two `OutputPorts <OutputPort>`. See `ObjectiveMechanism <ObjectiveMechanism_Class_Reference>` for additional arguments and attributes. Arguments --------- sample : OutputPort, Mechanism, value, or string specifies the value to compare with the `target` by the `function <ComparatorMechanism.function>`. target : OutputPort, Mechanism, value, or string specifies the value with which the `sample` is compared by the `function <ComparatorMechanism.function>`. input_ports : List[InputPort, value, str or dict] or Dict[] : default [SAMPLE, TARGET] specifies the names and/or formats to use for the values of the sample and target InputPorts; by default they are named *SAMPLE* and *TARGET*, and their formats are match the value of the OutputPorts specified in the **sample** and **target** arguments, respectively (see `ComparatorMechanism_Structure` for additional details). function : Function, function or method : default Distance(metric=DIFFERENCE) specifies the `function <Comparator.function>` used to compare the `sample` with the `target`. Attributes ---------- COMMENT: default_variable : Optional[List[array] or 2d np.array] COMMENT sample : OutputPort determines the value to compare with the `target` by the `function <ComparatorMechanism.function>`. target : OutputPort determines the value with which `sample` is compared by the `function <ComparatorMechanism.function>`. input_ports : ContentAddressableList[InputPort, InputPort] contains the two InputPorts named, by default, *SAMPLE* and *TARGET*, each of which receives a `MappingProjection` from the OutputPorts referenced by the `sample` and `target` attributes (see `ComparatorMechanism_Structure` for additional details). function : CombinationFunction, function or method used to compare the `sample` with the `target`. It can be any PsyNeuLink `CombinationFunction`, or a python function that takes a 2d array with two items and returns a 1d array of the same length as the two input items. output_port : OutputPort contains the `primary <OutputPort_Primary>` OutputPort of the ComparatorMechanism; the default is its *OUTCOME* OutputPort, the value of which is equal to the `value <ComparatorMechanism.value>` attribute of the ComparatorMechanism. output_ports : ContentAddressableList[OutputPort] contains, by default, only the *OUTCOME* (primary) OutputPort of the ComparatorMechanism. output_values : 2d np.array contains one item that is the value of the *OUTCOME* OutputPort. standard_output_ports : list[str] list of `Standard OutputPorts <OutputPort_Standard>` that includes the following in addition to the `standard_output_ports <ObjectiveMechanism.standard_output_ports>` of an `ObjectiveMechanism`: .. _COMPARATOR_MECHANISM_SSE *SSE* the value of the sum squared error of the Mechanism's function .. _COMPARATOR_MECHANISM_MSE *MSE* the value of the mean squared error of the Mechanism's function """ componentType = COMPARATOR_MECHANISM classPreferenceLevel = PreferenceLevel.SUBTYPE # These will override those specified in TYPE_DEFAULT_PREFERENCES classPreferences = { PREFERENCE_SET_NAME: 'ComparatorCustomClassPreferences', REPORT_OUTPUT_PREF: PreferenceEntry(False, PreferenceLevel.INSTANCE)} class Parameters(ObjectiveMechanism.Parameters): """ Attributes ---------- variable see `variable <Mechanism_Base.variable>` :default value: numpy.array([[0], [0]]) :type: numpy.ndarray :read only: True function see `function <ComparatorMechanism.function>` :default value: `LinearCombination`(weights=numpy.array([[-1], [ 1]])) :type: `Function` sample see `sample <ComparatorMechanism.sample>` :default value: None :type: target see `target <ComparatorMechanism.target>` :default value: None :type: """ # By default, ComparatorMechanism compares two 1D np.array input_ports variable = Parameter(np.array([[0], [0]]), read_only=True, pnl_internal=True, constructor_argument='default_variable') function = Parameter(LinearCombination(weights=[[-1], [1]]), stateful=False, loggable=False) sample = None target = None output_ports = Parameter( [OUTCOME], stateful=False, loggable=False, read_only=True, structural=True, ) # ComparatorMechanism parameter and control signal assignments): paramClassDefaults = Mechanism_Base.paramClassDefaults.copy() standard_output_ports = ObjectiveMechanism.standard_output_ports.copy() standard_output_ports.extend([{NAME: SSE, FUNCTION: lambda x: np.sum(x * x)}, {NAME: MSE, FUNCTION: lambda x: np.sum(x * x) / safe_len(x)}]) standard_output_port_names = ObjectiveMechanism.standard_output_port_names.copy() standard_output_port_names.extend([SSE, MSE]) @tc.typecheck def __init__(self, default_variable=None, sample: tc.optional(tc.any(OutputPort, Mechanism_Base, dict, is_numeric, str))=None, target: tc.optional(tc.any(OutputPort, Mechanism_Base, dict, is_numeric, str))=None, function=LinearCombination(weights=[[-1], [1]]), output_ports:tc.optional(tc.any(str, Iterable)) = None, params=None, name=None, prefs:is_pref_set=None, **kwargs ): input_ports = kwargs.pop(INPUT_PORTS, {}) if input_ports: input_ports = {INPUT_PORTS: input_ports} input_ports = self._merge_legacy_constructor_args(sample, target, default_variable, input_ports) # Default output_ports is specified in constructor as a tuple rather than a list # to avoid "gotcha" associated with mutable default arguments # (see: bit.ly/2uID3s3 and http://docs.python-guide.org/en/latest/writing/gotchas/) if isinstance(output_ports, (str, tuple)): output_ports = list(output_ports) # IMPLEMENTATION NOTE: The following prevents the default from being updated by subsequent assignment # (in this case, to [OUTCOME, {NAME= MSE}]), but fails to expose default in IDE # output_ports = output_ports or [OUTCOME, MSE] super().__init__(monitor=input_ports, function=function, output_ports=output_ports, # prevent default from getting overwritten by later assign params=params, name=name, prefs=prefs, **kwargs ) # Require Projection to TARGET InputPort (already required for SAMPLE as primary InputPort) self.input_ports[1].parameters.require_projection_in_composition._set(True, Context()) def _validate_params(self, request_set, target_set=None, context=None): """If sample and target values are specified, validate that they are compatible """ if INPUT_PORTS in request_set and request_set[INPUT_PORTS] is not None: input_ports = request_set[INPUT_PORTS] # Validate that there are exactly two input_ports (for sample and target) num_input_ports = len(input_ports) if num_input_ports != 2: raise ComparatorMechanismError(f"{INPUT_PORTS} arg is specified for {self.__class__.__name__} " f"({len(input_ports)}), so it must have exactly 2 items, " f"one each for {SAMPLE} and {TARGET}.") # Validate that input_ports are specified as dicts if not all(isinstance(input_port,dict) for input_port in input_ports): raise ComparatorMechanismError("PROGRAM ERROR: all items in input_port args must be converted to dicts" " by calling Port._parse_port_spec() before calling super().__init__") # Validate length of variable for sample = target if VARIABLE in input_ports[0]: # input_ports arg specified in standard port specification dict format lengths = [len(input_port[VARIABLE]) if input_port[VARIABLE] is not None else 0 for input_port in input_ports] else: # input_ports arg specified in {<Port_Name>:<PORT SPECIFICATION DICT>} format lengths = [len(list(input_port_dict.values())[0][VARIABLE]) for input_port_dict in input_ports] if lengths[0] != lengths[1]: raise ComparatorMechanismError(f"Length of value specified for {SAMPLE} InputPort " f"of {self.__class__.__name__} ({lengths[0]}) must be " f"same as length of value specified for {TARGET} ({lengths[1]}).") elif SAMPLE in request_set and TARGET in request_set: sample = request_set[SAMPLE] if isinstance(sample, InputPort): sample_value = sample.value elif isinstance(sample, Mechanism): sample_value = sample.input_value[0] elif is_value_spec(sample): sample_value = sample else: sample_value = None target = request_set[TARGET] if isinstance(target, InputPort): target_value = target.value elif isinstance(target, Mechanism): target_value = target.input_value[0] elif is_value_spec(target): target_value = target else: target_value = None if sample is not None and target is not None: if not iscompatible(sample, target, **{kwCompatibilityLength: True, kwCompatibilityNumeric: True}): raise ComparatorMechanismError(f"The length of the sample ({len(sample)}) " f"must be the same as for the target ({len(target)})" f"for {self.__class__.__name__} {self.name}.") super()._validate_params(request_set=request_set, target_set=target_set, context=context) def _merge_legacy_constructor_args(self, sample, target, default_variable=None, input_ports=None): # USE sample and target TO CREATE AN InputPort specfication dictionary for each; # DO SAME FOR InputPorts argument, USE TO OVERWRITE ANY SPECIFICATIONS IN sample AND target DICTS # TRY tuple format AS WAY OF PROVIDED CONSOLIDATED variable AND OutputPort specifications sample_dict = _parse_port_spec(owner=self, port_type=InputPort, port_spec=sample, name=SAMPLE) target_dict = _parse_port_spec(owner=self, port_type=InputPort, port_spec=target, name=TARGET) # If either the default_variable arg or the input_ports arg is provided: # - validate that there are exactly two items in default_variable or input_ports list # - if there is an input_ports list, parse it and use it to update sample and target dicts if input_ports: input_ports = input_ports[INPUT_PORTS] # print("type input_ports = {}".format(type(input_ports))) if not isinstance(input_ports, list): raise ComparatorMechanismError(f"If an '{INPUT_PORTS}' argument is included in the constructor " f"for a {ComparatorMechanism.__name__} it must be a list with " f"two {InputPort.__name__} specifications.") input_ports = input_ports or default_variable if input_ports is not None: if len(input_ports)!=2: raise ComparatorMechanismError(f"If an \'input_ports\' arg is included in the constructor for a " f"{ComparatorMechanism.__name__}, it must be a list with exactly " f"two items (not {len(input_ports)}).") sample_input_port_dict = _parse_port_spec(owner=self, port_type=InputPort, port_spec=input_ports[0], name=SAMPLE, value=None) target_input_port_dict = _parse_port_spec(owner=self, port_type=InputPort, port_spec=input_ports[1], name=TARGET, value=None) sample_dict = recursive_update(sample_dict, sample_input_port_dict) target_dict = recursive_update(target_dict, target_input_port_dict) return [sample_dict, target_dict]
class ComparatorMechanism(ObjectiveMechanism): """ ComparatorMechanism( \ sample, \ target, \ input_states=[SAMPLE,TARGET] \ function=LinearCombination(weights=[[-1],[1]], \ output_states=OUTCOME \ params=None, \ name=None, \ prefs=None) Subclass of `ObjectiveMechanism` that compares the values of two `OutputStates <OutputState>`. COMMENT: Description: ComparatorMechanism is a subtype of the ObjectiveMechanism Subtype of the ProcssingMechanism Type of the Mechanism Category of the Component class. By default, it's function uses the LinearCombination Function to compare two input variables. COMPARISON_OPERATION (functionParams) determines whether the comparison is subtractive or divisive The function returns an array with the Hadamard (element-wise) differece/quotient of target vs. sample, as well as the mean, sum, sum of squares, and mean sum of squares of the comparison array Class attributes: + componentType (str): ComparatorMechanism + classPreference (PreferenceSet): Comparator_PreferenceSet, instantiated in __init__() + classPreferenceLevel (PreferenceLevel): PreferenceLevel.SUBTYPE + class_defaults.variable (value): Comparator_DEFAULT_STARTING_POINT // QUESTION: What to change here + paramClassDefaults (dict): {FUNCTION_PARAMS:{COMPARISON_OPERATION: SUBTRACTION}} Class methods: None MechanismRegistry: All instances of ComparatorMechanism are registered in MechanismRegistry, which maintains an entry for the subclass, a count for all instances of it, and a dictionary of those instances COMMENT Arguments --------- sample : OutputState, Mechanism, value, or string specifies the value to compare with the `target` by the `function <ComparatorMechanism.function>`. target : OutputState, Mechanism, value, or string specifies the value with which the `sample` is compared by the `function <ComparatorMechanism.function>`. input_states : List[InputState, value, str or dict] or Dict[] : default [SAMPLE, TARGET] specifies the names and/or formats to use for the values of the sample and target InputStates; by default they are named *SAMPLE* and *TARGET*, and their formats are match the value of the OutputStates specified in the **sample** and **target** arguments, respectively (see `ComparatorMechanism_Structure` for additional details). function : Function, function or method : default Distance(metric=DIFFERENCE) specifies the `function <Comparator.function>` used to compare the `sample` with the `target`. output_states : List[OutputState, value, str or dict] or Dict[] : default [OUTCOME] specifies the OutputStates for the Mechanism; params : Optional[Dict[param keyword: param value]] a `parameter dictionary <ParameterState_Specification>` that can be used to specify the parameters for the Mechanism, its function, and/or a custom function and its parameters. Values specified for parameters in the dictionary override any assigned to those parameters in arguments of the constructor. name : str : default see `name <ComparatorMechanism.name>` specifies the name of the ComparatorMechanism. prefs : PreferenceSet or specification dict : default Mechanism.classPreferences specifies the `PreferenceSet` for the ComparatorMechanism; see `prefs <ComparatorMechanism.prefs>` for details. Attributes ---------- COMMENT: default_variable : Optional[List[array] or 2d np.array] COMMENT sample : OutputState determines the value to compare with the `target` by the `function <ComparatorMechanism.function>`. target : OutputState determines the value with which `sample` is compared by the `function <ComparatorMechanism.function>`. input_states : ContentAddressableList[InputState, InputState] contains the two InputStates named, by default, *SAMPLE* and *TARGET*, each of which receives a `MappingProjection` from the OutputStates referenced by the `sample` and `target` attributes (see `ComparatorMechanism_Structure` for additional details). function : CombinationFunction, function or method used to compare the `sample` with the `target`. It can be any PsyNeuLink `CombinationFunction`, or a python function that takes a 2d array with two items and returns a 1d array of the same length as the two input items. value : 1d np.array the result of the comparison carried out by the `function <ComparatorMechanism.function>`. output_state : OutputState contains the `primary <OutputState_Primary>` OutputState of the ComparatorMechanism; the default is its *OUTCOME* OutputState, the value of which is equal to the `value <ComparatorMechanism.value>` attribute of the ComparatorMechanism. output_states : ContentAddressableList[OutputState] contains, by default, only the *OUTCOME* (primary) OutputState of the ComparatorMechanism. output_values : 2d np.array contains one item that is the value of the *OUTCOME* OutputState. name : str the name of the ComparatorMechanism; if it is not specified in the **name** argument of the constructor, a default is assigned by MechanismRegistry (see `Naming` for conventions used for default and duplicate names). prefs : PreferenceSet or specification dict the `PreferenceSet` for the ComparatorMechanism; if it is not specified in the **prefs** argument of the constructor, a default is assigned using `classPreferences` defined in __init__.py (see :doc:`PreferenceSet <LINK>` for details). """ componentType = COMPARATOR_MECHANISM classPreferenceLevel = PreferenceLevel.SUBTYPE # These will override those specified in TypeDefaultPreferences classPreferences = { kwPreferenceSetName: 'ComparatorCustomClassPreferences', kpReportOutputPref: PreferenceEntry(False, PreferenceLevel.INSTANCE) } class Parameters(ObjectiveMechanism.Parameters): """ Attributes ---------- variable see `variable <ComparatorMechanism.variable>` :default value: numpy.array([[0], [0]]) :type: numpy.ndarray :read only: True function see `function <ComparatorMechanism.function>` :default value: `LinearCombination`(offset=0.0, operation=sum, scale=1.0, weights=numpy.array([[-1], [ 1]])) :type: `Function` sample see `sample <ComparatorMechanism.sample>` :default value: None :type: target see `target <ComparatorMechanism.target>` :default value: None :type: """ # By default, ComparatorMechanism compares two 1D np.array input_states variable = Parameter(np.array([[0], [0]]), read_only=True) function = Parameter(LinearCombination(weights=[[-1], [1]]), stateful=False, loggable=False) sample = None target = None # ComparatorMechanism parameter and control signal assignments): paramClassDefaults = Mechanism_Base.paramClassDefaults.copy() standard_output_states = ObjectiveMechanism.standard_output_states.copy() standard_output_states.extend([{ NAME: SSE, FUNCTION: lambda x: np.sum(x * x) }, { NAME: MSE, FUNCTION: lambda x: np.sum(x * x) / safe_len(x) }]) @tc.typecheck def __init__( self, default_variable=None, sample: tc.optional( tc.any(OutputState, Mechanism_Base, dict, is_numeric, str)) = None, target: tc.optional( tc.any(OutputState, Mechanism_Base, dict, is_numeric, str)) = None, function=LinearCombination(weights=[[-1], [1]]), output_states: tc.optional(tc.any(str, Iterable)) = (OUTCOME, ), params=None, name=None, prefs: is_pref_set = None, ** input_states # IMPLEMENTATION NOTE: this is for backward compatibility ): input_states = self._merge_legacy_constructor_args( sample, target, default_variable, input_states) # Default output_states is specified in constructor as a tuple rather than a list # to avoid "gotcha" associated with mutable default arguments # (see: bit.ly/2uID3s3 and http://docs.python-guide.org/en/latest/writing/gotchas/) if isinstance(output_states, (str, tuple)): output_states = list(output_states) # IMPLEMENTATION NOTE: The following prevents the default from being updated by subsequent assignment # (in this case, to [OUTCOME, {NAME= MSE}]), but fails to expose default in IDE # output_states = output_states or [OUTCOME, MSE] # Create a StandardOutputStates object from the list of stand_output_states specified for the class if not isinstance(self.standard_output_states, StandardOutputStates): self.standard_output_states = StandardOutputStates( self, self.standard_output_states, indices=PRIMARY) super().__init__( # monitor=[sample, target], monitor=input_states, function=function, output_states=output_states.copy( ), # prevent default from getting overwritten by later assign params=params, name=name, prefs=prefs, context=ContextFlags.CONSTRUCTOR) def _validate_params(self, request_set, target_set=None, context=None): """If sample and target values are specified, validate that they are compatible """ if INPUT_STATES in request_set and request_set[ INPUT_STATES] is not None: input_states = request_set[INPUT_STATES] # Validate that there are exactly two input_states (for sample and target) num_input_states = len(input_states) if num_input_states != 2: raise ComparatorMechanismError( "{} arg is specified for {} ({}), so it must have exactly 2 items, " "one each for {} and {}".format(INPUT_STATES, self.__class__.__name__, len(input_states), SAMPLE, TARGET)) # Validate that input_states are specified as dicts if not all( isinstance(input_state, dict) for input_state in input_states): raise ComparatorMechanismError( "PROGRAM ERROR: all items in input_state args must be converted to dicts" " by calling State._parse_state_spec() before calling super().__init__" ) # Validate length of variable for sample = target if VARIABLE in input_states[0]: # input_states arg specified in standard state specification dict format lengths = [ len(input_state[VARIABLE]) for input_state in input_states ] else: # input_states arg specified in {<STATE_NAME>:<STATE SPECIFICATION DICT>} format lengths = [ len(list(input_state_dict.values())[0][VARIABLE]) for input_state_dict in input_states ] if lengths[0] != lengths[1]: raise ComparatorMechanismError( "Length of value specified for {} InputState of {} ({}) must be " "same as length of value specified for {} ({})".format( SAMPLE, self.__class__.__name__, lengths[0], TARGET, lengths[1])) elif SAMPLE in request_set and TARGET in request_set: sample = request_set[SAMPLE] if isinstance(sample, InputState): sample_value = sample.value elif isinstance(sample, Mechanism): sample_value = sample.input_value[0] elif is_value_spec(sample): sample_value = sample else: sample_value = None target = request_set[TARGET] if isinstance(target, InputState): target_value = target.value elif isinstance(target, Mechanism): target_value = target.input_value[0] elif is_value_spec(target): target_value = target else: target_value = None if sample is not None and target is not None: if not iscompatible( sample, target, **{ kwCompatibilityLength: True, kwCompatibilityNumeric: True }): raise ComparatorMechanismError( "The length of the sample ({}) must be the same as for the target ({})" "for {} {}".format(len(sample), len(target), self.__class__.__name__, self.name)) super()._validate_params(request_set=request_set, target_set=target_set, context=context) def _merge_legacy_constructor_args(self, sample, target, default_variable=None, input_states=None): # USE sample and target TO CREATE AN InputState specfication dictionary for each; # DO SAME FOR InputStates argument, USE TO OVERWRITE ANY SPECIFICATIONS IN sample AND target DICTS # TRY tuple format AS WAY OF PROVIDED CONSOLIDATED variable AND OutputState specifications sample_dict = _parse_state_spec(owner=self, state_type=InputState, state_spec=sample, name=SAMPLE) target_dict = _parse_state_spec(owner=self, state_type=InputState, state_spec=target, name=TARGET) # If either the default_variable arg or the input_states arg is provided: # - validate that there are exactly two items in default_variable or input_states list # - if there is an input_states list, parse it and use it to update sample and target dicts if input_states: input_states = input_states[INPUT_STATES] # print("type input_states = {}".format(type(input_states))) if not isinstance(input_states, list): raise ComparatorMechanismError( "If an \'{}\' argument is included in the constructor for a {} " "it must be a list with two {} specifications.".format( INPUT_STATES, ComparatorMechanism.__name__, InputState.__name__)) input_states = input_states or default_variable if input_states is not None: if len(input_states) != 2: raise ComparatorMechanismError( "If an \'input_states\' arg is " "included in the constructor for " "a {}, it must be a list with " "exactly two items (not {})".format( ComparatorMechanism.__name__, len(input_states))) sample_input_state_dict = _parse_state_spec( owner=self, state_type=InputState, state_spec=input_states[0], name=SAMPLE, value=None) target_input_state_dict = _parse_state_spec( owner=self, state_type=InputState, state_spec=input_states[1], name=TARGET, value=None) sample_dict = recursive_update(sample_dict, sample_input_state_dict) target_dict = recursive_update(target_dict, target_input_state_dict) return [sample_dict, target_dict]
def _validate_params(self, variable, request_set, target_set=None, context=None): """Validate matrix param `matrix <Stability.matrix>` argument must be one of the following - 2d list, np.ndarray or np.matrix - ParameterPort for one of the above - MappingProjection with a parameterPorts[MATRIX] for one of the above Parse matrix specification to insure it resolves to a square matrix (but leave in the form in which it was specified so that, if it is a ParameterPort or MappingProjection, its current value can be accessed at runtime (i.e., it can be used as a "pointer") """ # Validate matrix specification # (str can be automatically transformed to variable shape) if MATRIX in target_set and not isinstance(target_set[MATRIX], str): from psyneulink.core.components.projections.pathway.mappingprojection import MappingProjection from psyneulink.core.components.ports.parameterport import ParameterPort matrix = target_set[MATRIX] if isinstance(matrix, MappingProjection): try: matrix = matrix._parameter_ports[MATRIX].value param_type_string = "MappingProjection's ParameterPort" except KeyError: raise FunctionError( "The MappingProjection specified for the {} arg of {} ({}) must have a {} " "ParameterPort that has been assigned a 2d array or matrix" .format(MATRIX, self.name, matrix.shape, MATRIX)) elif isinstance(matrix, ParameterPort): try: matrix = matrix.value param_type_string = "ParameterPort" except KeyError: raise FunctionError( "The value of the {} parameterPort specified for the {} arg of {} ({}) " "must be a 2d array or matrix".format( MATRIX, MATRIX, self.name, matrix.shape)) else: param_type_string = "array or matrix" matrix = np.array(matrix) if matrix.ndim != 2: raise FunctionError( "The value of the {} specified for the {} arg of {} ({}) " "must be a 2d array or matrix".format( param_type_string, MATRIX, self.name, matrix)) rows = matrix.shape[0] cols = matrix.shape[1] # this mirrors the transformation in _function # it is a hack, and a general solution should be found squeezed = np.array(self.defaults.variable) if squeezed.ndim > 1: squeezed = np.squeeze(squeezed) size = safe_len(squeezed) if rows != size: raise FunctionError( "The value of the {} specified for the {} arg of {} is the wrong size;" "it is {}x{}, but must be square matrix of size {}".format( param_type_string, MATRIX, self.name, rows, cols, size)) if rows != cols: raise FunctionError( "The value of the {} specified for the {} arg of {} ({}) " "must be a square matrix".format(param_type_string, MATRIX, self.name, matrix)) super()._validate_params(request_set=request_set, target_set=target_set, context=context)