예제 #1
0
def check_tracing_params(accessor, param_key):
    tracer = Tracer()
    tracer.record_calculation_start('A', '2015-01')
    tracingParams = TracingParameterNodeAtInstant(parameters('2015-01-01'),
                                                  tracer)
    param = accessor(tracingParams)
    assert_near(tracer.trace['A<2015-01>']['parameters'][param_key], param)
예제 #2
0
def test_variable_stats():
    tracer = Tracer()
    tracer.record_calculation_start("A", 2017)
    tracer.record_calculation_start("B", 2017)
    tracer.record_calculation_start("B", 2017)
    tracer.record_calculation_start("B", 2016)

    assert tracer.usage_stats['B']['nb_requests'] == 3
    assert tracer.usage_stats['A']['nb_requests'] == 1
    assert tracer.usage_stats['C']['nb_requests'] == 0
예제 #3
0
    def __init__(
        self,
        tax_benefit_system,
        simulation_json=None,
        debug=False,
        period=None,
        trace=False,
        opt_out_cache=False,
        memory_config=None,
    ):
        """
            Create an empty simulation

            To fill the simulation with input data, you can use the :any:`SimulationBuilder` or proceed manually.
        """
        self.tax_benefit_system = tax_benefit_system
        assert tax_benefit_system is not None
        if period:
            assert isinstance(period, periods.Period)
        self.period = period

        # To keep track of the values (formulas and periods) being calculated to detect circular definitions.
        # See use in formulas.py.
        # The data structure of requested_periods_by_variable_name is: {variable_name: [period1, period2]}
        self.requested_periods_by_variable_name = {}
        self.max_nb_cycles = None

        self.debug = debug
        self.trace = trace or self.debug
        if self.trace:
            self.tracer = Tracer()
        else:
            self.tracer = None
        self.opt_out_cache = opt_out_cache

        self.memory_config = memory_config
        self._data_storage_dir = None
        self.instantiate_entities()

        if simulation_json is not None:
            warnings.warn(
                ' '.join([
                    "The 'simulation_json' argument of the Simulation is deprecated since version 25.0, and will be removed in the future.",
                    "The proper way to init a simulation from a JSON-like dict is to use SimulationBuilder.build_from_entities. See <https://openfisca.org/doc/openfisca-python-api/simulation_builder.html#openfisca_core.simulation_builder.SimulationBuilder.build_from_dict>"
                ]), Warning)
            from openfisca_core.simulation_builder import SimulationBuilder
            SimulationBuilder().build_from_entities(tax_benefit_system,
                                                    simulation_json,
                                                    simulation=self)
예제 #4
0
    def clone(self, debug = False, trace = False):
        """
            Copy the simulation just enough to be able to run the copy without modifying the original simulation
        """
        new = empty_clone(self)
        new_dict = new.__dict__

        for key, value in self.__dict__.items():
            if key not in ('debug', 'trace', 'tracer'):
                new_dict[key] = value

        new.persons = self.persons.clone(new)
        setattr(new, new.persons.key, new.persons)
        new.entities = {new.persons.key: new.persons}

        for entity_class in self.tax_benefit_system.group_entities:
            entity = self.entities[entity_class.key].clone(new)
            new.entities[entity.key] = entity
            setattr(new, entity_class.key, entity)  # create shortcut simulation.household (for instance)

        if debug:
            new_dict['debug'] = True
        if trace:
            new_dict['trace'] = True
        if debug or trace:
            if self.debug or self.trace:
                new_dict['tracer'] = self.tracer.clone()
            else:
                new_dict['tracer'] = Tracer()

        return new
예제 #5
0
def test_log_format():
    tracer = Tracer()
    tracer.record_calculation_start("A", 2017)
    tracer.record_calculation_start("B", 2017)
    tracer.record_calculation_end("B", 2017, 1)
    tracer.record_calculation_end("A", 2017, 2)

    lines = tracer.computation_log()
    assert_equals(lines[0], '  A<2017> >> 2')
    assert_equals(lines[1], '    B<2017> >> 1')
예제 #6
0
    def __init__(
            self,
            tax_benefit_system,
            simulation_json = None,
            debug = False,
            period = None,
            trace = False,
            opt_out_cache = False,
            memory_config = None,
            ):
        """
            If a ``simulation_json`` is given, initialises a simulation from a JSON dictionary.

            Note: This way of initialising a simulation, still under experimentation, aims at replacing the initialisation from `scenario.make_json_or_python_to_attributes`.

            If no ``simulation_json`` is given, initialises an empty simulation.
        """
        self.tax_benefit_system = tax_benefit_system
        assert tax_benefit_system is not None
        if period:
            assert isinstance(period, periods.Period)
        self.period = period

        # To keep track of the values (formulas and periods) being calculated to detect circular definitions.
        # See use in formulas.py.
        # The data structure of requested_periods_by_variable_name is: {variable_name: [period1, period2]}
        self.requested_periods_by_variable_name = {}
        self.max_nb_cycles = None

        self.debug = debug
        self.trace = trace or self.debug
        if self.trace:
            self.tracer = Tracer()
        else:
            self.tracer = None
        self.opt_out_cache = opt_out_cache

        self.memory_config = memory_config
        self._data_storage_dir = None
        self.instantiate_entities(simulation_json)
