Exemplo n.º 1
0
    def __init__(self):
        """
        Create the Model object. All the economic class objects (Country, Sector) live within
        a Model.

        Unlike other EconomicObject subclasses, the Model object is a base object, and has no parent.

        It is possible for two Model objects to coexist; there is no interaction between them
        (other than side effects from global Parameters).
        """
        EconomicObject.__init__(self)
        self.CountryList = []
        self.Exogenous = []
        self.InitialConditions = []
        self.FinalEquations = ''
        self.MaxTime = 100
        self.RegisteredCashFlows = []
        self.Aliases = {}
        self.TimeSeriesCutoff = None
        self.TimeSeriesSupressTimeZero = False
        self.EquationSolver = sfc_models.equation_solver.EquationSolver()
        self.GlobalVariables = []
        self.IncomeExclusions = []
        self.CurrencyZoneList = []
        self.State = 'Construction'
        self.DefaultCurrency = 'LOCAL'
        self.RunSteps = None
        self.ExternalSector = None
        self.FinalEquationBlock = EquationBlock()
Exemplo n.º 2
0
 def test_list(self):
     block = EquationBlock()
     eq = Equation('x', rhs=[Term('y')])
     block.AddEquation(eq)
     block.AddEquation(Equation('a'))
     # Always sorted
     self.assertEqual(['a', 'x'], block.GetEquationList())
Exemplo n.º 3
0
    def _CreateFinalEquations(self):
        """
        Create Final equations.

        Final output, which is a text block of equations
        :return: str
        """
        Logger('Model._CreateFinalEquations()')
        out = []
        for cntry in self.CountryList:
            for sector in cntry.SectorList:
                out.extend(sector._CreateFinalEquations())
        out.extend(self._GenerateInitialConditions())
        out.extend(self.GlobalVariables)
        if len(out) == 0:
            self.FinalEquations = ''
            raise Warning('There are no equations in the system.')
        # Build the FinalEquationBlock
        self.FinalEquationBlock = EquationBlock()
        for row in out:
            if 'EXOGENOUS' in row[1]:
                eq = Equation(row[0],
                              desc=row[2],
                              rhs=row[1].replace('EXOGENOUS', ''))
            else:
                eq = Equation(row[0], desc=row[2], rhs=row[1])
            self.FinalEquationBlock.AddEquation(eq)
        out = self._FinalEquationFormatting(out)
        self.FinalEquations = out
        return out
Exemplo n.º 4
0
 def __init__(self, country, code, long_name='', has_F=True):
     if long_name == '':
         long_name = 'Sector Object {0} in Country {1}'.format(
             code, country.Code)
     self.Code = code
     EconomicObject.__init__(self, country, code=code)
     self.CurrencyZone = country.CurrencyZone
     country._AddSector(self)
     # This is calculated by the Model
     self.FullCode = ''
     self.LongName = long_name
     # self.Equations = {}
     self.HasF = has_F
     self.IsTaxable = False
     self.EquationBlock = EquationBlock()
     if has_F:
         # self.AddVariable('F', 'Financial assets', '<TO BE GENERATED>')
         F = Equation('F', 'Financial assets')
         F.AddTerm('LAG_F')
         self.AddVariableFromEquation(F)
         # self.AddVariable('LAG_F', 'Previous period''s financial assets.', 'F(k-1)')
         INC = Equation('INC', 'Income (PreTax)', rhs=[])
         self.AddVariableFromEquation(INC)
         self.AddVariable('LAG_F', 'Previous period'
                          's financial assets.', 'F(k-1)')
Exemplo n.º 5
0
 def test_access(self):
     block = EquationBlock()
     eq = Equation('x', rhs=[Term('y')])
     block.AddEquation(eq)
     out = block['x']
     self.assertEqual('y', out.RHS())
