class PredictionVector():
        """Maintain a `vector <PredictionVector.vector>` of terms for a regression model specified by a list of
        `specified_terms <PredictionVector.specified_terms>`.

        Terms are maintained in lists indexed by the `PV` Enum and, in "flattened" form within fields of a 1d
        array in `vector <PredictionVector.vector>` indexed by slices listed in the `idx <PredicitionVector.idx>`
        attribute.

        Arguments
        ---------

        feature_values : 2d nparray
            arrays of features to assign as the `PV.F` term of `terms <PredictionVector.terms>`.

        control_signals : List[ControlSignal]
            list containing the `ControlSignals <ControlSignal>` of an `OptimizationControlMechanism`;  the `variable
            <Projection_Base.variable>` of each is assigned as the `PV.C` term of `terms <PredictionVector.terms>`.

        specified_terms : List[PV]
            terms to include in `vector <PredictionVector.vector>`; entries must be members of the `PV` Enum.

        Attributes
        ----------

        specified_terms : List[PV]
            terms included as predictors, specified using members of the `PV` Enum.

        terms : List[ndarray]
            current value of ndarray terms, some of which are used to compute other terms. Only entries for terms in
            `specified_terms <specified_terms>` are assigned values; others are assigned `None`.

        num : List[int]
            number of arrays in outer dimension (axis 0) of each ndarray in `terms <PredictionVector.terms>`.
            Only entries for terms in `specified_terms <PredictionVector.specified_terms>` are assigned values;
            others are assigned `None`.

        num_elems : List[int]
            number of elements in flattened array for each ndarray in `terms <PredictionVector.terms>`.
            Only entries for terms in `specified_terms <PredictionVector.specified_terms>` are assigned values;
            others are assigned `None`.

        self.labels : List[str]
            label of each item in `terms <PredictionVector.terms>`. Only entries for terms in  `specified_terms
            <PredictionVector.specified_terms>` are assigned values; others are assigned `None`.

        vector : ndarray
            contains the flattened array for all ndarrays in `terms <PredictionVector.terms>`.  Contains only
            the terms specified in `specified_terms <PredictionVector.specified_terms>`.  Indices for the fields
            corresponding to each term are listed in `idx <PredictionVector.idx>`.

        idx : List[slice]
            indices of `vector <PredictionVector.vector>` for the flattened version of each nd term in
            `terms <PredictionVector.terms>`. Only entries for terms in `specified_terms
            <PredictionVector.specified_terms>` are assigned values; others are assigned `None`.

        """

        _deepcopy_shared_keys = ['control_signal_functions', '_compute_costs']

        def __init__(self, feature_values, control_signals, specified_terms):

            # Get variable for control_signals specified in constructor
            control_allocation = []
            for c in control_signals:
                if isinstance(c, ControlSignal):
                    try:
                        v = c.variable
                    except:
                        v = c.defaults.variable
                elif isinstance(c, type):
                    if issubclass(c, ControlSignal):
                        v = c.class_defaults.variable
                    else:  # If a class other than ControlSignal was specified, typecheck should have found it
                        assert False, "PROGRAM ERROR: unrecognized specification for {} arg of {}: {}".\
                                                      format(repr(CONTROL_SIGNALS), self.name, c)
                else:
                    port_spec_dict = _parse_port_spec(port_type=ControlSignal,
                                                      owner=self,
                                                      port_spec=c)
                    v = port_spec_dict[VARIABLE]
                    v = v or ControlSignal.defaults.variable
                control_allocation.append(v)
            # Get primary function and compute_costs function for each ControlSignal (called in compute_terms)
            self.control_signal_functions = [
                c.function for c in control_signals
            ]
            self._compute_costs = [c.compute_costs for c in control_signals]

            def get_intrxn_labels(x):
                return list([s for s in powerset(x) if len(s) > 1])

            def error_for_too_few_terms(term):
                spec_type = {'FF': 'feature_values', 'CC': 'control_signals'}
                raise RegressionCFAError(
                    "Specification of {} for {} arg of {} "
                    "requires at least two {} be specified".format(
                        'PV.' + term, repr(PREDICTION_TERMS), self.name,
                        spec_type(term)))

            F = PV.F.value
            C = PV.C.value
            FF = PV.FF.value
            CC = PV.CC.value
            FC = PV.FC.value
            FFC = PV.FFC.value
            FCC = PV.FCC.value
            FFCC = PV.FFCC.value
            COST = PV.COST.value

            # RENAME THIS AS SPECIFIED_TERMS
            self.specified_terms = specified_terms
            self.terms = [None] * len(PV)
            self.idx = [None] * len(PV)
            self.num = [None] * len(PV)
            self.num_elems = [None] * len(PV)
            self.labels = [None] * len(PV)

            # MAIN EFFECT TERMS (unflattened)

            # Feature_values
            self.terms[F] = f = feature_values
            self.num[F] = len(f)  # feature_values are arrays
            self.num_elems[F] = len(
                f.reshape(-1))  # num of total elements assigned to vector
            self.labels[F] = ['f' + str(i) for i in range(0, len(f))]

            # Placemarker until control_signals are instantiated
            self.terms[C] = c = np.array([[0]] * len(control_allocation))
            self.num[C] = len(c)
            self.num_elems[C] = len(c.reshape(-1))
            self.labels[C] = [
                'c' + str(i) for i in range(0, len(control_allocation))
            ]

            # Costs
            # Placemarker until control_signals are instantiated
            self.terms[COST] = cst = np.array([[0]] * len(control_allocation))
            self.num[COST] = self.num[C]
            self.num_elems[COST] = len(cst.reshape(-1))
            self.labels[COST] = [
                'cst' + str(i) for i in range(0, self.num[COST])
            ]

            # INTERACTION TERMS (unflattened)

            # Interactions among feature vectors
            if any(term in specified_terms
                   for term in [PV.FF, PV.FFC, PV.FFCC]):
                if len(f) < 2:
                    self.error_for_too_few_terms('FF')
                self.terms[FF] = ff = np.array(
                    tensor_power(f, levels=range(2,
                                                 len(f) + 1)))
                self.num[FF] = len(ff)
                self.num_elems[FF] = len(ff.reshape(-1))
                self.labels[FF] = get_intrxn_labels(self.labels[F])

            # Interactions among values of control_signals
            if any(term in specified_terms
                   for term in [PV.CC, PV.FCC, PV.FFCC]):
                if len(c) < 2:
                    self.error_for_too_few_terms('CC')
                self.terms[CC] = cc = np.array(
                    tensor_power(c, levels=range(2,
                                                 len(c) + 1)))
                self.num[CC] = len(cc)
                self.num_elems[CC] = len(cc.reshape(-1))
                self.labels[CC] = get_intrxn_labels(self.labels[C])

            # feature-control interactions
            if any(term in specified_terms
                   for term in [PV.FC, PV.FCC, PV.FFCC]):
                self.terms[FC] = fc = np.tensordot(f, c, axes=0)
                self.num[FC] = len(fc.reshape(-1))
                self.num_elems[FC] = len(fc.reshape(-1))
                self.labels[FC] = list(product(self.labels[F], self.labels[C]))

            # feature-feature-control interactions
            if any(term in specified_terms for term in [PV.FFC, PV.FFCC]):
                if len(f) < 2:
                    self.error_for_too_few_terms('FF')
                self.terms[FFC] = ffc = np.tensordot(ff, c, axes=0)
                self.num[FFC] = len(ffc.reshape(-1))
                self.num_elems[FFC] = len(ffc.reshape(-1))
                self.labels[FFC] = list(
                    product(self.labels[FF], self.labels[C]))

            # feature-control-control interactions
            if any(term in specified_terms for term in [PV.FCC, PV.FFCC]):
                if len(c) < 2:
                    self.error_for_too_few_terms('CC')
                self.terms[FCC] = fcc = np.tensordot(f, cc, axes=0)
                self.num[FCC] = len(fcc.reshape(-1))
                self.num_elems[FCC] = len(fcc.reshape(-1))
                self.labels[FCC] = list(
                    product(self.labels[F], self.labels[CC]))

            # feature-feature-control-control interactions
            if PV.FFCC in specified_terms:
                if len(f) < 2:
                    self.error_for_too_few_terms('FF')
                if len(c) < 2:
                    self.error_for_too_few_terms('CC')
                self.terms[FFCC] = ffcc = np.tensordot(ff, cc, axes=0)
                self.num[FFCC] = len(ffcc.reshape(-1))
                self.num_elems[FFCC] = len(ffcc.reshape(-1))
                self.labels[FFCC] = list(
                    product(self.labels[FF], self.labels[CC]))

            # Construct "flattened" vector based on specified terms, and assign indices (as slices)
            i = 0
            for t in range(len(PV)):
                if t in [t.value for t in specified_terms]:
                    self.idx[t] = slice(i, i + self.num_elems[t])
                    i += self.num_elems[t]

            self.vector = np.zeros(i)

        def __call__(self, terms: tc.any(PV, list)) -> tc.any(PV, tuple):
            """Return subvector(s) for specified term(s)"""
            if not isinstance(terms, list):
                return self.idx[terms.value]
            else:
                return tuple(
                    [self.idx[pv_member.value] for pv_member in terms])

        __deepcopy__ = get_deepcopy_with_shared(
            shared_keys=_deepcopy_shared_keys)

        # FIX: 11/9/19 LOCALLY MANAGE STATEFULNESS OF ControlSignals AND costs
        def update_vector(self, variable, feature_values=None, context=None):
            """Update vector with flattened versions of values returned from the `compute_terms
            <PredictionVector.compute_terms>` method of the `prediction_vector
            <RegressorCFA.prediction_vector>`.

            Updates `vector <PredictionVector.vector>` with current values of variable and, optionally,
            and feature_values.

            """

            # # FIX: 11/9/19 LOCALLY MANAGE STATEFULNESS OF ControlSignals AND costs
            # if reference_variable is not None:
            #     self.reference_variable = reference_variable

            if feature_values is not None:
                self.terms[PV.F.value] = np.array(feature_values)
            # FIX: 11/9/19 LOCALLY MANAGE STATEFULNESS OF ControlSignals AND costs
            computed_terms = self.compute_terms(np.array(variable),
                                                context=context)

            # Assign flattened versions of specified terms to vector
            for k, v in computed_terms.items():
                if k in self.specified_terms:
                    self.vector[self.idx[k.value]] = v.reshape(-1)

        def compute_terms(self, control_allocation, context=None):
            """Calculate interaction terms.

            Results are returned in a dict; entries are keyed using names of terms listed in the `PV` Enum.
            Values of entries are nd arrays.
            """

            # FIX: 11/9/19 LOCALLY MANAGE STATEFULNESS OF ControlSignals AND costs
            # ref_variables = ref_variables or self.reference_variable
            # self.reference_variable = ref_variables

            terms = self.specified_terms
            computed_terms = {}

            # No need to calculate features, so just get values
            computed_terms[PV.F] = f = self.terms[PV.F.value]

            # Compute value of each control_signal from its variable
            c = [None] * len(control_allocation)
            for i, var in enumerate(control_allocation):
                c[i] = self.control_signal_functions[i](var, context=context)
            computed_terms[PV.C] = c = np.array(c)

            # Compute costs for new control_signal values
            if PV.COST in terms:
                # computed_terms[PV.COST] = -(np.exp(0.25*c-3))
                # computed_terms[PV.COST] = -(np.exp(0.25*c-3) + (np.exp(0.25*np.abs(c-self.control_signal_change)-3)))
                costs = [None] * len(c)
                for i, val in enumerate(c):
                    # MODIFIED 11/9/18 OLD:
                    costs[i] = -(self._compute_costs[i](val, context=context))
                    # # MODIFIED 11/9/18 NEW: [JDC]
                    # costs[i] = -(self._compute_costs[i](val, ref_variables[i]))
                    # MODIFIED 11/9/18 END
                computed_terms[PV.COST] = np.array(costs)

            # Compute terms interaction that are used
            if any(term in terms for term in [PV.FF, PV.FFC, PV.FFCC]):
                computed_terms[PV.FF] = ff = np.array(
                    tensor_power(f, range(2, self.num[PV.F.value] + 1)))
            if any(term in terms for term in [PV.CC, PV.FCC, PV.FFCC]):
                computed_terms[PV.CC] = cc = np.array(
                    tensor_power(c, range(2, self.num[PV.C.value] + 1)))
            if any(term in terms for term in [PV.FC, PV.FCC, PV.FFCC]):
                computed_terms[PV.FC] = np.tensordot(f, c, axes=0)
            if any(term in terms for term in [PV.FFC, PV.FFCC]):
                computed_terms[PV.FFC] = np.tensordot(ff, c, axes=0)
            if any(term in terms for term in [PV.FCC, PV.FFCC]):
                computed_terms[PV.FCC] = np.tensordot(f, cc, axes=0)
            if PV.FFCC in terms:
                computed_terms[PV.FFCC] = np.tensordot(ff, cc, axes=0)

            return computed_terms
