예제 #1
0
    def __init__(self, tax_benefit_system, populations):
        """
        This constructor is reserved for internal use; see :any:`SimulationBuilder`,
        which is the preferred way to obtain a Simulation initialized with a consistent
        set of Entities.
        """
        self.tax_benefit_system = tax_benefit_system
        assert tax_benefit_system is not None

        self.populations = populations
        self.persons = self.populations[tax_benefit_system.person_entity.key]
        self.link_to_entities_instances()
        self.create_shortcuts()

        self.invalidated_caches = set()

        self.debug = False
        self.trace = False
        self.tracer = SimpleTracer()
        self.opt_out_cache = False

        # controls the spirals detection; check for performance impact if > 1
        self.max_spiral_loops = 1
        self.memory_config = None
        self._data_storage_dir = None
예제 #2
0
 def trace(self, trace):
     self._trace = trace
     if trace:
         self.tracer = FullTracer()
     else:
         self.tracer = SimpleTracer()
예제 #3
0
class Simulation:
    """
    Represents a simulation, and handles the calculation logic
    """
    def __init__(self, tax_benefit_system, populations):
        """
        This constructor is reserved for internal use; see :any:`SimulationBuilder`,
        which is the preferred way to obtain a Simulation initialized with a consistent
        set of Entities.
        """
        self.tax_benefit_system = tax_benefit_system
        assert tax_benefit_system is not None

        self.populations = populations
        self.persons = self.populations[tax_benefit_system.person_entity.key]
        self.link_to_entities_instances()
        self.create_shortcuts()

        self.invalidated_caches = set()

        self.debug = False
        self.trace = False
        self.tracer = SimpleTracer()
        self.opt_out_cache = False

        # controls the spirals detection; check for performance impact if > 1
        self.max_spiral_loops = 1
        self.memory_config = None
        self._data_storage_dir = None

    @property
    def trace(self):
        return self._trace

    @trace.setter
    def trace(self, trace):
        self._trace = trace
        if trace:
            self.tracer = FullTracer()
        else:
            self.tracer = SimpleTracer()

    def link_to_entities_instances(self):
        for _key, entity_instance in self.populations.items():
            entity_instance.simulation = self

    def create_shortcuts(self):
        for _key, population in self.populations.items():
            # create shortcut simulation.person and simulation.household (for instance)
            setattr(self, population.entity.key, population)

    @property
    def data_storage_dir(self):
        """
        Temporary folder used to store intermediate calculation data in case the memory is saturated
        """
        if self._data_storage_dir is None:
            self._data_storage_dir = tempfile.mkdtemp(prefix="openfisca_")
            message = [(
                "Intermediate results will be stored on disk in {} in case of memory overflow."
            ).format(
                self._data_storage_dir
            ), "You should remove this directory once you're done with your simulation."
                       ]
            warnings.warn(" ".join(message), TempfileWarning)
        return self._data_storage_dir

    # ----- Calculation methods ----- #

    def calculate(self, variable_name, period):
        """Calculate ``variable_name`` for ``period``."""

        if period is not None and not isinstance(period, Period):
            period = periods.period(period)

        self.tracer.record_calculation_start(variable_name, period)

        try:
            result = self._calculate(variable_name, period)
            self.tracer.record_calculation_result(result)
            return result

        finally:
            self.tracer.record_calculation_end()
            self.purge_cache_of_invalid_values()

    def _calculate(self, variable_name, period: Period):
        """
        Calculate the variable ``variable_name`` for the period ``period``, using the variable formula if it exists.

        :returns: A numpy array containing the result of the calculation
        """
        population = self.get_variable_population(variable_name)
        holder = population.get_holder(variable_name)
        variable = self.tax_benefit_system.get_variable(variable_name,
                                                        check_existence=True)

        self._check_period_consistency(period, variable)

        # First look for a value already cached
        cached_array = holder.get_array(period)
        if cached_array is not None:
            return cached_array

        array = None

        # First, try to run a formula
        try:
            self._check_for_cycle(variable.name, period)
            array = self._run_formula(variable, population, period)

            # If no result, use the default value and cache it
            if array is None:
                array = holder.default_array()

            array = self._cast_formula_result(array, variable)
            holder.put_in_cache(array, period)

        except SpiralError:
            array = holder.default_array()

        return array

    def purge_cache_of_invalid_values(self):
        # We wait for the end of calculate(), signalled by an empty stack, before purging the cache
        if self.tracer.stack:
            return
        for (_name, _period) in self.invalidated_caches:
            holder = self.get_holder(_name)
            holder.delete_arrays(_period)
        self.invalidated_caches = set()

    def calculate_add(self, variable_name, period):
        variable = self.tax_benefit_system.get_variable(variable_name,
                                                        check_existence=True)

        if period is not None and not isinstance(period, Period):
            period = periods.period(period)

        # Check that the requested period matches definition_period
        if periods.unit_weight(
                variable.definition_period) > periods.unit_weight(period.unit):
            raise ValueError(
                "Unable to compute variable '{0}' for period {1}: '{0}' can only be computed for {2}-long periods. You can use the DIVIDE option to get an estimate of {0} by dividing the yearly value by 12, or change the requested period to 'period.this_year'."
                .format(variable.name, period, variable.definition_period))

        if variable.definition_period not in [
                periods.DAY, periods.MONTH, periods.YEAR
        ]:
            raise ValueError(
                "Unable to sum constant variable '{}' over period {}: only variables defined daily, monthly, or yearly can be summed over time."
                .format(variable.name, period))

        return sum(
            self.calculate(variable_name, sub_period) for sub_period in
            period.get_subperiods(variable.definition_period))

    def calculate_divide(self, variable_name, period):
        variable = self.tax_benefit_system.get_variable(variable_name,
                                                        check_existence=True)

        if period is not None and not isinstance(period, Period):
            period = periods.period(period)

        # Check that the requested period matches definition_period
        if variable.definition_period != periods.YEAR:
            raise ValueError(
                "Unable to divide the value of '{}' over time on period {}: only variables defined yearly can be divided over time."
                .format(variable_name, period))

        if period.size != 1:
            raise ValueError(
                "DIVIDE option can only be used for a one-year or a one-month requested period"
            )

        if period.unit == periods.MONTH:
            computation_period = period.this_year
            return self.calculate(variable_name,
                                  period=computation_period) / 12.
        elif period.unit == periods.YEAR:
            return self.calculate(variable_name, period)

        raise ValueError(
            "Unable to divide the value of '{}' to match period {}.".format(
                variable_name, period))

    def calculate_output(self, variable_name, period):
        """
        Calculate the value of a variable using the ``calculate_output`` attribute of the variable.
        """

        variable = self.tax_benefit_system.get_variable(variable_name,
                                                        check_existence=True)

        if variable.calculate_output is None:
            return self.calculate(variable_name, period)

        return variable.calculate_output(self, variable_name, period)

    def trace_parameters_at_instant(self, formula_period):
        return TracingParameterNodeAtInstant(
            self.tax_benefit_system.get_parameters_at_instant(formula_period),
            self.tracer)

    def _run_formula(self, variable, population, period):
        """
        Find the ``variable`` formula for the given ``period`` if it exists, and apply it to ``population``.
        """

        formula = variable.get_formula(period)
        if formula is None:
            return None

        if self.trace:
            parameters_at = self.trace_parameters_at_instant
        else:
            parameters_at = self.tax_benefit_system.get_parameters_at_instant

        if formula.__code__.co_argcount == 2:
            array = formula(population, period)
        else:
            array = formula(population, period, parameters_at)

        return array

    def _check_period_consistency(self, period, variable):
        """
        Check that a period matches the variable definition_period
        """
        if variable.definition_period == periods.ETERNITY:
            return  # For variables which values are constant in time, all periods are accepted

        if variable.definition_period == periods.MONTH and period.unit != periods.MONTH:
            raise ValueError(
                "Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole month. You can use the ADD option to sum '{0}' over the requested period, or change the requested period to 'period.first_month'."
                .format(variable.name, period))

        if variable.definition_period == periods.YEAR and period.unit != periods.YEAR:
            raise ValueError(
                "Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole year. You can use the DIVIDE option to get an estimate of {0} by dividing the yearly value by 12, or change the requested period to 'period.this_year'."
                .format(variable.name, period))

        if period.size != 1:
            raise ValueError(
                "Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole {2}. You can use the ADD option to sum '{0}' over the requested period."
                .format(
                    variable.name, period, 'month' if
                    variable.definition_period == periods.MONTH else 'year'))

    def _cast_formula_result(self, value, variable):
        if variable.value_type == Enum and not isinstance(value, EnumArray):
            return variable.possible_values.encode(value)

        if not isinstance(value, numpy.ndarray):
            population = self.get_variable_population(variable.name)
            value = population.filled_array(value)

        if value.dtype != variable.dtype:
            return value.astype(variable.dtype)

        return value

    # ----- Handle circular dependencies in a calculation ----- #

    def _check_for_cycle(self, variable: str, period):
        """
        Raise an exception in the case of a circular definition, where evaluating a variable for
        a given period loops around to evaluating the same variable/period pair. Also guards, as
        a heuristic, against "quasicircles", where the evaluation of a variable at a period involves
        the same variable at a different period.
        """
        # The last frame is the current calculation, so it should be ignored from cycle detection
        previous_periods = [
            frame['period'] for frame in self.tracer.stack[:-1]
            if frame['name'] == variable
        ]
        if period in previous_periods:
            raise CycleError(
                "Circular definition detected on formula {}@{}".format(
                    variable, period))
        spiral = len(previous_periods) >= self.max_spiral_loops
        if spiral:
            self.invalidate_spiral_variables(variable)
            message = "Quasicircular definition detected on formula {}@{} involving {}".format(
                variable, period, self.tracer.stack)
            raise SpiralError(message, variable)

    def invalidate_cache_entry(self, variable: str, period):
        self.invalidated_caches.add((variable, period))

    def invalidate_spiral_variables(self, variable: str):
        # Visit the stack, from the bottom (most recent) up; we know that we'll find
        # the variable implicated in the spiral (max_spiral_loops+1) times; we keep the
        # intermediate values computed (to avoid impacting performance) but we mark them
        # for deletion from the cache once the calculation ends.
        count = 0
        for frame in reversed(self.tracer.stack):
            self.invalidate_cache_entry(frame['name'], frame['period'])
            if frame['name'] == variable:
                count += 1
                if count > self.max_spiral_loops:
                    break

    # ----- Methods to access stored values ----- #

    def get_array(self, variable_name, period):
        """
        Return the value of ``variable_name`` for ``period``, if this value is alreay in the cache (if it has been set as an input or previously calculated).

        Unlike :meth:`.calculate`, this method *does not* trigger calculations and *does not* use any formula.
        """
        if period is not None and not isinstance(period, Period):
            period = periods.period(period)
        return self.get_holder(variable_name).get_array(period)

    def get_holder(self, variable_name):
        """
        Get the :obj:`.Holder` associated with the variable ``variable_name`` for the simulation
        """
        return self.get_variable_population(variable_name).get_holder(
            variable_name)

    def get_memory_usage(self, variables=None):
        """
        Get data about the virtual memory usage of the simulation
        """
        result = dict(total_nb_bytes=0, by_variable={})
        for entity in self.populations.values():
            entity_memory_usage = entity.get_memory_usage(variables=variables)
            result['total_nb_bytes'] += entity_memory_usage['total_nb_bytes']
            result['by_variable'].update(entity_memory_usage['by_variable'])
        return result

    # ----- Misc ----- #

    def delete_arrays(self, variable, period=None):
        """
        Delete a variable's value for a given period

        :param variable: the variable to be set
        :param period: the period for which the value should be deleted

        Example:

        >>> from openfisca_country_template import CountryTaxBenefitSystem
        >>> simulation = Simulation(CountryTaxBenefitSystem())
        >>> simulation.set_input('age', '2018-04', [12, 14])
        >>> simulation.set_input('age', '2018-05', [13, 14])
        >>> simulation.get_array('age', '2018-05')
        array([13, 14], dtype=int32)
        >>> simulation.delete_arrays('age', '2018-05')
        >>> simulation.get_array('age', '2018-04')
        array([12, 14], dtype=int32)
        >>> simulation.get_array('age', '2018-05') is None
        True
        >>> simulation.set_input('age', '2018-05', [13, 14])
        >>> simulation.delete_arrays('age')
        >>> simulation.get_array('age', '2018-04') is None
        True
        >>> simulation.get_array('age', '2018-05') is None
        True
        """
        self.get_holder(variable).delete_arrays(period)

    def get_known_periods(self, variable):
        """
        Get a list variable's known period, i.e. the periods where a value has been initialized and

        :param variable: the variable to be set

        Example:

        >>> from openfisca_country_template import CountryTaxBenefitSystem
        >>> simulation = Simulation(CountryTaxBenefitSystem())
        >>> simulation.set_input('age', '2018-04', [12, 14])
        >>> simulation.set_input('age', '2018-05', [13, 14])
        >>> simulation.get_known_periods('age')
        [Period((u'month', Instant((2018, 5, 1)), 1)), Period((u'month', Instant((2018, 4, 1)), 1))]
        """
        return self.get_holder(variable).get_known_periods()

    def set_input(self, variable_name, period, value):
        """
        Set a variable's value for a given period

        :param variable: the variable to be set
        :param value: the input value for the variable
        :param period: the period for which the value is setted

        Example:
        >>> from openfisca_country_template import CountryTaxBenefitSystem
        >>> simulation = Simulation(CountryTaxBenefitSystem())
        >>> simulation.set_input('age', '2018-04', [12, 14])
        >>> simulation.get_array('age', '2018-04')
        array([12, 14], dtype=int32)

        If a ``set_input`` property has been set for the variable, this method may accept inputs for periods not matching the ``definition_period`` of the variable. To read more about this, check the `documentation <https://openfisca.org/doc/coding-the-legislation/35_periods.html#automatically-process-variable-inputs-defined-for-periods-not-matching-the-definitionperiod>`_.
        """
        variable = self.tax_benefit_system.get_variable(variable_name,
                                                        check_existence=True)
        period = periods.period(period)
        if ((variable.end is not None) and (period.start.date > variable.end)):
            return
        self.get_holder(variable_name).set_input(period, value)

    def get_variable_population(self, variable_name):
        variable = self.tax_benefit_system.get_variable(variable_name,
                                                        check_existence=True)
        return self.populations[variable.entity.key]

    def get_population(self, plural=None):
        return next((population for population in self.populations.values()
                     if population.entity.plural == plural), None)

    def get_entity(self, plural=None):
        population = self.get_population(plural)
        return population and population.entity

    def describe_entities(self):
        return {
            population.entity.plural: population.ids
            for population in self.populations.values()
        }

    def clone(self, debug=False, trace=False):
        """
        Copy the simulation just enough to be able to run the copy without modifying the original simulation
        """
        new = commons.empty_clone(self)
        new_dict = new.__dict__

        for key, value in self.__dict__.items():
            if key not in ('debug', 'trace', 'tracer'):
                new_dict[key] = value

        new.persons = self.persons.clone(new)
        setattr(new, new.persons.entity.key, new.persons)
        new.populations = {new.persons.entity.key: new.persons}

        for entity in self.tax_benefit_system.group_entities:
            population = self.populations[entity.key].clone(new)
            new.populations[entity.key] = population
            setattr(new, entity.key, population
                    )  # create shortcut simulation.household (for instance)

        new.debug = debug
        new.trace = trace

        return new
 def __init__(self):
     self._simple_tracer = SimpleTracer()
     self._trees = []
     self._current_node = None