Exemplo n.º 6
0
class Model(EconomicObject):
    """
    Model class.

    All other entities live within a model.
    """
    def __init__(self):
        """
        Create the Model object. All the economic class objects (Country, Sector) live within
        a Model.

        Unlike other EconomicObject subclasses, the Model object is a base object, and has no parent.

        It is possible for two Model objects to coexist; there is no interaction between them
        (other than side effects from global Parameters).
        """
        EconomicObject.__init__(self)
        self.CountryList = []
        self.Exogenous = []
        self.InitialConditions = []
        self.FinalEquations = ''
        self.MaxTime = 100
        self.RegisteredCashFlows = []
        self.Aliases = {}
        self.TimeSeriesCutoff = None
        self.TimeSeriesSupressTimeZero = False
        self.EquationSolver = sfc_models.equation_solver.EquationSolver()
        self.GlobalVariables = []
        self.IncomeExclusions = []
        self.CurrencyZoneList = []
        self.State = 'Construction'
        self.DefaultCurrency = 'LOCAL'
        self.RunSteps = None
        self.ExternalSector = None
        self.FinalEquationBlock = EquationBlock()

    def main(self, base_file_name=None):  # pragma: no cover
        """
        Routine that does most of the work of model building. The model is build based upon
        the Sector objects that have been registered as children of this Model.

        The base_file_name is the base filename that is used for Logging operations; just used
        to call Logger.register_standard_logs(base_file_name). It is recommended that you call
        Logger.register_standard_logs() before main, so that Sector creation can be logged.

        The major operations:
        [1] Call GenerateEquations() on all Sectors. The fact that GenerateEquations() is only
         called now at the Sector level means that Sectors can be created independently, and the
        GenerateEquations() call which ties it to other sectors is only called after all other sectors
        are created. (There is at least one exception where Sectors have to be created in a specific
        order, which we want to avoid.)

        [2] Cleanup work: processing initial conditions, exogenous variables, replacing
            aliases in equations.

        [3] Tie the sector level equations into a single block of equations. (Currently strings, not
            Equation objects.)

        [4] The equations are passed to self.EquationSolver, and they are solved.

        The user can then use GetTimeSeries() to access the output time series (if they can be
        calculated.)

        :param base_file_name: str
        :return: None
        """
        self.State = 'Running'
        try:
            if base_file_name is not None:
                Logger.register_standard_logs(base_file_name)
            Logger('Starting Model main()')
            self._GenerateFullSectorCodes()
            self._GenerateEquations()
            self._FixAliases()
            self._GenerateRegisteredCashFlows()
            self._ProcessExogenous()
            self.FinalEquations = self._CreateFinalEquations()
            self.EquationSolver.ParseString(self.FinalEquations)
            self.EquationSolver.SolveEquation()
            self.LogInfo()
        except Warning as e:
            self.LogInfo(ex=e)
            print('Warning triggered: ' + str(e))
            return self.FinalEquations
        except Exception as e:
            self.LogInfo(ex=e)
            raise
        finally:
            self.State = 'Finished Running'
            Logger(self.EquationSolver.GenerateCSVtext(), 'timeseries')
            Logger.cleanup()
        return self.FinalEquations

    def _GetSteps(self):  # pragma: no cover
        """
        This is experimental, for GUI use. Will integrate with main() later...
        :return:
        """
        if self.RunSteps is None:
            # Need to generate RunSteps
            self.RunSteps = []
            self.RunSteps.append(
                {'Generate Sector Codes': self._GenerateFullSectorCodes})
            self.RunSteps.append({'Fix Aliases': self._FixAliases})
            self.RunSteps.append(
                {'Generate Equations': self._GenerateEquationSteps})
            self.RunSteps.append(
                {'Process Cash Flows': self._GenerateRegisteredCashFlows})
            self.RunSteps.append({'Process Exogenous': self._ProcessExogenous})
            self.RunSteps.append({'Fix Aliases (Pass #2)': self._FixAliases})
            self.RunSteps.append(
                {'Final Equations': self._CreateFinalEquations})
            self.RunSteps.append({'Solve': self._FinalSteps})
            # self.EquationSolver.ParseString(self.FinalEquations)
            # self.EquationSolver.SolveEquation()
            # self.LogInfo()
            # self.State = 'Finished Running'
            # Logger(self.EquationSolver.GenerateCSVtext(), 'timeseries')
            # Logger.cleanup()
        out = [list(x.keys()) for x in self.RunSteps]
        for x in out:
            x.sort()
        return out

    def _GenerateEquationSteps(self):  # pragma: no cover
        sector_list = self.GetSectors()
        for sec in sector_list:
            self.RunSteps[0][sec.FullCode] = sec._GenerateEquationsFrontEnd

    def _FinalSteps(self):  # pragma: no cover
        self.EquationSolver.ParseString(self.FinalEquations)
        self.EquationSolver.SolveEquation()
        self.LogInfo()
        self.State = 'Finished Running'
        Logger(self.EquationSolver.GenerateCSVtext(), 'timeseries')
        Logger.cleanup()

    def _RunStep(self, command):  # pragma: no cover
        if len(self.RunSteps) == 0:
            self.State = 'Finished Running'
            return
        self.State = 'Running'
        # Will throw KeyError if not in command list
        func = self.RunSteps[0].pop(command)
        try:
            func()
        except:
            self.RunSteps = []
            self.State = 'Finished Running'
            raise
        if len(self.RunSteps[0]) == 0:
            self.RunSteps.pop(0)
        if len(self.RunSteps) == 0:
            self.State = 'Finished Running'

    def _RunAllSteps(self):  # pragma: no cover
        while len(self.RunSteps) > 0:
            all_cmds = list(self.RunSteps[0].keys())
            self._RunStep(all_cmds[0])

    def AddExogenous(self, sector_fullcode, varname, value):
        """
        Add an exogenous variable to the model. Overwrites an existing variable definition.
        Need to use the full sector code.

        Exogenous variables are sepcified as time series, which are implemented as list variables ([1, 2, 3,...])
        The exogenous variable can either be specified as a string (which can be evaluated to a list), or else
        as a list object. The list object will be converted into a string representation using repr(), which
        means that it may be much longer than using something like '[20,] * 100'.

        At present, does not support the usage of specifying a constant value. For example value='20.' does
        not work, you need '[20.]*100.'

        :param sector_fullcode: str
        :param varname: str
        :param value: str
        :return:
        """
        Logger('Adding exogenous variable: {0} in {1}',
               priority=5,
               data_to_format=(varname, sector_fullcode))
        # If the user passes in a list or tuple, convert it to a string representation.
        if type(value) in (list, tuple):
            value = repr(value)
        self.Exogenous.append((sector_fullcode, varname, value))

    def AddInitialCondition(self, sector_fullcode, varname, value):
        """
        Set the initial condition for a variable. Need to use the full code of the sector -
        or its ID (int).

        :param sector_fullcode: str
        :param varname: str
        :param value: float
        :return:
        """
        Logger('Adding initial condition: {0} in {1}',
               priority=5,
               data_to_format=(varname, sector_fullcode))
        # Convert the "value" to a float, in case someone uses a string
        try:
            value = float(value)
        except:
            raise ValueError(
                'The "value" parameter for initial conditions must be a float.'
            )
        self.InitialConditions.append((sector_fullcode, varname, str(value)))

    def AddCashFlowIncomeExclusion(self, sector, cash_flow_name):
        """
        Specify that a cash flow is to be excluded from the list of flows that go into
        income.

        Examples of such flows within economics:
        - Household consumption of goods.
        - Business investment.
        - Dividend outflows for business (inflow is income for household).
        - (Financing flows - do not really appear?)

        Note that these exclusions are generally implemented at the sector code level.
        :param sector: Sector
        :param cash_flow_name: str
        :return: None
        """
        Logger('Registering cash flow exclusion: {0} for ID={1}',
               priority=5,
               data_to_format=(cash_flow_name, sector.ID))
        self.IncomeExclusions.append((sector, cash_flow_name))

    def _RegisterAlias(self, alias, sector, local_variable_name):
        """
        Used by Sector objects to register aliases for local variables.

        :param alias: str
        :param sector: Sector
        :param local_variable_name: str
        :return:
        """
        Logger('Registering alias {0} for {1} in ID={2}',
               priority=5,
               data_to_format=(alias, local_variable_name, sector.ID))
        self.Aliases[alias] = (sector, local_variable_name)

    def AddGlobalEquation(self, var, description, eqn):
        """
        Add a variable that is not associated with a sector.
        Typical example: 't'

        :param var: str
        :param description: str
        :param eqn: str
        :return: None
        """
        Logger('Registering global equation: {0} = {1}',
               priority=5,
               data_to_format=(var, eqn))
        self.GlobalVariables.append((var, eqn, description))

    def GetSectors(self):
        """
        Returns a list of Sector objects held within this Model.

        :return: list
        """
        out = []
        for cntry in self.CountryList:
            for sector in cntry.SectorList:
                out.append(sector)
        return out

    def GetTimeSeries(self, series, cutoff=None, group_of_series='main'):
        """
        Convenience function to retrieve time series from the EquationSolver.

        Use cutoff to truncate the length of the output.

        If self.TimeSeriesSupressZero is True, the first point is removed (the initial
        conditions period).

        :param group_of_series:
        :param series: str
        :param cutoff: int
        :return: list
        """
        if cutoff is None:
            cutoff = self.TimeSeriesCutoff
        try:
            series_holder = self.EquationSolver.TimeSeries
            if group_of_series == 'step':  # pragma: no cover [The GUI tells us quickly if this breaks]
                series_holder = self.EquationSolver.TimeSeriesStepTrace
            elif group_of_series == 'initial':  # pragma: no cover
                series_holder = self.EquationSolver.TimeSeriesInitialSteadyState
            if cutoff is None:
                val = series_holder[series]
            else:
                val = series_holder[series][0:(cutoff + 1)]
        except KeyError:
            raise KeyError('Time series "{0}" does not exist'.format(series))
        if self.TimeSeriesSupressTimeZero:
            val.pop(0)
        return val

    def _FixAliases(self):
        """
        Assign the proper names to variables in Sector objects (that were perviously aliases).
        :return:
        """
        Logger('Fixing aliases (Model._FixAliases)', priority=3)
        lookup = {}
        for alias in self.Aliases:
            sector, varname = self.Aliases[alias]
            lookup[alias] = sector.GetVariableName(varname)
        for sector in self.GetSectors():
            sector._ReplaceAliases(lookup)

    def LogInfo(self, generate_full_codes=True, ex=None):  # pragma: no cover
        """
        Write information to a file; if there is an exception, dump the trace.
        The log will normally generate the full sector codes; set generate_full_codes=False
        to leave the Model full codes untouched.

        :param generate_full_codes: bool
        :param ex: Exception
        :return:
        """
        # Not covered with unit tests [for now]. Output format will change a lot.
        if ex is not None:
            Logger('\nError or Warning raised:\n')
            try:
                traceback.print_exc(file=Logger.get_handle())
            except KeyError:
                # Log was not registered; do nothing!
                return
        Logger('-' * 30)
        Logger('Starting LogInfo() Data Dump')
        Logger('-' * 30)
        if generate_full_codes:
            self._GenerateFullSectorCodes()
            for c in self.CountryList:
                Logger('Country: Code= "%s" %s\n' % (c.Code, c.LongName))
                Logger('=' * 60 + '\n\n')
                for s in c.SectorList:
                    Logger(s.Dump() + '\n')
        Logger('Writing LogInfo to log="eqn"')
        Logger('\n\nFinal Equations:\n', log='eqn')
        Logger(self.FinalEquations + '\n', log='eqn')
        parser = EquationParser()
        parser.ParseString(self.FinalEquations)
        parser.EquationReduction()
        Logger('\n\nReduced Equations', log='eqn')
        Logger(parser.DumpEquations(), log='eqn')
        if ex is not None:
            Logger('\n\nError raised:\n')
            traceback.print_exc(file=Logger.get_handle())

    def _AddCountry(self, country):
        """
        Add a country to the list. This is called by the object constructore; users should
        not call this.

        :param country: Country
        :return: None
        """
        Logger('Adding Country: {0} ID={1}',
               data_to_format=(country.Code, country.ID))
        if country.Code in self:
            raise LogicError('Country with Code {0} already in Model'.format(
                country.Code))
        self.CountryList.append(country)
        self.DefaultCurrency = country.Currency
        czone = self._FitIntoCurrencyZone(country)
        country.CurrencyZone = czone

    def _FitIntoCurrencyZone(self, country):
        """
        Find whether Country fits into an existing CurrencyZone; if not,
        create a new one.
        :param country: Country
        :return: CurrencyZone
        """
        for czone in self.CurrencyZoneList:
            if country.Currency == czone.Currency:
                czone.CountryList.append(country)
                Logger('Fitting {0} into CurrencyZone {1}',
                       data_to_format=(country.LongName, czone.Currency))
                return czone
        Logger('Creating new currency zone {0}, adding {1} to it',
               data_to_format=(country.Currency, country.LongName))
        czone = CurrencyZone(self, country.Currency)
        czone.CountryList.append(country)
        self.CurrencyZoneList.append(czone)
        return czone

    def _GenerateFullSectorCodes(self):
        """
        Create full sector names (which is equal to '[country.Code]_[sector.Code]' - if there is more than one country.
        Equals the sector code otherwise.

        :return: None
        """
        Logger('Generating FullSector codes (Model._GenerateFullSectorCodes()',
               priority=3)
        add_country_code = len(self.CountryList) > 1
        for cntry in self.CountryList:
            for sector in cntry.SectorList:
                if add_country_code:
                    sector.FullCode = cntry.Code + '_' + sector.Code
                else:
                    sector.FullCode = sector.Code

    @staticmethod
    def GetSectorCodeWithCountry(sector):
        """
        Return the sector code including the country information.
        Need to use this if we want the FullCode before the model information
        is bound in main().

        We should not need to use this function often; generally only when we need
        FullCodes in constructors. For example, the multi-supply business sector
        needs the full codes of markets passed into the constructor.
        :param sector: Sector
        :return: str
        """
        return '{0}_{1}'.format(sector.Parent.Code, sector.Code)

    def RegisterCashFlow(self,
                         source_sector,
                         target_sector,
                         amount_variable,
                         is_income_source=True,
                         is_income_dest=True):
        """
        Register a cash flow between two sectors.

        The amount_variable is the name of the local variable within the source sector.

        Only allowed across currency zones if an ExternalSector object has been defined;
        otherwise throws a LogicError.

        The currency value of the amount_variable is assumed to be in the source currency.

        :param is_income_dest:
        :param is_income_source:
        :param source_sector: Sector
        :param target_sector: Sector
        :param amount_variable: str
        :return:
        """
        # if amount_variable not in source_sector.Equations:
        #     raise KeyError('Must define the variable that is the amount of the cash flow')
        Logger('Cash flow registered {0}: {1} -> {2}  [ID: {3} -> {4}]',
               priority=3,
               data_to_format=(amount_variable, source_sector.Code,
                               target_sector.Code, source_sector.ID,
                               target_sector.ID))
        self.RegisteredCashFlows.append(
            (source_sector, target_sector, amount_variable, is_income_source,
             is_income_dest))

    def _GenerateRegisteredCashFlows(self):
        """
        Create cash flows based on those previously registered.

        :return:
        """
        Logger('Model._GenerateRegisteredCashFlows()')
        Logger('Adding {0} cash flows to sectors',
               priority=3,
               data_to_format=(len(self.RegisteredCashFlows), ))
        for source_sector, target_sector, amount_variable, is_income_source, is_income_dest in self.RegisteredCashFlows:
            is_cross_currency = source_sector.CurrencyZone != target_sector.CurrencyZone
            if is_cross_currency:
                if self.ExternalSector is None:
                    msg = """Only can have cross-currency flows if an ExternalSector object is created\nSource={0} Destination={1}""".format(
                        source_sector.FullCode, target_sector.FullCode)
                    raise LogicError(msg)
            full_variable_name = source_sector.GetVariableName(amount_variable)
            source_sector.AddCashFlow('-' + full_variable_name,
                                      eqn=None,
                                      is_income=is_income_source)
            if is_cross_currency:
                fx = self.ExternalSector['FX']
                fx._SendMoney(source_sector, full_variable_name)
                term = fx._ReceiveMoney(target_sector=target_sector,
                                        source_sector=source_sector,
                                        variable_name=full_variable_name)
            else:
                term = '+' + full_variable_name
            target_sector.AddCashFlow(term, eqn=None, is_income=is_income_dest)

    def LookupSector(self, fullcode):
        """
        Find a sector based on its FullCode.
        :param fullcode: str
        :return: Sector
        """
        for cntry in self.CountryList:
            try:
                s = cntry.LookupSector(fullcode, is_full_code=True)
                return s
            except KeyError:
                pass
        raise KeyError('Sector with FullCode does not exist: ' + fullcode)

    def _ProcessExogenous(self):
        """
        Handles the exogenous variables.

        :return: None
        """
        Logger('Processing {0} exogenous variables',
               priority=3,
               data_to_format=(len(self.Exogenous), ))
        for sector_code, varname, eqn in self.Exogenous:
            if type(sector_code) is str:
                sector = self.LookupSector(sector_code)
            else:
                sector = sector_code
            if varname not in sector.EquationBlock.Equations:
                raise KeyError('Sector %s does not have variable %s' %
                               (sector_code, varname))
            # Need to mark exogenous variables
            sector.SetEquationRightHandSide(varname, 'EXOGENOUS ' + eqn)

    def _GenerateInitialConditions(self):
        """
        Create block of equations for initial conditions.

        Validates that the variables exist.
        :return:
        """
        Logger('Generating {0} initial conditions',
               priority=3,
               data_to_format=(len(self.InitialConditions), ))
        out = []
        for sector_code, varname, value in self.InitialConditions:
            sector = self.LookupSector(sector_code)
            if varname not in sector.EquationBlock.Equations:
                raise KeyError('Sector %s does not have variable %s' %
                               (sector_code, varname))
            out.append(('%s(0)' % (sector.GetVariableName(varname), ), value,
                        'Initial Condition'))
        return out

    def _GenerateEquations(self):
        """
        Call _GenerateEquations on all child Sector objects.

        :return:
        """
        Logger('Model._GenerateEquations()', priority=1)
        for cntry in self.CountryList:
            for sector in cntry.SectorList:
                Logger('Calling _GenerateEquations on {0}  ({1})',
                       priority=3,
                       data_to_format=(sector.FullCode, type(sector)))
                sector._GenerateEquations()

    def DumpEquations(self):
        """
        Returns a string with basic information about the entities within this Model.
        Output is primitive, and aimed at debugging purposes. In other words, the format
        will change without warning.

        If you want information of a specific format, please create a specific reporting
        function for your needs.

        :return: str
        """
        out = ''
        for cntry in self.CountryList:
            for sector in cntry.SectorList:
                out += sector.Dump()
        return out

    def _CreateFinalEquations(self):
        """
        Create Final equations.

        Final output, which is a text block of equations
        :return: str
        """
        Logger('Model._CreateFinalEquations()')
        out = []
        for cntry in self.CountryList:
            for sector in cntry.SectorList:
                out.extend(sector._CreateFinalEquations())
        out.extend(self._GenerateInitialConditions())
        out.extend(self.GlobalVariables)
        if len(out) == 0:
            self.FinalEquations = ''
            raise Warning('There are no equations in the system.')
        # Build the FinalEquationBlock
        self.FinalEquationBlock = EquationBlock()
        for row in out:
            if 'EXOGENOUS' in row[1]:
                eq = Equation(row[0],
                              desc=row[2],
                              rhs=row[1].replace('EXOGENOUS', ''))
            else:
                eq = Equation(row[0], desc=row[2], rhs=row[1])
            self.FinalEquationBlock.AddEquation(eq)
        out = self._FinalEquationFormatting(out)
        self.FinalEquations = out
        return out

    def _FinalEquationFormatting(self, out):
        """
        Convert equation information in list into formatted strings.

        :param out: list
        :return:
        """
        Logger('_FinalEquationFormatting()', priority=5)
        endo = []
        exo = []
        for row in out:
            if 'EXOGENOUS' in row[1]:
                new_eqn = row[1].replace('EXOGENOUS', '')
                exo.append((row[0], new_eqn, row[2]))
            else:
                endo.append(row)
        max0 = max([len(x[0]) for x in out])
        max1 = max([len(x[1]) for x in out])
        formatter = '%<max0>s = %-<max1>s  # %s'
        formatter = formatter.replace('<max0>', str(max0))
        formatter = formatter.replace('<max1>', str(max1))
        endo = [formatter % x for x in endo]
        exo = [formatter % x for x in exo]
        s = '\n'.join(endo) + '\n\n# Exogenous Variables\n\n' + '\n'.join(exo)
        s += '\n\nMaxTime = {0}\nErr_Tolerance=1e-6'.format(self.MaxTime)
        return s

    def __getitem__(self, item):
        """
        Get a country using model[country_code] notation
        :param item: str
        :return: Country
        """
        for obj in self.CountryList:
            if item == obj.Code:
                return obj
        raise KeyError('Country {0} not in Model'.format(item))

    def __contains__(self, item):
        """
        Is a Country object (or string code) in a Model?
        :param item: str
        :return:
        """
        if type(item) is str:
            try:
                a = self[item]
                return True
            except KeyError:
                return False
        return item in self.CountryList