예제 #7
0
def test_trace_enums():
    tracer = Tracer()
    tracer.record_calculation_start("A", 2017)
    tracer.record_calculation_end("A", 2017, HousingOccupancyStatus.encode(np.array(['tenant'])))

    lines = tracer.computation_log()
    assert lines[0] == "  A<2017> >> ['tenant']"
예제 #8
0
def test_variable_stats():
    tracer = Tracer()
    tracer.record_calculation_start("A", 2017)
    tracer.record_calculation_start("B", 2017)
    tracer.record_calculation_start("B", 2017)
    tracer.record_calculation_start("B", 2016)

    assert_equals(tracer.usage_stats['B']['nb_requests'], 3)
    assert_equals(tracer.usage_stats['A']['nb_requests'], 1)
    assert_equals(tracer.usage_stats['C']['nb_requests'], 0)
예제 #9
0
def test_log_format():
    tracer = Tracer()
    tracer.record_calculation_start("A", 2017)
    tracer.record_calculation_start("B", 2017)
    tracer.record_calculation_end("B", 2017, 1)
    tracer.record_calculation_end("A", 2017, 2)

    lines = tracer.computation_log()
    assert lines[0] == '  A<2017> >> 2'
    assert lines[1] == '    B<2017> >> 1'
예제 #10
0
def test_consistency():
    tracer = Tracer()
    tracer.record_calculation_start("rsa", 2017)
    tracer.record_calculation_end("unkwonn", 2017, 100)