Beispiel #2
0
class Context():
    """Used to indicate the state of initialization and phase of execution of a Component, as well as the source of
    call of a method;  also used to specify and identify `conditions <Log_Conditions>` for `logging <Log>`.


    Attributes
    ----------

    owner : Component
        Component to which the Context belongs.

    flags : binary vector
        represents the current operating context of the `owner <Context.owner>`; contains three fields
        `initialization_status <Context.initialization_status>`, `execution_phase <Context.initialization_status>`,
        and `source <Context.source>` (described below).

    flags_string : str
        contains the names of the flags currently set in each of the fields of the `flags <Context.flags>` attribute;
        note that this is *not* the same as the `string <Context.string>` attribute (see `note <Context_String_Note>`).

    initialization_status : field of flags attribute
        indicates the state of initialization of the Component;
        one and only one of the following flags is always set:

            * `DEFERRED_INIT <ContextFlags.DEFERRED_INIT>`
            * `INITIALIZING <ContextFlags.INITIALIZING>`
            * `VALIDATING <ContextFlags.VALIDATING>`
            * `INITIALIZED <ContextFlags.INITIALIZED>`
            * `REINITIALIZED <ContextFlags.REINITIALIZED>`

    execution_phase : field of flags attribute
        indicates the phase of execution of the Component;
        one or more of the following flags can be set:

            * `PROCESSING <ContextFlags.PROCESSING>`
            * `LEARNING <ContextFlags.LEARNING>`
            * `CONTROL <ContextFlags.CONTROL>`
            * `SIMULATION <ContextFlags.SIMULATION>`
        If no flags are set, the Component is not being executed at the current time, and `flags_string
        <Context.flags_string>` will include *IDLE* in the string.  In some circumstances all of the
        `execution_phase <Context.execution_phase>` flags may be set, in which case `flags_string
        <Context.flags_string>` will include *EXECUTING* in the string.

    source : field of the flags attribute
        indicates the source of a call to a method belonging to or referencing the Component;
        one of the following flags is always set:

            * `CONSTRUCTOR <ContextFlags.CONSTRUCTOR>`
            * `COMMAND_LINE <ContextFlags.COMMAND_LINE>`
            * `COMPONENT <ContextFlags.COMPONENT>`
            * `COMPOSITION <ContextFlags.COMPOSITION>`

    composition : Composition
      the `Composition <Composition>` in which the `owner <Context.owner>` is currently being executed.

    execution_id : UUID
      the execution_id assigned to the Component by the Composition in which it is currently being executed.

    execution_time : TimeScale
      current time of the `Scheduler` running the Composition within which the Component is currently being executed.

    string : str
      contains message(s) relevant to a method of the Component currently invoked or that is referencing the Component.
      In general, this contains a copy of the **context** argument passed to method of the Component or one that
      references it, but it is possible that future uses will involve other messages.  Note that this is *not* the
      same as the `flags_string <Context.flags_string>` attribute (see `note <Context_String_Note>`).

    """

    __name__ = 'Context'
    _deepcopy_shared_keys = {'owner', 'composition', '_composition'}

    def __init__(self,
                 owner=None,
                 composition=None,
                 flags=None,
                 initialization_status=ContextFlags.UNINITIALIZED,
                 execution_phase=None,
                 # source=ContextFlags.COMPONENT,
                 source=ContextFlags.NONE,
                 execution_id:UUID=None,
                 string:str='', time=None):

        self.owner = owner
        self.composition = composition
        self.initialization_status = initialization_status
        self.execution_phase = execution_phase
        self.source = source
        if flags:
            if (initialization_status != (ContextFlags.UNINITIALIZED) and
                    not (flags & ContextFlags.INITIALIZATION_MASK & initialization_status)):
                raise ContextError("Conflict in assignment to flags ({}) and status ({}) arguments of Context for {}".
                                   format(ContextFlags._get_context_string(flags & ContextFlags.INITIALIZATION_MASK),
                                          ContextFlags._get_context_string(flags, INITIALIZATION_STATUS),
                                          self.owner.name))
            if (execution_phase and not (flags & ContextFlags.EXECUTION_PHASE_MASK & execution_phase)):
                raise ContextError("Conflict in assignment to flags ({}) and execution_phase ({}) arguments "
                                   "of Context for {}".
                                   format(ContextFlags._get_context_string(flags & ContextFlags.EXECUTION_PHASE_MASK),
                                          ContextFlags._get_context_string(flags, EXECUTION_PHASE), self.owner.name))
            if (source != ContextFlags.COMPONENT) and not (flags & ContextFlags.SOURCE_MASK & source):
                raise ContextError("Conflict in assignment to flags ({}) and source ({}) arguments of Context for {}".
                                   format(ContextFlags._get_context_string(flags & ContextFlags.SOURCE_MASK),
                                          ContextFlags._get_context_string(flags, SOURCE),
                                          self.owner.name))
        self.execution_id = execution_id
        self.execution_time = None
        self.string = string

    __deepcopy__ = get_deepcopy_with_shared(_deepcopy_shared_keys)

    @property
    def composition(self):
        try:
            return self._composition
        except AttributeError:
            self._composition = None

    @composition.setter
    def composition(self, composition):
        # from psyneulink.core.compositions.composition import Composition
        # if isinstance(composition, Composition):
        if (
            composition is None
            or composition.__class__.__name__ in {
                'Composition', 'SystemComposition', 'PathwayComposition', 'AutodiffComposition', 'System', 'Process'
            }
        ):
            self._composition = composition
        else:
            raise ContextError("Assignment to context.composition for {} ({}) "
                               "must be a Composition (or \'None\').".format(self.owner.name, composition))

    @property
    def flags(self):
        try:
            return self._flags
        except:
            self._flags = ContextFlags.UNINITIALIZED |ContextFlags.COMPONENT
            return self._flags

    @flags.setter
    def flags(self, flags):
        if isinstance(flags, (ContextFlags, int)):
            self._flags = flags
        else:
            raise ContextError("\'{}\'{} argument in call to {} must be a {} or an int".
                               format(FLAGS, flags, self.__name__, ContextFlags.__name__))

    @property
    def flags_string(self):
        return ContextFlags._get_context_string(self.flags)

    @property
    def initialization_status(self):
        return self.flags & ContextFlags.INITIALIZATION_MASK

    @initialization_status.setter
    def initialization_status(self, flag):
        """Check that a flag is one and only one status flag """
        flag &= ContextFlags.INITIALIZATION_MASK
        if flag in INITIALIZATION_STATUS_FLAGS:
            self.flags &= ContextFlags.UNINITIALIZED
            self.flags |= flag
        elif not flag or flag is ContextFlags.UNINITIALIZED:
            self.flags &= ContextFlags.UNINITIALIZED
        elif not (flag & ContextFlags.INITIALIZATION_MASK):
            raise ContextError("Attempt to assign a flag ({}) to {}.context.flags "
                               "that is not an initialization status flag".
                               format(ContextFlags._get_context_string(flag), self.owner.name))
        else:
            raise ContextError("Attempt to assign more than one flag ({}) to {}.context.initialization_status".
                               format(ContextFlags._get_context_string(flag), self.owner.name))

    @property
    def execution_phase(self):
        v = self.flags & ContextFlags.EXECUTION_PHASE_MASK
        if v == 0:
            return ContextFlags.IDLE
        else:
            return v


    @execution_phase.setter
    def execution_phase(self, flag):
        """Check that a flag is one and only one execution_phase flag """
        if flag in EXECUTION_PHASE_FLAGS:
            # self.flags |= flag
            self.flags &= ContextFlags.IDLE
            self.flags |= flag
        elif not flag or flag is ContextFlags.IDLE:
            self.flags &= ContextFlags.IDLE
        elif flag is ContextFlags.EXECUTING:
            self.flags |= flag
        elif not (flag & ContextFlags.EXECUTION_PHASE_MASK):
            raise ContextError("Attempt to assign a flag ({}) to {}.context.execution_phase "
                               "that is not an execution phase flag".
                               format(ContextFlags._get_context_string(flag), self.owner.name))
        else:
            raise ContextError("Attempt to assign more than one flag ({}) to {}.context.execution_phase".
                               format(ContextFlags._get_context_string(flag), self.owner.name))

    @property
    def source(self):
        return self.flags & ContextFlags.SOURCE_MASK

    @source.setter
    def source(self, flag):
        """Check that a flag is one and only one source flag """
        if flag in SOURCE_FLAGS:
            self.flags &= ContextFlags.NONE
            self.flags |= flag
        elif not flag or flag is ContextFlags.NONE:
            self.flags &= ContextFlags.NONE
        elif not flag & ContextFlags.SOURCE_MASK:
            raise ContextError("Attempt to assign a flag ({}) to {}.context.source that is not a source flag".
                               format(ContextFlags._get_context_string(flag), self.owner.name))
        else:
            raise ContextError("Attempt to assign more than one flag ({}) to {}.context.source".
                               format(ContextFlags._get_context_string(flag), self.owner.name))

    @property
    def execution_time(self):
        try:
            return self._execution_time
        except:
            return None

    @execution_time.setter
    def execution_time(self, time):
        self._execution_time = time

    def update_execution_time(self):
        if self.execution & ContextFlags.EXECUTING:
            self.execution_time = _get_time(self.owner, self.context.flags)
        else:
            raise ContextError("PROGRAM ERROR: attempt to call update_execution_time for {} "
                               "when 'EXECUTING' was not in its context".format(self.owner.name))

    def add_to_string(self, string):
        if self.string is None:
            self.string = string
        else:
            self.string = '{0} {1} {2}'.format(self.string, SEPARATOR_BAR, string)