class FullTracer:
    def __init__(self):
        self._simple_tracer = SimpleTracer()
        self._trees = []
        self._current_node = None

    def record_calculation_start(self, variable: str, period):
        self._simple_tracer.record_calculation_start(variable, period)
        self._enter_calculation(variable, period)
        self._record_start_time()

    def _enter_calculation(self, variable: str, period):
        new_node = TraceNode(name=variable,
                             period=period,
                             parent=self._current_node)
        if self._current_node is None:
            self._trees.append(new_node)
        else:
            self._current_node.append_child(new_node)
        self._current_node = new_node

    def record_parameter_access(self, parameter: str, period, value):
        self._current_node.parameters.append(
            TraceNode(name=parameter, period=period, value=value))

    def _record_start_time(self, time_in_s: typing.Optional[float] = None):
        if time_in_s is None:
            time_in_s = self._get_time_in_sec()

        self._current_node.start = time_in_s

    def record_calculation_result(self, value: numpy.ndarray):
        self._current_node.value = value

    def record_calculation_end(self):
        self._simple_tracer.record_calculation_end()
        self._record_end_time()
        self._exit_calculation()

    def _record_end_time(self, time_in_s: typing.Optional[float] = None):
        if time_in_s is None:
            time_in_s = self._get_time_in_sec()

        self._current_node.end = time_in_s

    def _exit_calculation(self):
        self._current_node = self._current_node.parent

    @property
    def stack(self):
        return self._simple_tracer.stack

    @property
    def trees(self):
        return self._trees

    @property
    def computation_log(self):
        return ComputationLog(self)

    @property
    def performance_log(self):
        return PerformanceLog(self)

    @property
    def flat_trace(self):
        return FlatTrace(self)

    def _get_time_in_sec(self) -> float:
        return time.time_ns() / (10**9)

    def print_computation_log(self, aggregate=False):
        self.computation_log.print_log(aggregate)

    def generate_performance_graph(self, dir_path: str) -> None:
        self.performance_log.generate_graph(dir_path)

    def generate_performance_tables(self, dir_path: str) -> None:
        self.performance_log.generate_performance_tables(dir_path)

    def _get_nb_requests(self, tree, variable: str):
        tree_call = tree.name == variable
        children_calls = sum(
            self._get_nb_requests(child, variable) for child in tree.children)

        return tree_call + children_calls

    def get_nb_requests(self, variable: str):
        return sum(
            self._get_nb_requests(tree, variable) for tree in self.trees)

    def get_flat_trace(self):
        return self.flat_trace.get_trace()

    def get_serialized_flat_trace(self):
        return self.flat_trace.get_serialized_trace()

    def browse_trace(self) -> typing.Iterator[TraceNode]:
        def _browse_node(node):
            yield node
            for child in node.children:
                yield from _browse_node(child)

        for node in self._trees:
            yield from _browse_node(node)
예제 #6
0
    def enter_calculation(self, variable, period):
        self.entered = True

    def record_calculation_result(self, value):
        self.recorded_result = True

    def exit_calculation(self):
        self.exited = True


@fixture
def tracer():
    return FullTracer()


@mark.parametrize("tracer", [SimpleTracer(), FullTracer()])
def test_stack_one_level(tracer):
    tracer.enter_calculation('a', 2017)
    assert len(tracer.stack) == 1
    assert tracer.stack == [{'name': 'a', 'period': 2017}]

    tracer.exit_calculation()
    assert tracer.stack == []


@mark.parametrize("tracer", [SimpleTracer(), FullTracer()])
def test_stack_two_levels(tracer):
    tracer.enter_calculation('a', 2017)
    tracer.enter_calculation('b', 2017)
    assert len(tracer.stack) == 2
    assert tracer.stack == [{