예제 #11
0
class Simulation(object):
    """
        Represents a simulation, and handles the calculation logic
    """
    debug = False
    period = None
    steps_count = 1
    tax_benefit_system = None
    trace = False

    # ----- Simulation construction ----- #

    def __init__(
            self,
            tax_benefit_system,
            simulation_json = None,
            debug = False,
            period = None,
            trace = False,
            opt_out_cache = False,
            memory_config = None,
            ):
        """
            Create an empty simulation

            To fill the simulation with input data, you can use the :any:`SimulationBuilder` or proceed manually.
        """
        self.tax_benefit_system = tax_benefit_system
        assert tax_benefit_system is not None
        if period:
            assert isinstance(period, periods.Period)
        self.period = period

        # To keep track of the values (formulas and periods) being calculated to detect circular definitions.
        # See use in formulas.py.
        # The data structure of requested_periods_by_variable_name is: {variable_name: [period1, period2]}
        self.requested_periods_by_variable_name = {}
        self.max_nb_cycles = None

        self.debug = debug
        self.trace = trace or self.debug
        if self.trace:
            self.tracer = Tracer()
        else:
            self.tracer = None
        self.opt_out_cache = opt_out_cache

        self.memory_config = memory_config
        self._data_storage_dir = None
        self.instantiate_entities()

        if simulation_json is not None:
            warnings.warn(' '.join([
                "The 'simulation_json' argument of the Simulation is deprecated since version 25.0, and will be removed in the future.",
                "The proper way to init a simulation from a JSON-like dict is to use SimulationBuilder.build_from_entities. See <https://openfisca.org/doc/openfisca-python-api/simulation_builder.html#openfisca_core.simulation_builder.SimulationBuilder.build_from_dict>"
                ]),
                Warning
                )
            from openfisca_core.simulation_builder import SimulationBuilder
            SimulationBuilder().build_from_entities(tax_benefit_system, simulation_json, simulation = self)

    def instantiate_entities(self):
        self.persons = self.tax_benefit_system.person_entity(self)
        self.entities = {self.persons.key: self.persons}
        setattr(self, self.persons.key, self.persons)  # create shortcut simulation.person (for instance)

        for entity_class in self.tax_benefit_system.group_entities:
            entities = entity_class(self)
            self.entities[entity_class.key] = entities
            setattr(self, entity_class.key, entities)  # create shortcut simulation.household (for instance)

    @property
    def data_storage_dir(self):
        """
        Temporary folder used to store intermediate calculation data in case the memory is saturated
        """
        if self._data_storage_dir is None:
            self._data_storage_dir = tempfile.mkdtemp(prefix = "openfisca_")
            log.warn((
                "Intermediate results will be stored on disk in {} in case of memory overflow. "
                "You should remove this directory once you're done with your simulation."
                ).format(self._data_storage_dir).encode('utf-8'))
        return self._data_storage_dir

    # ----- Calculation methods ----- #

    def calculate(self, variable_name, period, **parameters):
        """
            Calculate the variable ``variable_name`` for the period ``period``, using the variable formula if it exists.

            :returns: A numpy array containing the result of the calculation
        """
        entity = self.get_variable_entity(variable_name)
        holder = entity.get_holder(variable_name)
        variable = self.tax_benefit_system.get_variable(variable_name)

        if period is not None and not isinstance(period, periods.Period):
            period = periods.period(period)

        if self.trace:
            self.tracer.record_calculation_start(variable.name, period, **parameters)

        self._check_period_consistency(period, variable)

        extra_params = parameters.get('extra_params', ())

        # First look for a value already cached
        cached_array = holder.get_array(period, extra_params)
        if cached_array is not None:
            if self.trace:
                self.tracer.record_calculation_end(variable.name, period, cached_array, **parameters)
            return cached_array

        max_nb_cycles = parameters.get('max_nb_cycles')
        if max_nb_cycles is not None:
            self.max_nb_cycles = max_nb_cycles

        # First, try to run a formula
        array = self._run_formula(variable, entity, period, extra_params, max_nb_cycles)

        # If no result, try a base function
        if array is None and variable.base_function:
            array = variable.base_function(holder, period, *extra_params)

        # If no result, use the default value
        if array is None:
            array = holder.default_array()

        self._clean_cycle_detection_data(variable.name)
        if max_nb_cycles is not None:
            self.max_nb_cycles = None

        holder.put_in_cache(array, period, extra_params)
        if self.trace:
            self.tracer.record_calculation_end(variable.name, period, array, **parameters)

        return array

    def calculate_add(self, variable_name, period, **parameters):
        variable = self.tax_benefit_system.get_variable(variable_name)

        if period is not None and not isinstance(period, periods.Period):
            period = periods.period(period)

        # Check that the requested period matches definition_period
        if periods.unit_weight(variable.definition_period) > periods.unit_weight(period.unit):
            raise ValueError("Unable to compute variable '{0}' for period {1}: '{0}' can only be computed for {2}-long periods. You can use the DIVIDE option to get an estimate of {0} by dividing the yearly value by 12, or change the requested period to 'period.this_year'.".format(
                variable.name,
                period,
                variable.definition_period
                ).encode('utf-8'))

        if variable.definition_period not in [periods.DAY, periods.MONTH, periods.YEAR]:
            raise ValueError("Unable to sum constant variable '{}' over period {}: only variables defined daily, monthly, or yearly can be summed over time.".format(
                variable.name,
                period).encode('utf-8'))

        return sum(
            self.calculate(variable_name, sub_period, **parameters)
            for sub_period in period.get_subperiods(variable.definition_period)
            )

    def calculate_divide(self, variable_name, period, **parameters):
        variable = self.tax_benefit_system.get_variable(variable_name)

        if period is not None and not isinstance(period, periods.Period):
            period = periods.period(period)

        # Check that the requested period matches definition_period
        if variable.definition_period != periods.YEAR:
            raise ValueError("Unable to divide the value of '{}' over time on period {}: only variables defined yearly can be divided over time.".format(
                variable_name,
                period).encode('utf-8'))

        if period.size != 1:
            raise ValueError("DIVIDE option can only be used for a one-year or a one-month requested period")

        if period.unit == periods.MONTH:
            computation_period = period.this_year
            return self.calculate(variable_name, period = computation_period, **parameters) / 12.
        elif period.unit == periods.YEAR:
            return self.calculate(variable_name, period, **parameters)

        raise ValueError("Unable to divide the value of '{}' to match period {}.".format(
            variable_name,
            period).encode('utf-8'))

    def calculate_output(self, variable_name, period):
        """
            Calculate the value of a variable using the ``calculate_output`` attribute of the variable.
        """

        variable = self.tax_benefit_system.get_variable(variable_name, check_existence = True)

        if variable.calculate_output is None:
            return self.calculate(variable_name, period)

        return variable.calculate_output(self, variable_name, period)

    def trace_parameters_at_instant(self, formula_period):
        return TracingParameterNodeAtInstant(
            self.tax_benefit_system.get_parameters_at_instant(formula_period),
            self.tracer
            )

    def _run_formula(self, variable, entity, period, extra_params, max_nb_cycles):
        """
            Find the ``variable`` formula for the given ``period`` if it exists, and apply it to ``entity``.
        """

        formula = variable.get_formula(period)
        if formula is None:
            return None

        if self.trace:
            parameters_at = self.trace_parameters_at_instant
        else:
            parameters_at = self.tax_benefit_system.get_parameters_at_instant

        try:
            self._check_for_cycle(variable, period)
            if formula.__code__.co_argcount == 2:
                array = formula(entity, period)
            else:
                array = formula(entity, period, parameters_at, *extra_params)
        except CycleError as error:
            self._clean_cycle_detection_data(variable.name)
            if max_nb_cycles is None:
                if self.trace:
                    self.tracer.record_calculation_abortion(variable.name, period, extra_params = extra_params)
                # Re-raise until reaching the first variable called with max_nb_cycles != None in the stack.
                raise error
            self.max_nb_cycles = None
            return None

        self._check_formula_result(array, variable, entity, period)
        return self._cast_formula_result(array, variable)

    def _check_period_consistency(self, period, variable):
        """
            Check that a period matches the variable definition_period
        """
        if variable.definition_period == periods.ETERNITY:
            return  # For variables which values are constant in time, all periods are accepted

        if variable.definition_period == periods.MONTH and period.unit != periods.MONTH:
            raise ValueError("Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole month. You can use the ADD option to sum '{0}' over the requested period, or change the requested period to 'period.first_month'.".format(
                variable.name,
                period
                ).encode('utf-8'))

        if variable.definition_period == periods.YEAR and period.unit != periods.YEAR:
            raise ValueError("Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole year. You can use the DIVIDE option to get an estimate of {0} by dividing the yearly value by 12, or change the requested period to 'period.this_year'.".format(
                variable.name,
                period
                ).encode('utf-8'))

        if period.size != 1:
            raise ValueError("Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole {2}. You can use the ADD option to sum '{0}' over the requested period.".format(
                variable.name,
                period,
                'month' if variable.definition_period == periods.MONTH else 'year'
                ).encode('utf-8'))

    def _check_formula_result(self, value, variable, entity, period):

        assert isinstance(value, np.ndarray), (linesep.join([
            "You tried to compute the formula '{0}' for the period '{1}'.".format(variable.name, str(period)),
            "The formula '{0}@{1}' should return a Numpy array;".format(variable.name, str(period)),
            "instead it returned '{0}' of {1}.".format(value, type(value)),
            "Learn more about Numpy arrays and vectorial computing:",
            "<https://openfisca.org/doc/coding-the-legislation/25_vectorial_computing.html.>"
            ]))

        assert value.size == entity.count, \
            "Function {}@{}<{}>() --> <{}>{} returns an array of size {}, but size {} is expected for {}".format(
                variable.name, entity.key, str(period), str(period), stringify_array(value),
                value.size, entity.count, entity.key).encode('utf-8')

        if self.debug:
            try:
                # cf https://stackoverflow.com/questions/6736590/fast-check-for-nan-in-numpy
                if np.isnan(np.min(value)):
                    nan_count = np.count_nonzero(np.isnan(value))
                    raise NaNCreationError("Function {}@{}<{}>() --> <{}>{} returns {} NaN value(s)".format(
                        variable.name, entity.key, str(period), str(period), stringify_array(value),
                        nan_count).encode('utf-8'))
            except TypeError:
                pass

    def _cast_formula_result(self, value, variable):
        if variable.value_type == Enum and not isinstance(value, EnumArray):
            return variable.possible_values.encode(value)

        if value.dtype != variable.dtype:
            return value.astype(variable.dtype)

        return value

    # ----- Handle circular dependencies in a calculation ----- #

    def _check_for_cycle(self, variable, period):
        """
        Return a boolean telling if the current variable has already been called without being allowed by
        the parameter max_nb_cycles of the calculate method.
        """
        def get_error_message():
            return "Circular definition detected on formula {}@{}. Formulas and periods involved: {}.".format(
                variable.name,
                period,
                ", ".join(sorted(set(
                    "{}@{}".format(variable_name, period2)
                    for variable_name, periods in requested_periods_by_variable_name.items()
                    for period2 in periods
                    ))).encode('utf-8'),
                )
        requested_periods_by_variable_name = self.requested_periods_by_variable_name
        variable_name = variable.name
        if variable_name in requested_periods_by_variable_name:
            # Make sure the formula doesn't call itself for the same period it is being called for.
            # It would be a pure circular definition.
            requested_periods = requested_periods_by_variable_name[variable_name]
            assert period not in requested_periods and (variable.definition_period != periods.ETERNITY), get_error_message()
            if self.max_nb_cycles is None or len(requested_periods) > self.max_nb_cycles:
                message = get_error_message()
                if self.max_nb_cycles is None:
                    message += ' Hint: use "max_nb_cycles = 0" to get a default value, or "= N" to allow N cycles.'
                raise CycleError(message)
            else:
                requested_periods.append(period)
        else:
            requested_periods_by_variable_name[variable_name] = [period]

    def _clean_cycle_detection_data(self, variable_name):
        """
        When the value of a formula have been computed, remove the period from
        requested_periods_by_variable_name[variable_name] and delete the latter if empty.
        """

        requested_periods_by_variable_name = self.requested_periods_by_variable_name
        if variable_name in requested_periods_by_variable_name:
            requested_periods_by_variable_name[variable_name].pop()
            if len(requested_periods_by_variable_name[variable_name]) == 0:
                del requested_periods_by_variable_name[variable_name]

    # ----- Methods to access stored values ----- #

    def get_array(self, variable_name, period):
        """
            Return the value of ``variable_name`` for ``period``, if this value is alreay in the cache (if it has been set as an input or previously calculated).

            Unlike :any:`calculate`, this method *does not* trigger calculations and *does not* use any formula.
        """
        if period is not None and not isinstance(period, periods.Period):
            period = periods.period(period)
        return self.get_holder(variable_name).get_array(period)

    def get_holder(self, variable_name):
        """
            Get the :any:`Holder` associated with the variable ``variable_name`` for the simulation
        """
        return self.get_variable_entity(variable_name).get_holder(variable_name)

    def get_memory_usage(self, variables = None):
        """
            Get data about the virtual memory usage of the simulation
        """
        result = dict(
            total_nb_bytes = 0,
            by_variable = {}
            )
        for entity in self.entities.values():
            entity_memory_usage = entity.get_memory_usage(variables = variables)
            result['total_nb_bytes'] += entity_memory_usage['total_nb_bytes']
            result['by_variable'].update(entity_memory_usage['by_variable'])
        return result

    # ----- Misc ----- #

    def delete_arrays(self, variable, period = None):
        """
            Delete a variable's value for a given period

            :param variable: the variable to be set
            :param period: the period for which the value should be deleted

            Example:

            >>> from openfisca_country_template import CountryTaxBenefitSystem
            >>> simulation = Simulation(CountryTaxBenefitSystem())
            >>> simulation.set_input('age', '2018-04', [12, 14])
            >>> simulation.set_input('age', '2018-05', [13, 14])
            >>> simulation.get_array('age', '2018-05')
            array([13, 14], dtype=int32)
            >>> simulation.delete_arrays('age', '2018-05')
            >>> simulation.get_array('age', '2018-04')
            array([12, 14], dtype=int32)
            >>> simulation.get_array('age', '2018-05') is None
            True
            >>> simulation.set_input('age', '2018-05', [13, 14])
            >>> simulation.delete_arrays('age')
            >>> simulation.get_array('age', '2018-04') is None
            True
            >>> simulation.get_array('age', '2018-05') is None
            True
        """
        self.get_holder(variable).delete_arrays(period)

    def get_known_periods(self, variable):
        """
            Get a list variable's known period, i.e. the periods where a value has been initialized and

            :param variable: the variable to be set

            Example:

            >>> from openfisca_country_template import CountryTaxBenefitSystem
            >>> simulation = Simulation(CountryTaxBenefitSystem())
            >>> simulation.set_input('age', '2018-04', [12, 14])
            >>> simulation.set_input('age', '2018-05', [13, 14])
            >>> simulation.get_known_periods('age')
            [Period((u'month', Instant((2018, 5, 1)), 1)), Period((u'month', Instant((2018, 4, 1)), 1))]

        """
        return self.get_holder(variable).get_known_periods()

    def set_input(self, variable, period, value):
        """
            Set a variable's value for a given period

            :param variable: the variable to be set
            :param value: the input value for the variable
            :param period: the period for which the value is setted

            Example:
            >>> from openfisca_country_template import CountryTaxBenefitSystem
            >>> simulation = Simulation(CountryTaxBenefitSystem())
            >>> simulation.set_input('age', '2018-04', [12, 14])
            >>> simulation.get_array('age', '2018-04')
            array([12, 14], dtype=int32)

            If a ``set_input`` property has been set for the variable, this method may accept inputs for periods not matching the ``definition_period`` of the variable. To read more about this, check the `documentation <https://openfisca.org/doc/coding-the-legislation/35_periods.html#automatically-process-variable-inputs-defined-for-periods-not-matching-the-definitionperiod>`_.
        """
        self.get_holder(variable).set_input(period, value)

    def get_variable_entity(self, variable_name):

        variable = self.tax_benefit_system.get_variable(variable_name, check_existence = True)
        return self.get_entity(variable.entity)

    def get_entity(self, entity_type = None, plural = None):
        if entity_type:
            return self.entities[entity_type.key]
        if plural:
            return next((entity for entity in self.entities.values() if entity.plural == plural), None)

    def clone(self, debug = False, trace = False):
        """
            Copy the simulation just enough to be able to run the copy without modifying the original simulation
        """
        new = empty_clone(self)
        new_dict = new.__dict__

        for key, value in self.__dict__.items():
            if key not in ('debug', 'trace', 'tracer'):
                new_dict[key] = value

        new.persons = self.persons.clone(new)
        setattr(new, new.persons.key, new.persons)
        new.entities = {new.persons.key: new.persons}

        for entity_class in self.tax_benefit_system.group_entities:
            entity = self.entities[entity_class.key].clone(new)
            new.entities[entity.key] = entity
            setattr(new, entity_class.key, entity)  # create shortcut simulation.household (for instance)

        if debug:
            new_dict['debug'] = True
        if trace:
            new_dict['trace'] = True
        if debug or trace:
            if self.debug or self.trace:
                new_dict['tracer'] = self.tracer.clone()
            else:
                new_dict['tracer'] = Tracer()

        return new