Exemplo n.º 7
0
class Sector(EconomicObject):
    """
    All sectors derive from this class.
    """
    def __init__(self, country, code, long_name='', has_F=True):
        if long_name == '':
            long_name = 'Sector Object {0} in Country {1}'.format(
                code, country.Code)
        self.Code = code
        EconomicObject.__init__(self, country, code=code)
        self.CurrencyZone = country.CurrencyZone
        country._AddSector(self)
        # This is calculated by the Model
        self.FullCode = ''
        self.LongName = long_name
        # self.Equations = {}
        self.HasF = has_F
        self.IsTaxable = False
        self.EquationBlock = EquationBlock()
        if has_F:
            # self.AddVariable('F', 'Financial assets', '<TO BE GENERATED>')
            F = Equation('F', 'Financial assets')
            F.AddTerm('LAG_F')
            self.AddVariableFromEquation(F)
            # self.AddVariable('LAG_F', 'Previous period''s financial assets.', 'F(k-1)')
            INC = Equation('INC', 'Income (PreTax)', rhs=[])
            self.AddVariableFromEquation(INC)
            self.AddVariable('LAG_F', 'Previous period'
                             's financial assets.', 'F(k-1)')

    def AddVariable(self, varname, desc='', eqn=''):
        """
        Add a variable to the sector.
        The variable name (varname) is the local name; it will be decorated to create a
        full name. Equations within a sector can use the local name; other sectors need to
        use GetVariableName to get the full name.
        :param varname: str
        :param desc: str
        :param eqn: str
        :return: None
        """
        if '__' in varname:
            raise ValueError('Cannot use "__" inside local variable names: ' +
                             varname)
        if desc is None:
            desc = ''
        if type(eqn) == Equation:
            equation = eqn
        else:
            equation = Equation(varname, desc, [
                Term(eqn, is_blob=True),
            ])
        if varname in self.GetVariables():
            Logger('[ID={0}] Variable Overwritten: {1}',
                   priority=3,
                   data_to_format=(self.ID, varname))
        self.EquationBlock.AddEquation(equation)
        # self.Equations[varname] = eqn
        Logger('[ID={0}] Variable Added: {1} = {2} # {3}',
               priority=2,
               data_to_format=(self.ID, varname, eqn, desc))

    def AddVariableFromEquation(self, eqn):
        """
        Method to be used until the Equation member is replaced...
        :param eqn: Equation
        :return:
        """
        if type(eqn) == str:
            eqn = Equation(eqn)
        self.AddVariable(eqn.LeftHandSide, eqn.Description, eqn)

    def SetEquationRightHandSide(self, varname, rhs):
        """
        Set the right hand side of the equation for an existing variable.
        :param varname: str
        :param rhs: str
        :return: None
        """
        try:
            self.EquationBlock[varname].TermList = [
                Term(rhs, is_blob=True),
            ]
        except KeyError:
            raise KeyError('Variable {0} does not exist'.format(varname))
        # Could try: Equation.ParseString(rhs), but is too slow in unit tests...
        # if varname not in self.Equations:
        #     raise KeyError('Variable {0} does not exist'.format(varname))
        Logger('[ID={0}] Equation set: {1} = {2} ',
               priority=2,
               data_to_format=(self.ID, varname, rhs))
        # self.Equations[varname] = rhs

    def AddTermToEquation(self, varname, term):
        """
        Add a new term to an existing equation.

        The term variable may be either a string or (non-Blob) Term object.

        :param varname: str
        :param term: Term
        :return: None
        """
        term = Term(term)
        Logger('Adding term {0} to Equation {1} in Sector {2} [ID={3}]',
               priority=2,
               data_to_format=(term, varname, self.Code, self.ID))
        try:
            self.EquationBlock[varname].AddTerm(term)
        except KeyError:
            raise KeyError('Variable {0} not in Sector {1}'.format(
                varname, self.Code))

    def SetExogenous(self, varname, val):
        """
        Set an exogenous variable for a sector. The variable must already be defined (by AddVariable()).
        :param varname: str
        :param val: str
        :return: None
        """
        self.GetModel().AddExogenous(self, varname, val)

    def GetVariables(self):
        """
        Return a sorted list of variables.

        (Need to sort to make testing easier; dict's store in "random" hash order.

        This is a convenience function; it just passes along self.EquationBlock.GetEquationList()
        :return: list
        """
        return self.EquationBlock.GetEquationList()

    def GetVariableName(self, varname):
        """
        Get the full variable name associated with a local variable.

        Standard convention:
        {sector_fullcode}__{local variable name}.
        NOTE: that is is double-underscore '_'. The use of double underscores in
        variable names (or sector codes) is now verboten!
        This means that the presence of double underscore means that this is a full variable name.

        NOTE: If the sector FullCode is not defined, a temporary alias is created and registered.
        The Model object will ensure that all registered aliases are cleaned up.]

        :param varname: str
        :return: str
        """
        if varname not in self.EquationBlock.GetEquationList():
            raise KeyError('Variable %s not in sector %s' %
                           (varname, self.FullCode))
        if self.FullCode == '':
            alias = '_{0}__{1}'.format(self.ID, varname)
            Logger('Registering alias: {0}',
                   priority=5,
                   data_to_format=(alias, ))
            self.GetModel()._RegisterAlias(alias, self, varname)
            return alias
        else:
            # Put in a sanity check here
            if '__' in self.FullCode:
                raise ValueError(
                    'The use of "__" in sector codes is invalid: ' +
                    self.FullCode)
            if '__' in varname:
                raise ValueError(
                    'The use of "__" in variable local names is invalid: ' +
                    varname)
            return self.FullCode + '__' + varname

    def IsSharedCurrencyZone(self, other):
        """
        Is a sector in the same CurrencyZone as the other?
        :param other: Sector
        :return: bool
        """
        return self.CurrencyZone.ID == other.CurrencyZone.ID

    def _ReplaceAliases(self, lookup):
        """
        Use the lookup dictionary to replace aliases.
        :param lookup: dict
        :return:
        """
        self.EquationBlock.ReplaceTokensFromLookup(lookup)

    def AddCashFlow(self, term, eqn=None, desc=None, is_income=True):
        """
        Add a cash flow to the sector. Will add to the financial asset equation (F), and
        the income equation (INC) if is_income is True.

        Except: There is a list of exclusions to which cash flows are not considered income.
        That setting will override the is_income parameter. This allows us to carve out exceptions
        to the standard behaviour, which generally is to assume that cash flows are associated with
        income.

        :param term: str
        :param eqn: str
        :param desc: str
        :param is_income: bool
        :return: None
        """
        term = term.strip()
        if len(term) == 0:
            return
        term_obj = Term(term)
        if not term_obj.IsSimple:  # pragma: no cover  - Not implemented; cannot hit the line below.
            raise LogicError(
                'Must supply a single variable as the term to AddCashFlow')
        # term = term.replace(' ', '')
        # if not (term[0] in ('+', '-')):
        #     term = '+' + term
        # if len(term) < 2:
        #     raise ValueError('Invalid cash flow term')
        self.EquationBlock['F'].AddTerm(term)
        if is_income:
            # Need to see whether it is excluded
            mod = self.GetModel()
            for obj, excluded in mod.IncomeExclusions:
                if obj.ID == self.ID:
                    if term_obj.Term == excluded:
                        is_income = False
                        break
        if is_income:
            self.EquationBlock['INC'].AddTerm(term)
        if eqn is None:
            return
        # Remove the +/- from the term
        term = term_obj.Term
        if term in self.GetVariables():
            rhs = self.EquationBlock[term].RHS()
            if rhs == '' or rhs == '0.0':
                self.SetEquationRightHandSide(term, eqn)
        else:
            self.AddVariable(term, desc, eqn)

    def AddInitialCondition(self, variable_name, value):
        """
        Add an initial condition for a variable associated with this sector.
        :param variable_name: str
        :param value: float
        :return:
        """
        self.GetModel().AddInitialCondition(self.ID, variable_name, value)

    def _GenerateEquationsFrontEnd(self):  # pragma: no cover
        """
        Used by graphical front ends; generates a logging message. (In Model.Main(),
        the logging is done by the Model before it calls the Sector.)
        :return:
        """
        Logger('Running _GenerateEquations on {0} [{1}]',
               priority=3,
               data_to_format=(self.Code, self.ID))
        self._GenerateEquations()

    def _GenerateEquations(self):
        """
        Work is done in derived classes.
        :return: None
        """
        return

    def Dump(self):
        """
        Create a string with information about this object. This is for debugging
        purposes, and the format will change over time. In other words, do not rely on
        this output if you want specific information.

        :return: str
        """
        out = '[%s] %s. FullCode = "%s" \n' % (self.Code, self.LongName,
                                               self.FullCode)
        out += '-' * 60 + '\n'
        for var in self.EquationBlock.GetEquationList():
            out += str(self.EquationBlock[var]) + '\n'
        return out

    def _CreateFinalEquations(self):
        """
        Returns the final set of equations, with the full names of variables.
        :return: list
        """
        out = []
        lookup = {}
        for varname in self.EquationBlock.GetEquationList():
            lookup[varname] = self.GetVariableName(varname)
        for varname in self.EquationBlock.GetEquationList():
            eq = self.EquationBlock[varname]
            rhs = eq.GetRightHandSide()
            if len(
                    rhs.strip()
            ) == 0:  # pragma: no cover  [Does not happen any more; leave in just in case.]
                continue
            out.append((self.GetVariableName(varname),
                        replace_token_from_lookup(rhs, lookup),
                        '[%s] %s' % (varname, eq.Description)))
        return out

    def GenerateAssetWeighting(self,
                               asset_weighting_dict,
                               residual_asset_code,
                               is_absolute_weighting=False):
        """
        Generates the asset weighting/allocation equations. If there are N assets, pass N-1 in the list, the residual
        gets the rest.

        The variable asset_weighting_list is a
        dictionary, of the form:
        {'asset1code': 'weighting equation',
        'asset2code': 'weighting2')}

        The is_absolute_weighting parameter is a placeholder; if set to true, asset demands are
        absolute. There is a TODO marking where code should be added.

        Note that weightings are (normally) from 0-1.
        :param asset_weighting_dict: dict
        :param residual_asset_code: str
        :param is_absolute_weighting: bool
        :return:
        """
        if is_absolute_weighting:
            # TODO: Implement absolute weightings.
            raise NotImplementedError('Absolute weightings not implemented')
        residual_weight = '1.0'
        if type(asset_weighting_dict) in (list, tuple):
            # Allow asset_weighting_dict to be a list of key: value pairs.
            tmp = dict()
            for code, eqn in asset_weighting_dict:
                tmp[code] = eqn
            asset_weighting_dict = tmp
        for code, weight_eqn in asset_weighting_dict.items():
            # Weight variable = 'WGT_{CODE}'
            weight = 'WGT_' + code
            self.AddVariable(weight, 'Asset weight for' + code, weight_eqn)
            self.AddVariable('DEM_' + code, 'Demand for asset ' + code,
                             'F * {0}'.format(weight))
            residual_weight += ' - ' + weight
        self.AddVariable('WGT_' + residual_asset_code,
                         'Asset weight for ' + residual_asset_code,
                         residual_weight)
        self.AddVariable('DEM_' + residual_asset_code,
                         'Demand for asset ' + residual_asset_code,
                         'F * {0}'.format('WGT_' + residual_asset_code))