Beispiel #3
0
class Context():
    """Used to indicate the state of initialization and phase of execution of a Component, as well as the source of
    call of a method;  also used to specify and identify `conditions <Log_Conditions>` for `logging <Log>`.


    Attributes
    ----------

    owner : Component
        Component to which the Context belongs.

    flags : binary vector
        represents the current operating context of the `owner <Context.owner>`; contains two fields
        `execution_phase <Context.execution_phase>`,
        and `source <Context.source>` (described below).

    flags_string : str
        contains the names of the flags currently set in each of the fields of the `flags <Context.flags>` attribute;
        note that this is *not* the same as the `string <Context.string>` attribute (see `note <Context_String_Note>`).

    execution_phase : field of flags attribute
        indicates the phase of execution of the Component;
        one or more of the following flags can be set:

            * `PREPARING <ContextFlags.PREPARING>`
            * `PROCESSING <ContextFlags.PROCESSING>`
            * `LEARNING <ContextFlags.LEARNING>`
            * `CONTROL <ContextFlags.CONTROL>`
            * `IDLE <ContextFlags.IDLE>`

        If `IDLE` is set, the Component is not being executed at the current time, and `flags_string
        <Context.flags_string>` will include *IDLE* in the string.  In some circumstances all of the
        `execution_phase <Context.execution_phase>` flags may be set (other than *IDLE* and *PREPARING*),
        in which case `flags_string <Context.flags_string>` will include *EXECUTING* in the string.

    source : field of the flags attribute
        indicates the source of a call to a method belonging to or referencing the Component;
        one of the following flags is always set:

            * `CONSTRUCTOR <ContextFlags.CONSTRUCTOR>`
            * `COMMAND_LINE <ContextFlags.COMMAND_LINE>`
            * `COMPONENT <ContextFlags.COMPONENT>`
            * `COMPOSITION <ContextFlags.COMPOSITION>`

    composition : Composition
      the `Composition <Composition>` in which the `owner <Context.owner>` is currently being executed.

    execution_id : str
      the execution_id assigned to the Component by the Composition in which it is currently being executed.

    execution_time : TimeScale
      current time of the `Scheduler` running the Composition within which the Component is currently being executed.

    string : str
      contains message(s) relevant to a method of the Component currently invoked or that is referencing the Component.
      In general, this contains a copy of the **context** argument passed to method of the Component or one that
      references it, but it is possible that future uses will involve other messages.  Note that this is *not* the
      same as the `flags_string <Context.flags_string>` attribute (see `note <Context_String_Note>`).

    rpc_pipeline : Queue
      queue to populate with messages for external environment in cases where execution was triggered via RPC call
      (e.g. through PsyNeuLinkView).

    """

    __name__ = 'Context'
    _deepcopy_shared_keys = {'owner', 'composition', '_composition'}

    def __init__(
            self,
            owner=None,
            composition=None,
            flags=None,
            execution_phase=ContextFlags.IDLE,
            # source=ContextFlags.COMPONENT,
            source=ContextFlags.NONE,
            runmode=ContextFlags.DEFAULT_MODE,
            execution_id=None,
            string: str = '',
            time=None,
            rpc_pipeline: Queue = None):

        self.owner = owner
        self.composition = composition
        self._execution_phase = execution_phase
        self._source = source
        self._runmode = runmode

        if flags:
            if (execution_phase
                    and not (flags & ContextFlags.EXECUTION_PHASE_MASK
                             & execution_phase)):
                raise ContextError(
                    "Conflict in assignment to flags ({}) and execution_phase ({}) arguments "
                    "of Context for {}".format(
                        ContextFlags._get_context_string(
                            flags & ContextFlags.EXECUTION_PHASE_MASK),
                        ContextFlags._get_context_string(
                            flags, EXECUTION_PHASE), self.owner.name))
            if (source != ContextFlags.COMPONENT
                ) and not (flags & ContextFlags.SOURCE_MASK & source):
                raise ContextError(
                    "Conflict in assignment to flags ({}) and source ({}) arguments of Context for {}"
                    .format(
                        ContextFlags._get_context_string(
                            flags & ContextFlags.SOURCE_MASK),
                        ContextFlags._get_context_string(flags, SOURCE),
                        self.owner.name))
        self.execution_id = execution_id
        self.execution_time = None
        self.string = string
        self.rpc_pipeline = rpc_pipeline

    __deepcopy__ = get_deepcopy_with_shared(_deepcopy_shared_keys)

    @property
    def composition(self):
        try:
            return self._composition
        except AttributeError:
            self._composition = None

    @composition.setter
    def composition(self, composition):
        # from psyneulink.core.compositions.composition import Composition
        # if isinstance(composition, Composition):
        if (composition is None or composition.__class__.__name__
                in {'Composition', 'AutodiffComposition'}):
            self._composition = composition
        else:
            raise ContextError(
                "Assignment to context.composition for {self.owner.name} ({composition}) "
                "must be a Composition (or \'None\').")

    @property
    def flags(self):
        return self.execution_phase | self.source

    @flags.setter
    def flags(self, flags: ContextFlags):
        if isinstance(flags, (ContextFlags, int)):
            self.execution_phase = flags & ContextFlags.EXECUTION_PHASE_MASK
            self.source = flags & ContextFlags.SOURCE_MASK
        else:
            raise ContextError(
                "\'{}\'{} argument in call to {} must be a {} or an int".
                format(FLAGS, flags, self.__name__, ContextFlags.__name__))

    @property
    def flags_string(self):
        return ContextFlags._get_context_string(self.flags)

    @property
    def execution_phase(self):
        return self._execution_phase

    @execution_phase.setter
    def execution_phase(self, flag):
        """Check that flag is a valid execution_phase flag assignment"""
        if not flag:
            self._execution_phase = ContextFlags.IDLE
        elif flag not in EXECUTION_PHASE_FLAGS:
            raise ContextError(
                f"Attempt to assign more than one non-SIMULATION flag ({str(flag)}) to execution_phase"
            )
        elif (flag & ~ContextFlags.EXECUTION_PHASE_MASK):
            raise ContextError(
                "Attempt to assign a flag ({}) to execution_phase "
                "that is not an execution phase flag".format(str(flag)))
        else:
            self._execution_phase = flag

    @property
    def source(self):
        return self._source

    @source.setter
    def source(self, flag):
        """Check that a flag is one and only one source flag"""
        if flag in SOURCE_FLAGS:
            self._source = flag
        elif not flag:
            self._source = ContextFlags.NONE
        elif not flag & ContextFlags.SOURCE_MASK:
            raise ContextError(
                "Attempt to assign a flag ({}) to source that is not a source flag"
                .format(str(flag)))
        else:
            raise ContextError(
                "Attempt to assign more than one flag ({}) to source".format(
                    str(flag)))

    @property
    def runmode(self):
        return self._runmode

    @runmode.setter
    def runmode(self, flag):
        """Check that a flag is one and only one run mode flag"""
        if (flag in RUN_MODE_FLAGS
                or (flag & ~ContextFlags.SIMULATION_MODE) in RUN_MODE_FLAGS):
            self._runmode = flag
        elif not flag:
            self._runmode = ContextFlags.DEFAULT_MODE
        elif not flag & ContextFlags.RUN_MODE_MASK:
            raise ContextError(
                "Attempt to assign a flag ({}) to run mode that is not a run mode flag"
                .format(str(flag)))
        else:
            raise ContextError(
                "Attempt to assign more than one non-SIMULATION flag ({}) to run mode"
                .format(str(flag)))

    @property
    def execution_time(self):
        try:
            return self._execution_time
        except AttributeError:
            return None

    @execution_time.setter
    def execution_time(self, time):
        self._execution_time = time

    def update_execution_time(self):
        if self.execution & ContextFlags.EXECUTING:
            self.execution_time = _get_time(self.owner,
                                            self.most_recent_context.flags)
        else:
            raise ContextError(
                "PROGRAM ERROR: attempt to call update_execution_time for {} "
                "when 'EXECUTING' was not in its context".format(
                    self.owner.name))

    def add_to_string(self, string):
        if self.string is None:
            self.string = string
        else:
            self.string = '{0} {1} {2}'.format(self.string, SEPARATOR_BAR,
                                               string)

    def _change_flags(
            self,
            *flags,
            operation=lambda attr, blank_flag, *flags: NotImplemented):
        # split by flag type to avoid extra costly binary operations on enum flags
        if all([flag in EXECUTION_PHASE_FLAGS for flag in flags]):
            self.execution_phase = operation(self.execution_phase,
                                             ContextFlags.IDLE, *flags)
        elif all([flag in SOURCE_FLAGS for flag in flags]):
            self.source = operation(self.source, ContextFlags.NONE, *flags)
        elif all([flag in RUN_MODE_FLAGS for flag in flags]):
            self.runmode = operation(self.runmode, ContextFlags.DEFAULT_MODE,
                                     *flags)
        else:
            raise ContextError(
                f'Flags must all correspond to one of: execution_phase, source, run mode'
            )

    def add_flag(self, flag: ContextFlags):
        def add(attr, blank_flag, flag):
            return (attr & ~blank_flag) | flag

        self._change_flags(flag, operation=add)

    def remove_flag(self, flag: ContextFlags):
        def remove(attr, blank_flag, flag):
            if attr & flag:
                res = (attr | flag) ^ flag
                if res is ContextFlags.UNSET:
                    res = blank_flag
                return res
            else:
                return attr

        self._change_flags(flag, operation=remove)

    def replace_flag(self, old: ContextFlags, new: ContextFlags):
        def replace(attr, blank_flag, old, new):
            return (attr & ~old) | new

        self._change_flags(old, new, operation=replace)