예제 #12
0
def check_tracing_params(accessor, param_key):
    tracer = Tracer()
    tracer.record_calculation_start('A', '2015-01')
    tracingParams = TracingParameterNodeAtInstant(parameters('2015-01-01'), tracer)
    param = accessor(tracingParams)
    assert_near(tracer.trace['A<2015-01>']['parameters'][param_key], param)
예제 #13
0
def test_consistency():
    with pytest.raises(ValueError):
        tracer = Tracer()
        tracer.record_calculation_start("rsa", 2017)
        tracer.record_calculation_end("unkwonn", 2017, 100)
# -*- coding: utf-8 -*-

from nose.tools import raises

from openfisca_core.tracers import Tracer

tracer = Tracer()


@raises(ValueError)
def test_consistency():
    tracer.record_calculation_start("rsa", "2016-01")
    tracer.record_calculation_end("unkwonn", "2016-01", 100)
예제 #15
0
 def trace(self, trace):
     self._trace = trace
     if trace:
         self.tracer = Tracer()
     else:
         self.tracer = None
예제 #16
0
class Simulation(object):
    """
        Represents a simulation, and handles the calculation logic
    """

    def __init__(
            self,
            tax_benefit_system,
            populations
            ):
        """
            This constructor is reserved for internal use; see :any:`SimulationBuilder`,
            which is the preferred way to obtain a Simulation initialized with a consistent
            set of Entities.
        """
        self.tax_benefit_system = tax_benefit_system
        assert tax_benefit_system is not None

        self.populations = populations
        self.persons = self.populations[tax_benefit_system.person_entity.key]
        self.link_to_entities_instances()
        self.create_shortcuts()

        # To keep track of the values (formulas and periods) being calculated to detect circular definitions.
        # The data structure of computation_stack is:
        # [('variable_name', 'period1'), ('variable_name', 'period2')]
        self.computation_stack = []
        self.invalidated_caches = set()

        self.debug = False
        self.trace = False
        self.opt_out_cache = False

        # controls the spirals detection; check for performance impact if > 1
        self.max_spiral_loops = 1
        self.memory_config = None
        self._data_storage_dir = None

    @property
    def trace(self):
        return self._trace

    @trace.setter
    def trace(self, trace):
        self._trace = trace
        if trace:
            self.tracer = Tracer()
        else:
            self.tracer = None

    def link_to_entities_instances(self):
        for _key, entity_instance in self.populations.items():
            entity_instance.simulation = self

    def create_shortcuts(self):
        for _key, population in self.populations.items():
            # create shortcut simulation.person and simulation.household (for instance)
            setattr(self, population.entity.key, population)

    @property
    def data_storage_dir(self):
        """
        Temporary folder used to store intermediate calculation data in case the memory is saturated
        """
        if self._data_storage_dir is None:
            self._data_storage_dir = tempfile.mkdtemp(prefix = "openfisca_")
            log.warn((
                "Intermediate results will be stored on disk in {} in case of memory overflow. "
                "You should remove this directory once you're done with your simulation."
                ).format(self._data_storage_dir))
        return self._data_storage_dir

    # ----- Calculation methods ----- #

    def calculate(self, variable_name, period, **parameters):
        """
            Calculate the variable ``variable_name`` for the period ``period``, using the variable formula if it exists.

            :returns: A numpy array containing the result of the calculation
        """
        population = self.get_variable_population(variable_name)
        holder = population.get_holder(variable_name)
        variable = self.tax_benefit_system.get_variable(variable_name, check_existence = True)

        if period is not None and not isinstance(period, periods.Period):
            period = periods.period(period)

        if self.trace:
            self.tracer.record_calculation_start(variable.name, period, **parameters)

        self._check_period_consistency(period, variable)

        # First look for a value already cached
        cached_array = holder.get_array(period)
        if cached_array is not None:
            if self.trace:
                self.tracer.record_calculation_end(variable.name, period, cached_array, **parameters)
            return cached_array

        array = None

        # First, try to run a formula
        try:
            self._check_for_cycle(variable, period)
            array = self._run_formula(variable, population, period)

            # If no result, use the default value and cache it
            if array is None:
                array = holder.default_array()

            array = self._cast_formula_result(array, variable)

            holder.put_in_cache(array, period)
        except SpiralError:
            array = holder.default_array()
        finally:
            if self.trace:
                self.tracer.record_calculation_end(variable.name, period, array, **parameters)
            self._clean_cycle_detection_data(variable.name)

        self.purge_cache_of_invalid_values()

        return array

    def purge_cache_of_invalid_values(self):
        # We wait for the end of calculate(), signalled by an empty stack, before purging the cache
        if self.computation_stack:
            return
        for (_name, _period) in self.invalidated_caches:
            holder = self.get_holder(_name)
            holder.delete_arrays(_period)

    def calculate_add(self, variable_name, period, **parameters):
        variable = self.tax_benefit_system.get_variable(variable_name, check_existence = True)

        if period is not None and not isinstance(period, periods.Period):
            period = periods.period(period)

        # Check that the requested period matches definition_period
        if periods.unit_weight(variable.definition_period) > periods.unit_weight(period.unit):
            raise ValueError("Unable to compute variable '{0}' for period {1}: '{0}' can only be computed for {2}-long periods. You can use the DIVIDE option to get an estimate of {0} by dividing the yearly value by 12, or change the requested period to 'period.this_year'.".format(
                variable.name,
                period,
                variable.definition_period
                ))

        if variable.definition_period not in [periods.DAY, periods.MONTH, periods.YEAR]:
            raise ValueError("Unable to sum constant variable '{}' over period {}: only variables defined daily, monthly, or yearly can be summed over time.".format(
                variable.name,
                period))

        return sum(
            self.calculate(variable_name, sub_period, **parameters)
            for sub_period in period.get_subperiods(variable.definition_period)
            )

    def calculate_divide(self, variable_name, period, **parameters):
        variable = self.tax_benefit_system.get_variable(variable_name, check_existence = True)

        if period is not None and not isinstance(period, periods.Period):
            period = periods.period(period)

        # Check that the requested period matches definition_period
        if variable.definition_period != periods.YEAR:
            raise ValueError("Unable to divide the value of '{}' over time on period {}: only variables defined yearly can be divided over time.".format(
                variable_name,
                period))

        if period.size != 1:
            raise ValueError("DIVIDE option can only be used for a one-year or a one-month requested period")

        if period.unit == periods.MONTH:
            computation_period = period.this_year
            return self.calculate(variable_name, period = computation_period, **parameters) / 12.
        elif period.unit == periods.YEAR:
            return self.calculate(variable_name, period, **parameters)

        raise ValueError("Unable to divide the value of '{}' to match period {}.".format(
            variable_name,
            period))

    def calculate_output(self, variable_name, period):
        """
            Calculate the value of a variable using the ``calculate_output`` attribute of the variable.
        """

        variable = self.tax_benefit_system.get_variable(variable_name, check_existence = True)

        if variable.calculate_output is None:
            return self.calculate(variable_name, period)

        return variable.calculate_output(self, variable_name, period)

    def trace_parameters_at_instant(self, formula_period):
        return TracingParameterNodeAtInstant(
            self.tax_benefit_system.get_parameters_at_instant(formula_period),
            self.tracer
            )

    def _run_formula(self, variable, population, period):
        """
            Find the ``variable`` formula for the given ``period`` if it exists, and apply it to ``population``.
        """

        formula = variable.get_formula(period)
        if formula is None:
            return None

        if self.trace:
            parameters_at = self.trace_parameters_at_instant
        else:
            parameters_at = self.tax_benefit_system.get_parameters_at_instant

        if formula.__code__.co_argcount == 2:
            array = formula(population, period)
        else:
            array = formula(population, period, parameters_at)

        return array

    def _check_period_consistency(self, period, variable):
        """
            Check that a period matches the variable definition_period
        """
        if variable.definition_period == periods.ETERNITY:
            return  # For variables which values are constant in time, all periods are accepted

        if variable.definition_period == periods.MONTH and period.unit != periods.MONTH:
            raise ValueError("Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole month. You can use the ADD option to sum '{0}' over the requested period, or change the requested period to 'period.first_month'.".format(
                variable.name,
                period
                ))

        if variable.definition_period == periods.YEAR and period.unit != periods.YEAR:
            raise ValueError("Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole year. You can use the DIVIDE option to get an estimate of {0} by dividing the yearly value by 12, or change the requested period to 'period.this_year'.".format(
                variable.name,
                period
                ))

        if period.size != 1:
            raise ValueError("Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole {2}. You can use the ADD option to sum '{0}' over the requested period.".format(
                variable.name,
                period,
                'month' if variable.definition_period == periods.MONTH else 'year'
                ))

    def _cast_formula_result(self, value, variable):
        if variable.value_type == Enum and not isinstance(value, EnumArray):
            return variable.possible_values.encode(value)

        if not isinstance(value, np.ndarray):
            population = self.get_variable_population(variable.name)
            value = population.filled_array(value)

        if value.dtype != variable.dtype:
            return value.astype(variable.dtype)

        return value

    # ----- Handle circular dependencies in a calculation ----- #

    def _check_for_cycle(self, variable, period):
        """
        Raise an exception in the case of a circular definition, where evaluating a variable for
        a given period loops around to evaluating the same variable/period pair. Also guards, as
        a heuristic, against "quasicircles", where the evaluation of a variable at a period involves
        the same variable at a different period.
        """
        previous_periods = [_period for (_name, _period) in self.computation_stack if _name == variable.name]
        self.computation_stack.append((variable.name, str(period)))
        if str(period) in previous_periods:
            raise CycleError("Circular definition detected on formula {}@{}".format(variable.name, period))
        spiral = len(previous_periods) >= self.max_spiral_loops
        if spiral:
            self.invalidate_spiral_variables(variable)
            message = "Quasicircular definition detected on formula {}@{} involving {}".format(variable.name, period, self.computation_stack)
            raise SpiralError(message, variable.name)

    def invalidate_spiral_variables(self, variable):
        # Visit the stack, from the bottom (most recent) up; we know that we'll find
        # the variable implicated in the spiral (max_spiral_loops+1) times; we keep the
        # intermediate values computed (to avoid impacting performance) but we mark them
        # for deletion from the cache once the calculation ends.
        count = 0
        for frame in reversed(self.computation_stack):
            self.invalidated_caches.add(frame)
            if frame[0] == variable.name:
                count += 1
                if count > self.max_spiral_loops:
                    break

    def _clean_cycle_detection_data(self, variable_name):
        """
        When the value of a formula have been computed, remove the period from
        requested_periods_by_variable_name[variable_name] and delete the latter if empty.
        """
        self.computation_stack.pop()

    # ----- Methods to access stored values ----- #

    def get_array(self, variable_name, period):
        """
            Return the value of ``variable_name`` for ``period``, if this value is alreay in the cache (if it has been set as an input or previously calculated).

            Unlike :any:`calculate`, this method *does not* trigger calculations and *does not* use any formula.
        """
        if period is not None and not isinstance(period, periods.Period):
            period = periods.period(period)
        return self.get_holder(variable_name).get_array(period)

    def get_holder(self, variable_name):
        """
            Get the :any:`Holder` associated with the variable ``variable_name`` for the simulation
        """
        return self.get_variable_population(variable_name).get_holder(variable_name)

    def get_memory_usage(self, variables = None):
        """
            Get data about the virtual memory usage of the simulation
        """
        result = dict(
            total_nb_bytes = 0,
            by_variable = {}
            )
        for entity in self.populations.values():
            entity_memory_usage = entity.get_memory_usage(variables = variables)
            result['total_nb_bytes'] += entity_memory_usage['total_nb_bytes']
            result['by_variable'].update(entity_memory_usage['by_variable'])
        return result

    # ----- Misc ----- #

    def delete_arrays(self, variable, period = None):
        """
            Delete a variable's value for a given period

            :param variable: the variable to be set
            :param period: the period for which the value should be deleted

            Example:

            >>> from openfisca_country_template import CountryTaxBenefitSystem
            >>> simulation = Simulation(CountryTaxBenefitSystem())
            >>> simulation.set_input('age', '2018-04', [12, 14])
            >>> simulation.set_input('age', '2018-05', [13, 14])
            >>> simulation.get_array('age', '2018-05')
            array([13, 14], dtype=int32)
            >>> simulation.delete_arrays('age', '2018-05')
            >>> simulation.get_array('age', '2018-04')
            array([12, 14], dtype=int32)
            >>> simulation.get_array('age', '2018-05') is None
            True
            >>> simulation.set_input('age', '2018-05', [13, 14])
            >>> simulation.delete_arrays('age')
            >>> simulation.get_array('age', '2018-04') is None
            True
            >>> simulation.get_array('age', '2018-05') is None
            True
        """
        self.get_holder(variable).delete_arrays(period)

    def get_known_periods(self, variable):
        """
            Get a list variable's known period, i.e. the periods where a value has been initialized and

            :param variable: the variable to be set

            Example:

            >>> from openfisca_country_template import CountryTaxBenefitSystem
            >>> simulation = Simulation(CountryTaxBenefitSystem())
            >>> simulation.set_input('age', '2018-04', [12, 14])
            >>> simulation.set_input('age', '2018-05', [13, 14])
            >>> simulation.get_known_periods('age')
            [Period((u'month', Instant((2018, 5, 1)), 1)), Period((u'month', Instant((2018, 4, 1)), 1))]

        """
        return self.get_holder(variable).get_known_periods()

    def set_input(self, variable_name, period, value):
        """
            Set a variable's value for a given period

            :param variable: the variable to be set
            :param value: the input value for the variable
            :param period: the period for which the value is setted

            Example:
            >>> from openfisca_country_template import CountryTaxBenefitSystem
            >>> simulation = Simulation(CountryTaxBenefitSystem())
            >>> simulation.set_input('age', '2018-04', [12, 14])
            >>> simulation.get_array('age', '2018-04')
            array([12, 14], dtype=int32)

            If a ``set_input`` property has been set for the variable, this method may accept inputs for periods not matching the ``definition_period`` of the variable. To read more about this, check the `documentation <https://openfisca.org/doc/coding-the-legislation/35_periods.html#automatically-process-variable-inputs-defined-for-periods-not-matching-the-definitionperiod>`_.
        """
        variable = self.tax_benefit_system.get_variable(variable_name, check_existence = True)
        period = periods.period(period)
        if ((variable.end is not None) and (period.start.date > variable.end)):
            return
        self.get_holder(variable_name).set_input(period, value)

    def get_variable_population(self, variable_name):
        variable = self.tax_benefit_system.get_variable(variable_name, check_existence = True)
        return self.populations[variable.entity.key]

    def get_population(self, plural = None):
        return next((population for population in self.populations.values() if population.entity.plural == plural), None)

    def get_entity(self, plural = None):
        population = self.get_population(plural)
        return population and population.entity

    def describe_entities(self):
        return {population.entity.plural: population.ids for population in self.populations.values()}

    def clone(self, debug = False, trace = False):
        """
            Copy the simulation just enough to be able to run the copy without modifying the original simulation
        """
        new = empty_clone(self)
        new_dict = new.__dict__

        for key, value in self.__dict__.items():
            if key not in ('debug', 'trace', 'tracer'):
                new_dict[key] = value

        new.persons = self.persons.clone(new)
        setattr(new, new.persons.entity.key, new.persons)
        new.populations = {new.persons.entity.key: new.persons}

        for entity in self.tax_benefit_system.group_entities:
            population = self.populations[entity.key].clone(new)
            new.populations[entity.key] = population
            setattr(new, entity.key, population)  # create shortcut simulation.household (for instance)

        new.debug = debug
        new.trace = trace

        return new