class PredictionVector(): """Maintain a `vector <PredictionVector.vector>` of terms for a regression model specified by a list of `specified_terms <PredictionVector.specified_terms>`. Terms are maintained in lists indexed by the `PV` Enum and, in "flattened" form within fields of a 1d array in `vector <PredictionVector.vector>` indexed by slices listed in the `idx <PredicitionVector.idx>` attribute. Arguments --------- feature_values : 2d nparray arrays of features to assign as the `PV.F` term of `terms <PredictionVector.terms>`. control_signals : List[ControlSignal] list containing the `ControlSignals <ControlSignal>` of an `OptimizationControlMechanism`; the `variable <Projection_Base.variable>` of each is assigned as the `PV.C` term of `terms <PredictionVector.terms>`. specified_terms : List[PV] terms to include in `vector <PredictionVector.vector>`; entries must be members of the `PV` Enum. Attributes ---------- specified_terms : List[PV] terms included as predictors, specified using members of the `PV` Enum. terms : List[ndarray] current value of ndarray terms, some of which are used to compute other terms. Only entries for terms in `specified_terms <specified_terms>` are assigned values; others are assigned `None`. num : List[int] number of arrays in outer dimension (axis 0) of each ndarray in `terms <PredictionVector.terms>`. Only entries for terms in `specified_terms <PredictionVector.specified_terms>` are assigned values; others are assigned `None`. num_elems : List[int] number of elements in flattened array for each ndarray in `terms <PredictionVector.terms>`. Only entries for terms in `specified_terms <PredictionVector.specified_terms>` are assigned values; others are assigned `None`. self.labels : List[str] label of each item in `terms <PredictionVector.terms>`. Only entries for terms in `specified_terms <PredictionVector.specified_terms>` are assigned values; others are assigned `None`. vector : ndarray contains the flattened array for all ndarrays in `terms <PredictionVector.terms>`. Contains only the terms specified in `specified_terms <PredictionVector.specified_terms>`. Indices for the fields corresponding to each term are listed in `idx <PredictionVector.idx>`. idx : List[slice] indices of `vector <PredictionVector.vector>` for the flattened version of each nd term in `terms <PredictionVector.terms>`. Only entries for terms in `specified_terms <PredictionVector.specified_terms>` are assigned values; others are assigned `None`. """ _deepcopy_shared_keys = ['control_signal_functions', '_compute_costs'] def __init__(self, feature_values, control_signals, specified_terms): # Get variable for control_signals specified in constructor control_allocation = [] for c in control_signals: if isinstance(c, ControlSignal): try: v = c.variable except: v = c.defaults.variable elif isinstance(c, type): if issubclass(c, ControlSignal): v = c.class_defaults.variable else: # If a class other than ControlSignal was specified, typecheck should have found it assert False, "PROGRAM ERROR: unrecognized specification for {} arg of {}: {}".\ format(repr(CONTROL_SIGNALS), self.name, c) else: port_spec_dict = _parse_port_spec(port_type=ControlSignal, owner=self, port_spec=c) v = port_spec_dict[VARIABLE] v = v or ControlSignal.defaults.variable control_allocation.append(v) # Get primary function and compute_costs function for each ControlSignal (called in compute_terms) self.control_signal_functions = [ c.function for c in control_signals ] self._compute_costs = [c.compute_costs for c in control_signals] def get_intrxn_labels(x): return list([s for s in powerset(x) if len(s) > 1]) def error_for_too_few_terms(term): spec_type = {'FF': 'feature_values', 'CC': 'control_signals'} raise RegressionCFAError( "Specification of {} for {} arg of {} " "requires at least two {} be specified".format( 'PV.' + term, repr(PREDICTION_TERMS), self.name, spec_type(term))) F = PV.F.value C = PV.C.value FF = PV.FF.value CC = PV.CC.value FC = PV.FC.value FFC = PV.FFC.value FCC = PV.FCC.value FFCC = PV.FFCC.value COST = PV.COST.value # RENAME THIS AS SPECIFIED_TERMS self.specified_terms = specified_terms self.terms = [None] * len(PV) self.idx = [None] * len(PV) self.num = [None] * len(PV) self.num_elems = [None] * len(PV) self.labels = [None] * len(PV) # MAIN EFFECT TERMS (unflattened) # Feature_values self.terms[F] = f = feature_values self.num[F] = len(f) # feature_values are arrays self.num_elems[F] = len( f.reshape(-1)) # num of total elements assigned to vector self.labels[F] = ['f' + str(i) for i in range(0, len(f))] # Placemarker until control_signals are instantiated self.terms[C] = c = np.array([[0]] * len(control_allocation)) self.num[C] = len(c) self.num_elems[C] = len(c.reshape(-1)) self.labels[C] = [ 'c' + str(i) for i in range(0, len(control_allocation)) ] # Costs # Placemarker until control_signals are instantiated self.terms[COST] = cst = np.array([[0]] * len(control_allocation)) self.num[COST] = self.num[C] self.num_elems[COST] = len(cst.reshape(-1)) self.labels[COST] = [ 'cst' + str(i) for i in range(0, self.num[COST]) ] # INTERACTION TERMS (unflattened) # Interactions among feature vectors if any(term in specified_terms for term in [PV.FF, PV.FFC, PV.FFCC]): if len(f) < 2: self.error_for_too_few_terms('FF') self.terms[FF] = ff = np.array( tensor_power(f, levels=range(2, len(f) + 1))) self.num[FF] = len(ff) self.num_elems[FF] = len(ff.reshape(-1)) self.labels[FF] = get_intrxn_labels(self.labels[F]) # Interactions among values of control_signals if any(term in specified_terms for term in [PV.CC, PV.FCC, PV.FFCC]): if len(c) < 2: self.error_for_too_few_terms('CC') self.terms[CC] = cc = np.array( tensor_power(c, levels=range(2, len(c) + 1))) self.num[CC] = len(cc) self.num_elems[CC] = len(cc.reshape(-1)) self.labels[CC] = get_intrxn_labels(self.labels[C]) # feature-control interactions if any(term in specified_terms for term in [PV.FC, PV.FCC, PV.FFCC]): self.terms[FC] = fc = np.tensordot(f, c, axes=0) self.num[FC] = len(fc.reshape(-1)) self.num_elems[FC] = len(fc.reshape(-1)) self.labels[FC] = list(product(self.labels[F], self.labels[C])) # feature-feature-control interactions if any(term in specified_terms for term in [PV.FFC, PV.FFCC]): if len(f) < 2: self.error_for_too_few_terms('FF') self.terms[FFC] = ffc = np.tensordot(ff, c, axes=0) self.num[FFC] = len(ffc.reshape(-1)) self.num_elems[FFC] = len(ffc.reshape(-1)) self.labels[FFC] = list( product(self.labels[FF], self.labels[C])) # feature-control-control interactions if any(term in specified_terms for term in [PV.FCC, PV.FFCC]): if len(c) < 2: self.error_for_too_few_terms('CC') self.terms[FCC] = fcc = np.tensordot(f, cc, axes=0) self.num[FCC] = len(fcc.reshape(-1)) self.num_elems[FCC] = len(fcc.reshape(-1)) self.labels[FCC] = list( product(self.labels[F], self.labels[CC])) # feature-feature-control-control interactions if PV.FFCC in specified_terms: if len(f) < 2: self.error_for_too_few_terms('FF') if len(c) < 2: self.error_for_too_few_terms('CC') self.terms[FFCC] = ffcc = np.tensordot(ff, cc, axes=0) self.num[FFCC] = len(ffcc.reshape(-1)) self.num_elems[FFCC] = len(ffcc.reshape(-1)) self.labels[FFCC] = list( product(self.labels[FF], self.labels[CC])) # Construct "flattened" vector based on specified terms, and assign indices (as slices) i = 0 for t in range(len(PV)): if t in [t.value for t in specified_terms]: self.idx[t] = slice(i, i + self.num_elems[t]) i += self.num_elems[t] self.vector = np.zeros(i) def __call__(self, terms: tc.any(PV, list)) -> tc.any(PV, tuple): """Return subvector(s) for specified term(s)""" if not isinstance(terms, list): return self.idx[terms.value] else: return tuple( [self.idx[pv_member.value] for pv_member in terms]) __deepcopy__ = get_deepcopy_with_shared( shared_keys=_deepcopy_shared_keys) # FIX: 11/9/19 LOCALLY MANAGE STATEFULNESS OF ControlSignals AND costs def update_vector(self, variable, feature_values=None, context=None): """Update vector with flattened versions of values returned from the `compute_terms <PredictionVector.compute_terms>` method of the `prediction_vector <RegressorCFA.prediction_vector>`. Updates `vector <PredictionVector.vector>` with current values of variable and, optionally, and feature_values. """ # # FIX: 11/9/19 LOCALLY MANAGE STATEFULNESS OF ControlSignals AND costs # if reference_variable is not None: # self.reference_variable = reference_variable if feature_values is not None: self.terms[PV.F.value] = np.array(feature_values) # FIX: 11/9/19 LOCALLY MANAGE STATEFULNESS OF ControlSignals AND costs computed_terms = self.compute_terms(np.array(variable), context=context) # Assign flattened versions of specified terms to vector for k, v in computed_terms.items(): if k in self.specified_terms: self.vector[self.idx[k.value]] = v.reshape(-1) def compute_terms(self, control_allocation, context=None): """Calculate interaction terms. Results are returned in a dict; entries are keyed using names of terms listed in the `PV` Enum. Values of entries are nd arrays. """ # FIX: 11/9/19 LOCALLY MANAGE STATEFULNESS OF ControlSignals AND costs # ref_variables = ref_variables or self.reference_variable # self.reference_variable = ref_variables terms = self.specified_terms computed_terms = {} # No need to calculate features, so just get values computed_terms[PV.F] = f = self.terms[PV.F.value] # Compute value of each control_signal from its variable c = [None] * len(control_allocation) for i, var in enumerate(control_allocation): c[i] = self.control_signal_functions[i](var, context=context) computed_terms[PV.C] = c = np.array(c) # Compute costs for new control_signal values if PV.COST in terms: # computed_terms[PV.COST] = -(np.exp(0.25*c-3)) # computed_terms[PV.COST] = -(np.exp(0.25*c-3) + (np.exp(0.25*np.abs(c-self.control_signal_change)-3))) costs = [None] * len(c) for i, val in enumerate(c): # MODIFIED 11/9/18 OLD: costs[i] = -(self._compute_costs[i](val, context=context)) # # MODIFIED 11/9/18 NEW: [JDC] # costs[i] = -(self._compute_costs[i](val, ref_variables[i])) # MODIFIED 11/9/18 END computed_terms[PV.COST] = np.array(costs) # Compute terms interaction that are used if any(term in terms for term in [PV.FF, PV.FFC, PV.FFCC]): computed_terms[PV.FF] = ff = np.array( tensor_power(f, range(2, self.num[PV.F.value] + 1))) if any(term in terms for term in [PV.CC, PV.FCC, PV.FFCC]): computed_terms[PV.CC] = cc = np.array( tensor_power(c, range(2, self.num[PV.C.value] + 1))) if any(term in terms for term in [PV.FC, PV.FCC, PV.FFCC]): computed_terms[PV.FC] = np.tensordot(f, c, axes=0) if any(term in terms for term in [PV.FFC, PV.FFCC]): computed_terms[PV.FFC] = np.tensordot(ff, c, axes=0) if any(term in terms for term in [PV.FCC, PV.FFCC]): computed_terms[PV.FCC] = np.tensordot(f, cc, axes=0) if PV.FFCC in terms: computed_terms[PV.FFCC] = np.tensordot(ff, cc, axes=0) return computed_terms
class Context(): """Used to indicate the state of initialization and phase of execution of a Component, as well as the source of call of a method; also used to specify and identify `conditions <Log_Conditions>` for `logging <Log>`. Attributes ---------- owner : Component Component to which the Context belongs. flags : binary vector represents the current operating context of the `owner <Context.owner>`; contains three fields `initialization_status <Context.initialization_status>`, `execution_phase <Context.initialization_status>`, and `source <Context.source>` (described below). flags_string : str contains the names of the flags currently set in each of the fields of the `flags <Context.flags>` attribute; note that this is *not* the same as the `string <Context.string>` attribute (see `note <Context_String_Note>`). initialization_status : field of flags attribute indicates the state of initialization of the Component; one and only one of the following flags is always set: * `DEFERRED_INIT <ContextFlags.DEFERRED_INIT>` * `INITIALIZING <ContextFlags.INITIALIZING>` * `VALIDATING <ContextFlags.VALIDATING>` * `INITIALIZED <ContextFlags.INITIALIZED>` * `REINITIALIZED <ContextFlags.REINITIALIZED>` execution_phase : field of flags attribute indicates the phase of execution of the Component; one or more of the following flags can be set: * `PROCESSING <ContextFlags.PROCESSING>` * `LEARNING <ContextFlags.LEARNING>` * `CONTROL <ContextFlags.CONTROL>` * `SIMULATION <ContextFlags.SIMULATION>` If no flags are set, the Component is not being executed at the current time, and `flags_string <Context.flags_string>` will include *IDLE* in the string. In some circumstances all of the `execution_phase <Context.execution_phase>` flags may be set, in which case `flags_string <Context.flags_string>` will include *EXECUTING* in the string. source : field of the flags attribute indicates the source of a call to a method belonging to or referencing the Component; one of the following flags is always set: * `CONSTRUCTOR <ContextFlags.CONSTRUCTOR>` * `COMMAND_LINE <ContextFlags.COMMAND_LINE>` * `COMPONENT <ContextFlags.COMPONENT>` * `COMPOSITION <ContextFlags.COMPOSITION>` composition : Composition the `Composition <Composition>` in which the `owner <Context.owner>` is currently being executed. execution_id : UUID the execution_id assigned to the Component by the Composition in which it is currently being executed. execution_time : TimeScale current time of the `Scheduler` running the Composition within which the Component is currently being executed. string : str contains message(s) relevant to a method of the Component currently invoked or that is referencing the Component. In general, this contains a copy of the **context** argument passed to method of the Component or one that references it, but it is possible that future uses will involve other messages. Note that this is *not* the same as the `flags_string <Context.flags_string>` attribute (see `note <Context_String_Note>`). """ __name__ = 'Context' _deepcopy_shared_keys = {'owner', 'composition', '_composition'} def __init__(self, owner=None, composition=None, flags=None, initialization_status=ContextFlags.UNINITIALIZED, execution_phase=None, # source=ContextFlags.COMPONENT, source=ContextFlags.NONE, execution_id:UUID=None, string:str='', time=None): self.owner = owner self.composition = composition self.initialization_status = initialization_status self.execution_phase = execution_phase self.source = source if flags: if (initialization_status != (ContextFlags.UNINITIALIZED) and not (flags & ContextFlags.INITIALIZATION_MASK & initialization_status)): raise ContextError("Conflict in assignment to flags ({}) and status ({}) arguments of Context for {}". format(ContextFlags._get_context_string(flags & ContextFlags.INITIALIZATION_MASK), ContextFlags._get_context_string(flags, INITIALIZATION_STATUS), self.owner.name)) if (execution_phase and not (flags & ContextFlags.EXECUTION_PHASE_MASK & execution_phase)): raise ContextError("Conflict in assignment to flags ({}) and execution_phase ({}) arguments " "of Context for {}". format(ContextFlags._get_context_string(flags & ContextFlags.EXECUTION_PHASE_MASK), ContextFlags._get_context_string(flags, EXECUTION_PHASE), self.owner.name)) if (source != ContextFlags.COMPONENT) and not (flags & ContextFlags.SOURCE_MASK & source): raise ContextError("Conflict in assignment to flags ({}) and source ({}) arguments of Context for {}". format(ContextFlags._get_context_string(flags & ContextFlags.SOURCE_MASK), ContextFlags._get_context_string(flags, SOURCE), self.owner.name)) self.execution_id = execution_id self.execution_time = None self.string = string __deepcopy__ = get_deepcopy_with_shared(_deepcopy_shared_keys) @property def composition(self): try: return self._composition except AttributeError: self._composition = None @composition.setter def composition(self, composition): # from psyneulink.core.compositions.composition import Composition # if isinstance(composition, Composition): if ( composition is None or composition.__class__.__name__ in { 'Composition', 'SystemComposition', 'PathwayComposition', 'AutodiffComposition', 'System', 'Process' } ): self._composition = composition else: raise ContextError("Assignment to context.composition for {} ({}) " "must be a Composition (or \'None\').".format(self.owner.name, composition)) @property def flags(self): try: return self._flags except: self._flags = ContextFlags.UNINITIALIZED |ContextFlags.COMPONENT return self._flags @flags.setter def flags(self, flags): if isinstance(flags, (ContextFlags, int)): self._flags = flags else: raise ContextError("\'{}\'{} argument in call to {} must be a {} or an int". format(FLAGS, flags, self.__name__, ContextFlags.__name__)) @property def flags_string(self): return ContextFlags._get_context_string(self.flags) @property def initialization_status(self): return self.flags & ContextFlags.INITIALIZATION_MASK @initialization_status.setter def initialization_status(self, flag): """Check that a flag is one and only one status flag """ flag &= ContextFlags.INITIALIZATION_MASK if flag in INITIALIZATION_STATUS_FLAGS: self.flags &= ContextFlags.UNINITIALIZED self.flags |= flag elif not flag or flag is ContextFlags.UNINITIALIZED: self.flags &= ContextFlags.UNINITIALIZED elif not (flag & ContextFlags.INITIALIZATION_MASK): raise ContextError("Attempt to assign a flag ({}) to {}.context.flags " "that is not an initialization status flag". format(ContextFlags._get_context_string(flag), self.owner.name)) else: raise ContextError("Attempt to assign more than one flag ({}) to {}.context.initialization_status". format(ContextFlags._get_context_string(flag), self.owner.name)) @property def execution_phase(self): v = self.flags & ContextFlags.EXECUTION_PHASE_MASK if v == 0: return ContextFlags.IDLE else: return v @execution_phase.setter def execution_phase(self, flag): """Check that a flag is one and only one execution_phase flag """ if flag in EXECUTION_PHASE_FLAGS: # self.flags |= flag self.flags &= ContextFlags.IDLE self.flags |= flag elif not flag or flag is ContextFlags.IDLE: self.flags &= ContextFlags.IDLE elif flag is ContextFlags.EXECUTING: self.flags |= flag elif not (flag & ContextFlags.EXECUTION_PHASE_MASK): raise ContextError("Attempt to assign a flag ({}) to {}.context.execution_phase " "that is not an execution phase flag". format(ContextFlags._get_context_string(flag), self.owner.name)) else: raise ContextError("Attempt to assign more than one flag ({}) to {}.context.execution_phase". format(ContextFlags._get_context_string(flag), self.owner.name)) @property def source(self): return self.flags & ContextFlags.SOURCE_MASK @source.setter def source(self, flag): """Check that a flag is one and only one source flag """ if flag in SOURCE_FLAGS: self.flags &= ContextFlags.NONE self.flags |= flag elif not flag or flag is ContextFlags.NONE: self.flags &= ContextFlags.NONE elif not flag & ContextFlags.SOURCE_MASK: raise ContextError("Attempt to assign a flag ({}) to {}.context.source that is not a source flag". format(ContextFlags._get_context_string(flag), self.owner.name)) else: raise ContextError("Attempt to assign more than one flag ({}) to {}.context.source". format(ContextFlags._get_context_string(flag), self.owner.name)) @property def execution_time(self): try: return self._execution_time except: return None @execution_time.setter def execution_time(self, time): self._execution_time = time def update_execution_time(self): if self.execution & ContextFlags.EXECUTING: self.execution_time = _get_time(self.owner, self.context.flags) else: raise ContextError("PROGRAM ERROR: attempt to call update_execution_time for {} " "when 'EXECUTING' was not in its context".format(self.owner.name)) def add_to_string(self, string): if self.string is None: self.string = string else: self.string = '{0} {1} {2}'.format(self.string, SEPARATOR_BAR, string)
class Context(): """Used to indicate the state of initialization and phase of execution of a Component, as well as the source of call of a method; also used to specify and identify `conditions <Log_Conditions>` for `logging <Log>`. Attributes ---------- owner : Component Component to which the Context belongs. flags : binary vector represents the current operating context of the `owner <Context.owner>`; contains two fields `execution_phase <Context.execution_phase>`, and `source <Context.source>` (described below). flags_string : str contains the names of the flags currently set in each of the fields of the `flags <Context.flags>` attribute; note that this is *not* the same as the `string <Context.string>` attribute (see `note <Context_String_Note>`). execution_phase : field of flags attribute indicates the phase of execution of the Component; one or more of the following flags can be set: * `PREPARING <ContextFlags.PREPARING>` * `PROCESSING <ContextFlags.PROCESSING>` * `LEARNING <ContextFlags.LEARNING>` * `CONTROL <ContextFlags.CONTROL>` * `IDLE <ContextFlags.IDLE>` If `IDLE` is set, the Component is not being executed at the current time, and `flags_string <Context.flags_string>` will include *IDLE* in the string. In some circumstances all of the `execution_phase <Context.execution_phase>` flags may be set (other than *IDLE* and *PREPARING*), in which case `flags_string <Context.flags_string>` will include *EXECUTING* in the string. source : field of the flags attribute indicates the source of a call to a method belonging to or referencing the Component; one of the following flags is always set: * `CONSTRUCTOR <ContextFlags.CONSTRUCTOR>` * `COMMAND_LINE <ContextFlags.COMMAND_LINE>` * `COMPONENT <ContextFlags.COMPONENT>` * `COMPOSITION <ContextFlags.COMPOSITION>` composition : Composition the `Composition <Composition>` in which the `owner <Context.owner>` is currently being executed. execution_id : str the execution_id assigned to the Component by the Composition in which it is currently being executed. execution_time : TimeScale current time of the `Scheduler` running the Composition within which the Component is currently being executed. string : str contains message(s) relevant to a method of the Component currently invoked or that is referencing the Component. In general, this contains a copy of the **context** argument passed to method of the Component or one that references it, but it is possible that future uses will involve other messages. Note that this is *not* the same as the `flags_string <Context.flags_string>` attribute (see `note <Context_String_Note>`). rpc_pipeline : Queue queue to populate with messages for external environment in cases where execution was triggered via RPC call (e.g. through PsyNeuLinkView). """ __name__ = 'Context' _deepcopy_shared_keys = {'owner', 'composition', '_composition'} def __init__( self, owner=None, composition=None, flags=None, execution_phase=ContextFlags.IDLE, # source=ContextFlags.COMPONENT, source=ContextFlags.NONE, runmode=ContextFlags.DEFAULT_MODE, execution_id=None, string: str = '', time=None, rpc_pipeline: Queue = None): self.owner = owner self.composition = composition self._execution_phase = execution_phase self._source = source self._runmode = runmode if flags: if (execution_phase and not (flags & ContextFlags.EXECUTION_PHASE_MASK & execution_phase)): raise ContextError( "Conflict in assignment to flags ({}) and execution_phase ({}) arguments " "of Context for {}".format( ContextFlags._get_context_string( flags & ContextFlags.EXECUTION_PHASE_MASK), ContextFlags._get_context_string( flags, EXECUTION_PHASE), self.owner.name)) if (source != ContextFlags.COMPONENT ) and not (flags & ContextFlags.SOURCE_MASK & source): raise ContextError( "Conflict in assignment to flags ({}) and source ({}) arguments of Context for {}" .format( ContextFlags._get_context_string( flags & ContextFlags.SOURCE_MASK), ContextFlags._get_context_string(flags, SOURCE), self.owner.name)) self.execution_id = execution_id self.execution_time = None self.string = string self.rpc_pipeline = rpc_pipeline __deepcopy__ = get_deepcopy_with_shared(_deepcopy_shared_keys) @property def composition(self): try: return self._composition except AttributeError: self._composition = None @composition.setter def composition(self, composition): # from psyneulink.core.compositions.composition import Composition # if isinstance(composition, Composition): if (composition is None or composition.__class__.__name__ in {'Composition', 'AutodiffComposition'}): self._composition = composition else: raise ContextError( "Assignment to context.composition for {self.owner.name} ({composition}) " "must be a Composition (or \'None\').") @property def flags(self): return self.execution_phase | self.source @flags.setter def flags(self, flags: ContextFlags): if isinstance(flags, (ContextFlags, int)): self.execution_phase = flags & ContextFlags.EXECUTION_PHASE_MASK self.source = flags & ContextFlags.SOURCE_MASK else: raise ContextError( "\'{}\'{} argument in call to {} must be a {} or an int". format(FLAGS, flags, self.__name__, ContextFlags.__name__)) @property def flags_string(self): return ContextFlags._get_context_string(self.flags) @property def execution_phase(self): return self._execution_phase @execution_phase.setter def execution_phase(self, flag): """Check that flag is a valid execution_phase flag assignment""" if not flag: self._execution_phase = ContextFlags.IDLE elif flag not in EXECUTION_PHASE_FLAGS: raise ContextError( f"Attempt to assign more than one non-SIMULATION flag ({str(flag)}) to execution_phase" ) elif (flag & ~ContextFlags.EXECUTION_PHASE_MASK): raise ContextError( "Attempt to assign a flag ({}) to execution_phase " "that is not an execution phase flag".format(str(flag))) else: self._execution_phase = flag @property def source(self): return self._source @source.setter def source(self, flag): """Check that a flag is one and only one source flag""" if flag in SOURCE_FLAGS: self._source = flag elif not flag: self._source = ContextFlags.NONE elif not flag & ContextFlags.SOURCE_MASK: raise ContextError( "Attempt to assign a flag ({}) to source that is not a source flag" .format(str(flag))) else: raise ContextError( "Attempt to assign more than one flag ({}) to source".format( str(flag))) @property def runmode(self): return self._runmode @runmode.setter def runmode(self, flag): """Check that a flag is one and only one run mode flag""" if (flag in RUN_MODE_FLAGS or (flag & ~ContextFlags.SIMULATION_MODE) in RUN_MODE_FLAGS): self._runmode = flag elif not flag: self._runmode = ContextFlags.DEFAULT_MODE elif not flag & ContextFlags.RUN_MODE_MASK: raise ContextError( "Attempt to assign a flag ({}) to run mode that is not a run mode flag" .format(str(flag))) else: raise ContextError( "Attempt to assign more than one non-SIMULATION flag ({}) to run mode" .format(str(flag))) @property def execution_time(self): try: return self._execution_time except AttributeError: return None @execution_time.setter def execution_time(self, time): self._execution_time = time def update_execution_time(self): if self.execution & ContextFlags.EXECUTING: self.execution_time = _get_time(self.owner, self.most_recent_context.flags) else: raise ContextError( "PROGRAM ERROR: attempt to call update_execution_time for {} " "when 'EXECUTING' was not in its context".format( self.owner.name)) def add_to_string(self, string): if self.string is None: self.string = string else: self.string = '{0} {1} {2}'.format(self.string, SEPARATOR_BAR, string) def _change_flags( self, *flags, operation=lambda attr, blank_flag, *flags: NotImplemented): # split by flag type to avoid extra costly binary operations on enum flags if all([flag in EXECUTION_PHASE_FLAGS for flag in flags]): self.execution_phase = operation(self.execution_phase, ContextFlags.IDLE, *flags) elif all([flag in SOURCE_FLAGS for flag in flags]): self.source = operation(self.source, ContextFlags.NONE, *flags) elif all([flag in RUN_MODE_FLAGS for flag in flags]): self.runmode = operation(self.runmode, ContextFlags.DEFAULT_MODE, *flags) else: raise ContextError( f'Flags must all correspond to one of: execution_phase, source, run mode' ) def add_flag(self, flag: ContextFlags): def add(attr, blank_flag, flag): return (attr & ~blank_flag) | flag self._change_flags(flag, operation=add) def remove_flag(self, flag: ContextFlags): def remove(attr, blank_flag, flag): if attr & flag: res = (attr | flag) ^ flag if res is ContextFlags.UNSET: res = blank_flag return res else: return attr self._change_flags(flag, operation=remove) def replace_flag(self, old: ContextFlags, new: ContextFlags): def replace(attr, blank_flag, old, new): return (attr & ~old) | new self._change_flags(old, new, operation=replace)
class PytorchModelCreator(torch.nn.Module): # sets up parameters of model & the information required for forward computation def __init__(self, composition, device, context=None): if not torch_available: raise Exception('Pytorch python module (torch) is not installed. Please install it with ' '`pip install torch` or `pip3 install torch`') super(PytorchModelCreator, self).__init__() # Maps Mechanism -> PytorchMechanismWrapper self.nodes = [] self.component_map = {} # Maps Projections -> PytorchProjectionWrappers self.projections = [] self.projection_map = {} self.params = nn.ParameterList() self.device = device self._composition = composition # Instantiate pytorch mechanisms for node in set(composition.nodes) - set(composition.get_nodes_by_role(NodeRole.LEARNING)): pytorch_node = PytorchMechanismWrapper(node, self._composition._get_node_index(node), device, context=context) self.component_map[node] = pytorch_node self.nodes.append(pytorch_node) # Instantiate pytorch projections for projection in composition.projections: if projection.sender.owner in self.component_map and projection.receiver.owner in self.component_map: proj_send = self.component_map[projection.sender.owner] proj_recv = self.component_map[projection.receiver.owner] port_idx = projection.sender.owner.output_ports.index(projection.sender) new_proj = PytorchProjectionWrapper(projection, list(self._composition._inner_projections).index(projection), port_idx, device, sender=proj_send, receiver=proj_recv, context=context) proj_send.add_efferent(new_proj) proj_recv.add_afferent(new_proj) self.projection_map[projection] = new_proj self.projections.append(new_proj) self.params.append(new_proj.matrix) c = Context() try: composition.scheduler._init_counts(execution_id=c.execution_id, base_execution_id=context.execution_id) except graph_scheduler.SchedulerError: # called from LLVM, no base context is provided composition.scheduler._init_counts(execution_id=c.execution_id) # Setup execution sets # 1) Remove all learning-specific nodes self.execution_sets = [x - set(composition.get_nodes_by_role(NodeRole.LEARNING)) for x in composition.scheduler.run(context=c)] # 2) Convert to pytorchcomponent representation self.execution_sets = [{self.component_map[comp] for comp in s if comp in self.component_map} for s in self.execution_sets] # 3) Remove empty execution sets self.execution_sets = [x for x in self.execution_sets if len(x) > 0] composition.scheduler._delete_counts(c.execution_id) __deepcopy__ = get_deepcopy_with_shared(shared_types=(Component, ComponentsMeta)) # generates llvm function for self.forward def _gen_llvm_function(self, *, ctx:pnlvm.LLVMBuilderContext, tags:frozenset): args = [ctx.get_state_struct_type(self._composition).as_pointer(), ctx.get_param_struct_type(self._composition).as_pointer(), ctx.get_data_struct_type(self._composition).as_pointer() ] builder = ctx.create_llvm_function(args, self) state, params, data = builder.function.args if "learning" in tags: self._gen_llvm_training_function_body(ctx, builder, state, params, data) else: model_input = builder.gep(data, [ctx.int32_ty(0), ctx.int32_ty(0), ctx.int32_ty(self._composition._get_node_index(self._composition.input_CIM))]) self._gen_llvm_forward_function_body(ctx, builder, state, params, model_input, data) builder.ret_void() return builder.function def _gen_llvm_forward_function_body(self, ctx, builder, state, params, arg_in, data): z_values = {} # dict for storing values of terminal (output) nodes for current_exec_set in self.execution_sets: for component in current_exec_set: mech_input_ty = ctx.get_input_struct_type(component._mechanism) variable = builder.alloca(mech_input_ty) z_values[component] = builder.alloca(mech_input_ty.elements[0].elements[0]) builder.store(z_values[component].type.pointee(None),z_values[component]) if NodeRole.INPUT in self._composition.get_roles_by_node(component._mechanism): input_ptr = builder.gep( variable, [ctx.int32_ty(0), ctx.int32_ty(0), ctx.int32_ty(0)]) input_id = component._idx mech_in = builder.gep(arg_in, [ctx.int32_ty(0), ctx.int32_ty(input_id)]) builder.store(builder.load(mech_in), input_ptr) for (proj_idx, proj) in enumerate(component.afferents): input_ptr = builder.gep( variable, [ctx.int32_ty(0), ctx.int32_ty(0), ctx.int32_ty(proj_idx)]) proj_output = proj._gen_llvm_execute(ctx, builder, state, params, data) # store in input ports struct builder.store(builder.load(proj_output), input_ptr) # HACK: Add to z_values struct gen_inject_vec_add(ctx, builder, proj_output, z_values[component], z_values[component]) component._gen_llvm_execute(ctx, builder, state, params, variable, data) return z_values # generates a function responsible for a single epoch of the training def _gen_llvm_training_backprop(self, ctx, optimizer, loss): composition = self._composition args = [ctx.get_state_struct_type(composition).as_pointer(), ctx.get_param_struct_type(composition).as_pointer(), ctx.get_data_struct_type(composition).as_pointer(), optimizer._get_optimizer_struct_type(ctx).as_pointer(), ] name = self._composition.name + "_training_backprop" builder = ctx.create_llvm_function(args, self, name) llvm_func = builder.function for a in llvm_func.args: if isinstance(a.type, pnlvm.ir.PointerType): a.attributes.add('noalias') state, params, data, optim_struct = llvm_func.args model_input = builder.gep(data, [ctx.int32_ty(0), ctx.int32_ty(0), ctx.int32_ty(self._composition._get_node_index(self._composition.input_CIM))]) model_output = data # setup useful mappings input_nodes = set(self._composition.get_nodes_by_role(NodeRole.INPUT)) # initialize optimizer params: delta_w = builder.gep(optim_struct, [ctx.int32_ty(0), ctx.int32_ty(optimizer._DELTA_W_NUM)]) # 2) call forward computation z_values = self._gen_llvm_forward_function_body( ctx, builder, state, params, model_input, data) # 3) compute errors loss_fn = ctx.import_llvm_function(loss) total_loss = builder.alloca(ctx.float_ty) builder.store(ctx.float_ty(0), total_loss) error_dict = {} for exec_set in reversed(self.execution_sets): for node in exec_set: if node._mechanism in input_nodes: continue node_z_value = z_values[node] activation_func_derivative = node._gen_llvm_execute_derivative_func(ctx, builder, state, params, node_z_value) error_val = builder.alloca(z_values[node].type.pointee) error_dict[node] = error_val if NodeRole.OUTPUT in self._composition.get_roles_by_node(node._mechanism): # We handle output layer here # compute dC/da = a_l - y(x) (TODO: Allow other cost functions! This only applies to MSE) # 1) Lookup desired target value terminal_sequence = self._composition._terminal_backprop_sequences[node._mechanism] target_idx = self._composition.get_nodes_by_role( NodeRole.INPUT).index(terminal_sequence[TARGET_MECHANISM]) node_target = builder.gep(model_input, [ctx.int32_ty(0), ctx.int32_ty(target_idx)]) # 2) Lookup desired output value node_output = builder.gep(model_output, [ctx.int32_ty(0), ctx.int32_ty(0), ctx.int32_ty(node._idx), ctx.int32_ty(0)]) tmp_loss = loss.gen_inject_lossfunc_call( ctx, builder, loss_fn, node_output, node_target) pnlvm.helpers.printf_float_array( builder, node_target, prefix=f"{node}\ttarget:\t") pnlvm.helpers.printf_float_array( builder, node_output, prefix=f"{node}\tvalue:\t") pnlvm.helpers.printf( builder, f"{node}\tloss:\t%f\n", tmp_loss, override_debug=False) builder.store(builder.fadd(builder.load( total_loss), tmp_loss), total_loss) loss_derivative = loss._gen_inject_loss_differential( ctx, builder, node_output, node_target) # compute δ_l = dσ/da ⊙ σ'(z) gen_inject_vec_hadamard( ctx, builder, activation_func_derivative, loss_derivative, error_val) else: # We propagate error backwards from next layer for proj_idx, proj in enumerate(node.efferents): efferent_node = proj.receiver efferent_node_error = error_dict[efferent_node] weights_llvmlite = proj._extract_llvm_matrix(ctx, builder, params) if proj_idx == 0: gen_inject_vxm_transposed( ctx, builder, efferent_node_error, weights_llvmlite, error_val) else: new_val = gen_inject_vxm_transposed( ctx, builder, efferent_node_error, weights_llvmlite) gen_inject_vec_add( ctx, builder, new_val, error_val, error_val) gen_inject_vec_hadamard( ctx, builder, activation_func_derivative, error_val, error_val) pnlvm.helpers.printf_float_array( builder, activation_func_derivative, prefix=f"{node}\tdSigma:\t") pnlvm.helpers.printf_float_array( builder, error_val, prefix=f"{node}\terror:\t") # 4) compute weight gradients for (node, err_val) in error_dict.items(): if node in input_nodes: continue for proj in node.afferents: # get a_(l-1) afferent_node_activation = builder.gep(model_output, [ctx.int32_ty(0), ctx.int32_ty(0), ctx.int32_ty(proj.sender._idx), ctx.int32_ty(0)]) # get dimensions of weight matrix weights_llvmlite = proj._extract_llvm_matrix(ctx, builder, params) pnlvm.helpers.printf_float_matrix(builder, weights_llvmlite, prefix= f"{proj.sender._mechanism} -> {proj.receiver._mechanism}\n", override_debug=False) # update delta_W node_delta_w = builder.gep(delta_w, [ctx.int32_ty(0), ctx.int32_ty(proj._idx)]) dim_x, dim_y = proj.matrix.shape with pnlvm.helpers.for_loop_zero_inc(builder, ctx.int32_ty(dim_x), "weight_update_loop_outer") as (b1, weight_row): with pnlvm.helpers.for_loop_zero_inc(b1, ctx.int32_ty(dim_y), "weight_update_loop_inner") as (b2, weight_column): a_val = b2.load(b2.gep(afferent_node_activation, [ctx.int32_ty(0), weight_row])) d_val = b2.load(b2.gep(err_val, [ctx.int32_ty(0), weight_column])) old_val = b2.load(b2.gep(node_delta_w, [ctx.int32_ty(0), weight_row, weight_column])) new_val = b2.fadd(old_val, b2.fmul(a_val, d_val)) b2.store(new_val, b2.gep(node_delta_w, [ctx.int32_ty(0), weight_row, weight_column])) pnlvm.helpers.printf(builder, "TOTAL LOSS:\t%.20f\n", builder.load(total_loss), override_debug=False) builder.ret_void() return builder.function def _gen_llvm_training_function_body(self, ctx, builder, state, params, data): composition = self._composition optimizer = self._get_compiled_optimizer() # setup loss loss_type = self._composition.loss_spec if loss_type == 'mse': loss = MSELoss() else: raise Exception("LOSS TYPE", loss_type, "NOT SUPPORTED") optimizer_step_f = ctx.import_llvm_function(optimizer) optimizer_struct_idx = len(state.type.pointee.elements) - 1 optimizer_struct = builder.gep(state, [ctx.int32_ty(0), ctx.int32_ty(optimizer_struct_idx)]) optimizer_zero_grad = ctx.import_llvm_function(optimizer.zero_grad(ctx).name) backprop = ctx.import_llvm_function(self._gen_llvm_training_backprop(ctx, optimizer, loss).name) # # FIXME: converting this call to inlined code results in # # significant longer compilation times builder.call(optimizer_zero_grad, [optimizer_struct]) builder.call(backprop, [state, params, data, optimizer_struct]) builder.call(optimizer_step_f, [optimizer_struct, params]) def _get_compiled_optimizer(self): # setup optimizer optimizer_type = self._composition.optimizer_type if optimizer_type == 'adam': optimizer = AdamOptimizer(self, lr=self._composition.learning_rate) elif optimizer_type == 'sgd': optimizer = SGDOptimizer(self, lr=self._composition.learning_rate) else: raise Exception("OPTIMIZER TYPE", optimizer_type, "NOT SUPPORTED") return optimizer # performs forward computation for the model @handle_external_context() def forward(self, inputs, context=None): outputs = {} # dict for storing values of terminal (output) nodes for current_exec_set in self.execution_sets: for component in current_exec_set: if NodeRole.INPUT in self._composition.get_roles_by_node(component._mechanism): component.execute(inputs[component._mechanism]) else: variable = component.collate_afferents() component.execute(variable) # save value in output list if we're at a node in the last execution set if NodeRole.OUTPUT in self._composition.get_roles_by_node(component._mechanism): outputs[component._mechanism] = component.value # NOTE: Context source needs to be set to COMMAND_LINE to force logs to update independantly of timesteps old_source = context.source context.source = ContextFlags.COMMAND_LINE self.log_values() self.log_weights() context.source = old_source return outputs def detach_all(self): for projection in self.projection_map.values(): projection.matrix.detach() def copy_weights_to_psyneulink(self, context=None): for projection, pytorch_rep in self.projection_map.items(): projection.parameters.matrix._set( pytorch_rep.matrix.detach().cpu().numpy(), context) projection.parameter_ports['matrix'].parameters.value._set( pytorch_rep.matrix.detach().cpu().numpy(), context) def log_weights(self): for proj in self.projections: proj.log_matrix() def log_values(self): for node in self.nodes: node.log_value()