Beispiel #4
0
class PytorchModelCreator(torch.nn.Module):
    # sets up parameters of model & the information required for forward computation
    def __init__(self, composition, device, context=None):

        if not torch_available:
            raise Exception('Pytorch python module (torch) is not installed. Please install it with '
                            '`pip install torch` or `pip3 install torch`')

        super(PytorchModelCreator, self).__init__()

        # Maps Mechanism -> PytorchMechanismWrapper
        self.nodes = []
        self.component_map = {}

        # Maps Projections -> PytorchProjectionWrappers
        self.projections = []
        self.projection_map = {}

        self.params = nn.ParameterList()
        self.device = device
        self._composition = composition

        # Instantiate pytorch mechanisms
        for node in set(composition.nodes) - set(composition.get_nodes_by_role(NodeRole.LEARNING)):
            pytorch_node = PytorchMechanismWrapper(node, self._composition._get_node_index(node), device, context=context)
            self.component_map[node] = pytorch_node
            self.nodes.append(pytorch_node)

        # Instantiate pytorch projections
        for projection in composition.projections:
            if projection.sender.owner in self.component_map and projection.receiver.owner in self.component_map:
                proj_send = self.component_map[projection.sender.owner]
                proj_recv = self.component_map[projection.receiver.owner]

                port_idx = projection.sender.owner.output_ports.index(projection.sender)
                new_proj = PytorchProjectionWrapper(projection, list(self._composition._inner_projections).index(projection), port_idx, device, sender=proj_send, receiver=proj_recv, context=context)
                proj_send.add_efferent(new_proj)
                proj_recv.add_afferent(new_proj)
                self.projection_map[projection] = new_proj
                self.projections.append(new_proj)
                self.params.append(new_proj.matrix)

        c = Context()
        try:
            composition.scheduler._init_counts(execution_id=c.execution_id, base_execution_id=context.execution_id)
        except graph_scheduler.SchedulerError:
            # called from LLVM, no base context is provided
            composition.scheduler._init_counts(execution_id=c.execution_id)

        # Setup execution sets
        # 1) Remove all learning-specific nodes
        self.execution_sets = [x - set(composition.get_nodes_by_role(NodeRole.LEARNING)) for x in composition.scheduler.run(context=c)]
        # 2) Convert to pytorchcomponent representation
        self.execution_sets = [{self.component_map[comp] for comp in s if comp in self.component_map} for s in self.execution_sets]
        # 3) Remove empty execution sets
        self.execution_sets = [x for x in self.execution_sets if len(x) > 0]

        composition.scheduler._delete_counts(c.execution_id)

    __deepcopy__ = get_deepcopy_with_shared(shared_types=(Component, ComponentsMeta))

    # generates llvm function for self.forward
    def _gen_llvm_function(self, *, ctx:pnlvm.LLVMBuilderContext, tags:frozenset):
        args = [ctx.get_state_struct_type(self._composition).as_pointer(),
                ctx.get_param_struct_type(self._composition).as_pointer(),
                ctx.get_data_struct_type(self._composition).as_pointer()
                ]
        builder = ctx.create_llvm_function(args, self)

        state, params, data = builder.function.args
        if "learning" in tags:
            self._gen_llvm_training_function_body(ctx, builder, state, params, data)
        else:
            model_input = builder.gep(data, [ctx.int32_ty(0), ctx.int32_ty(0), ctx.int32_ty(self._composition._get_node_index(self._composition.input_CIM))])
            self._gen_llvm_forward_function_body(ctx, builder, state, params, model_input, data)

        builder.ret_void()
        return builder.function

    def _gen_llvm_forward_function_body(self, ctx, builder, state, params, arg_in, data):
        z_values = {}  # dict for storing values of terminal (output) nodes
        for current_exec_set in self.execution_sets:
            for component in current_exec_set:
                mech_input_ty = ctx.get_input_struct_type(component._mechanism)
                variable = builder.alloca(mech_input_ty)
                z_values[component] = builder.alloca(mech_input_ty.elements[0].elements[0])
                builder.store(z_values[component].type.pointee(None),z_values[component])

                if NodeRole.INPUT in self._composition.get_roles_by_node(component._mechanism):
                    input_ptr = builder.gep(
                        variable, [ctx.int32_ty(0), ctx.int32_ty(0), ctx.int32_ty(0)])
                    input_id = component._idx
                    mech_in = builder.gep(arg_in, [ctx.int32_ty(0), ctx.int32_ty(input_id)])
                    builder.store(builder.load(mech_in), input_ptr)
                for (proj_idx, proj) in enumerate(component.afferents):
                    input_ptr = builder.gep(
                        variable, [ctx.int32_ty(0), ctx.int32_ty(0), ctx.int32_ty(proj_idx)])
                    proj_output = proj._gen_llvm_execute(ctx, builder, state, params, data)
                    # store in input ports struct
                    builder.store(builder.load(proj_output), input_ptr)
                    # HACK: Add to z_values struct
                    gen_inject_vec_add(ctx, builder, proj_output, z_values[component], z_values[component])
                component._gen_llvm_execute(ctx, builder, state, params, variable, data)

        return z_values

    # generates a function responsible for a single epoch of the training
    def _gen_llvm_training_backprop(self, ctx, optimizer, loss):
        composition = self._composition
        args = [ctx.get_state_struct_type(composition).as_pointer(),
                ctx.get_param_struct_type(composition).as_pointer(),
                ctx.get_data_struct_type(composition).as_pointer(),
                optimizer._get_optimizer_struct_type(ctx).as_pointer(),
                ]
        name = self._composition.name + "_training_backprop"
        builder = ctx.create_llvm_function(args, self, name)
        llvm_func = builder.function
        for a in llvm_func.args:
            if isinstance(a.type, pnlvm.ir.PointerType):
                a.attributes.add('noalias')

        state, params, data, optim_struct = llvm_func.args
        model_input = builder.gep(data, [ctx.int32_ty(0),
                                         ctx.int32_ty(0),
                                         ctx.int32_ty(self._composition._get_node_index(self._composition.input_CIM))])
        model_output = data
        # setup useful mappings
        input_nodes = set(self._composition.get_nodes_by_role(NodeRole.INPUT))

        # initialize optimizer params:
        delta_w = builder.gep(optim_struct, [ctx.int32_ty(0), ctx.int32_ty(optimizer._DELTA_W_NUM)])

        # 2) call forward computation
        z_values = self._gen_llvm_forward_function_body(
            ctx, builder, state, params, model_input, data)

        # 3) compute errors
        loss_fn = ctx.import_llvm_function(loss)
        total_loss = builder.alloca(ctx.float_ty)
        builder.store(ctx.float_ty(0), total_loss)

        error_dict = {}
        for exec_set in reversed(self.execution_sets):
            for node in exec_set:
                if node._mechanism in input_nodes:
                    continue
                node_z_value = z_values[node]
                activation_func_derivative = node._gen_llvm_execute_derivative_func(ctx, builder, state, params, node_z_value)
                error_val = builder.alloca(z_values[node].type.pointee)
                error_dict[node] = error_val

                if NodeRole.OUTPUT in self._composition.get_roles_by_node(node._mechanism):
                    # We handle output layer here
                    # compute  dC/da = a_l - y(x) (TODO: Allow other cost functions! This only applies to MSE)

                    # 1) Lookup desired target value
                    terminal_sequence = self._composition._terminal_backprop_sequences[node._mechanism]
                    target_idx = self._composition.get_nodes_by_role(
                        NodeRole.INPUT).index(terminal_sequence[TARGET_MECHANISM])
                    node_target = builder.gep(model_input, [ctx.int32_ty(0), ctx.int32_ty(target_idx)])

                    # 2) Lookup desired output value
                    node_output = builder.gep(model_output, [ctx.int32_ty(0), ctx.int32_ty(0), ctx.int32_ty(node._idx), ctx.int32_ty(0)])

                    tmp_loss = loss.gen_inject_lossfunc_call(
                        ctx, builder, loss_fn, node_output, node_target)

                    pnlvm.helpers.printf_float_array(
                        builder, node_target, prefix=f"{node}\ttarget:\t")
                    pnlvm.helpers.printf_float_array(
                        builder, node_output, prefix=f"{node}\tvalue:\t")

                    pnlvm.helpers.printf(
                        builder, f"{node}\tloss:\t%f\n", tmp_loss, override_debug=False)
                    builder.store(builder.fadd(builder.load(
                        total_loss), tmp_loss), total_loss)
                    loss_derivative = loss._gen_inject_loss_differential(
                        ctx, builder, node_output, node_target)
                    # compute δ_l = dσ/da ⊙ σ'(z)

                    gen_inject_vec_hadamard(
                        ctx, builder, activation_func_derivative, loss_derivative, error_val)

                else:
                    # We propagate error backwards from next layer
                    for proj_idx, proj in enumerate(node.efferents):
                        efferent_node = proj.receiver
                        efferent_node_error = error_dict[efferent_node]

                        weights_llvmlite = proj._extract_llvm_matrix(ctx, builder, params)

                        if proj_idx == 0:
                            gen_inject_vxm_transposed(
                                ctx, builder, efferent_node_error, weights_llvmlite, error_val)
                        else:
                            new_val = gen_inject_vxm_transposed(
                                ctx, builder, efferent_node_error, weights_llvmlite)

                            gen_inject_vec_add(
                                ctx, builder, new_val, error_val, error_val)

                    gen_inject_vec_hadamard(
                        ctx, builder, activation_func_derivative, error_val, error_val)

                pnlvm.helpers.printf_float_array(
                    builder, activation_func_derivative, prefix=f"{node}\tdSigma:\t")
                pnlvm.helpers.printf_float_array(
                    builder, error_val, prefix=f"{node}\terror:\t")

        # 4) compute weight gradients
        for (node, err_val) in error_dict.items():
            if node in input_nodes:
                continue
            for proj in node.afferents:
                # get a_(l-1)
                afferent_node_activation = builder.gep(model_output, [ctx.int32_ty(0), ctx.int32_ty(0), ctx.int32_ty(proj.sender._idx), ctx.int32_ty(0)])

                # get dimensions of weight matrix
                weights_llvmlite = proj._extract_llvm_matrix(ctx, builder, params)
                pnlvm.helpers.printf_float_matrix(builder, weights_llvmlite, prefix= f"{proj.sender._mechanism} -> {proj.receiver._mechanism}\n", override_debug=False)
                # update delta_W
                node_delta_w = builder.gep(delta_w, [ctx.int32_ty(0), ctx.int32_ty(proj._idx)])

                dim_x, dim_y = proj.matrix.shape
                with pnlvm.helpers.for_loop_zero_inc(builder, ctx.int32_ty(dim_x), "weight_update_loop_outer") as (b1, weight_row):
                    with pnlvm.helpers.for_loop_zero_inc(b1, ctx.int32_ty(dim_y), "weight_update_loop_inner") as (b2, weight_column):
                        a_val = b2.load(b2.gep(afferent_node_activation,
                                               [ctx.int32_ty(0), weight_row]))
                        d_val = b2.load(b2.gep(err_val,
                                               [ctx.int32_ty(0), weight_column]))
                        old_val = b2.load(b2.gep(node_delta_w,
                                                 [ctx.int32_ty(0), weight_row, weight_column]))
                        new_val = b2.fadd(old_val, b2.fmul(a_val, d_val))
                        b2.store(new_val, b2.gep(node_delta_w,
                                                 [ctx.int32_ty(0), weight_row, weight_column]))

        pnlvm.helpers.printf(builder, "TOTAL LOSS:\t%.20f\n",
                             builder.load(total_loss), override_debug=False)
        builder.ret_void()

        return builder.function

    def _gen_llvm_training_function_body(self, ctx, builder, state, params, data):
        composition = self._composition

        optimizer = self._get_compiled_optimizer()
        # setup loss
        loss_type = self._composition.loss_spec
        if loss_type == 'mse':
            loss = MSELoss()
        else:
            raise Exception("LOSS TYPE", loss_type, "NOT SUPPORTED")

        optimizer_step_f = ctx.import_llvm_function(optimizer)
        optimizer_struct_idx = len(state.type.pointee.elements) - 1
        optimizer_struct = builder.gep(state, [ctx.int32_ty(0), ctx.int32_ty(optimizer_struct_idx)])
        optimizer_zero_grad = ctx.import_llvm_function(optimizer.zero_grad(ctx).name)
        backprop = ctx.import_llvm_function(self._gen_llvm_training_backprop(ctx, optimizer, loss).name)

        # # FIXME: converting this call to inlined code results in
        # # significant longer compilation times
        builder.call(optimizer_zero_grad, [optimizer_struct])
        builder.call(backprop, [state, params, data,
                                optimizer_struct])
        builder.call(optimizer_step_f, [optimizer_struct, params])

    def _get_compiled_optimizer(self):
        # setup optimizer
        optimizer_type = self._composition.optimizer_type
        if optimizer_type == 'adam':
            optimizer = AdamOptimizer(self, lr=self._composition.learning_rate)
        elif optimizer_type == 'sgd':
            optimizer = SGDOptimizer(self, lr=self._composition.learning_rate)
        else:
            raise Exception("OPTIMIZER TYPE", optimizer_type, "NOT SUPPORTED")
        return optimizer

    # performs forward computation for the model
    @handle_external_context()
    def forward(self, inputs, context=None):
        outputs = {}  # dict for storing values of terminal (output) nodes
        for current_exec_set in self.execution_sets:
            for component in current_exec_set:
                if NodeRole.INPUT in self._composition.get_roles_by_node(component._mechanism):
                    component.execute(inputs[component._mechanism])
                else:
                    variable = component.collate_afferents()
                    component.execute(variable)

                # save value in output list if we're at a node in the last execution set
                if NodeRole.OUTPUT in self._composition.get_roles_by_node(component._mechanism):
                    outputs[component._mechanism] = component.value

        # NOTE: Context source needs to be set to COMMAND_LINE to force logs to update independantly of timesteps
        old_source = context.source
        context.source = ContextFlags.COMMAND_LINE
        self.log_values()
        self.log_weights()
        context.source = old_source

        return outputs

    def detach_all(self):
        for projection in self.projection_map.values():
            projection.matrix.detach()

    def copy_weights_to_psyneulink(self, context=None):
        for projection, pytorch_rep in self.projection_map.items():
            projection.parameters.matrix._set(
                pytorch_rep.matrix.detach().cpu().numpy(), context)
            projection.parameter_ports['matrix'].parameters.value._set(
                pytorch_rep.matrix.detach().cpu().numpy(), context)

    def log_weights(self):
        for proj in self.projections:
            proj.log_matrix()

    def log_values(self):
        for node in self.nodes:
            node.log_value()