示例#1
0
 def __init__(self, parameter_or_parameters, body, conditions=tuple(), styles=None, requirements=tuple()):
     '''
     Initialize a Lambda function expression given parameter(s) and a body.
     Each parameter must be a Variable.
     When there is a single parameter, there will be a 'parameter'
     attribute. Either way, there will be a 'parameters' attribute
     that bundles the one or more Variables into an ExprList.
     The 'body' attribute will be the lambda function body
     Expression (that may or may not be a Composite).  Zero or
     more expressions may be provided.
     '''
     from proveit._core_.expression.composite import compositeExpression, singleOrCompositeExpression, Iter
     if styles is None: styles = dict()
     self.parameters = compositeExpression(parameter_or_parameters)
     parameterVars = [getParamVar(parameter) for parameter in self.parameters]
     if len(self.parameters) == 1:
         # has a single parameter
         self.parameter = self.parameters[0]
         self.parameter_or_parameters = self.parameter
     else:
         self.parameter_or_parameters = self.parameters
     self.parameterVars = tuple(parameterVars)
     self.parameterVarSet = frozenset(parameterVars)
     if len(self.parameterVarSet) != len(self.parameters):
         raise ValueError('Lambda parameters Variables must be unique with respect to each other.')
     body = singleOrCompositeExpression(body)
     if not isinstance(body, Expression):
         raise TypeError('A Lambda body must be of type Expression')
     if isinstance(body, Iter):
         raise TypeError('An Iter must be within an ExprList or ExprTensor, not directly as a Lambda body')
     self.body = body
     self.conditions = compositeExpression(conditions)
     for requirement in self.body.getRequirements():
         if not self.parameterVarSet.isdisjoint(requirement.freeVars()):
             raise LambdaError("Cannot generate a Lambda expression with parameter variables involved in Lambda body requirements: " + str(requirement))
     sub_exprs = [self.parameter_or_parameters, self.body]
     if len(self.conditions)>0: sub_exprs.append(self.conditions)
     
     # Create a "generic" version (if not already) of the Lambda expression since the 
     # choice of parameter labeling is irrelevant.
     generic_body_vars = self.body._genericExpr.usedVars()
     generic_condition_vars = self.conditions._genericExpr.usedVars()
     used_generic_vars = generic_body_vars.union(generic_condition_vars)
     generic_params = tuple(safeDummyVars(len(self.parameterVars), *(used_generic_vars-self.parameterVarSet)))
     if generic_params != self.parameterVars:
         relabel_map = {param:generic_param for param, generic_param in zip(self.parameterVars, generic_params)}
         # temporarily disable automation during the relabeling process
         prev_automation = defaults.automation
         defaults.automation = False
         generic_parameters = self.parameters._genericExpr.relabeled(relabel_map)
         generic_body = self.body._genericExpr.relabeled(relabel_map)
         generic_conditions = self.conditions._genericExpr.relabeled(relabel_map)
         self._genericExpr = Lambda(generic_parameters, generic_body, generic_conditions, styles=dict(styles), requirements=requirements)
         defaults.automation = prev_automation # restore to previous value
     
     Expression.__init__(self, ['Lambda'], sub_exprs, styles=styles, requirements=requirements)
示例#2
0
 def safeDummyVars(self, n):
     from proveit._core_.expression.label.var import safeDummyVars
     return safeDummyVars(n, self)
示例#3
0
    def substituted(self,
                    exprMap,
                    relabelMap=None,
                    reservedVars=None,
                    assumptions=USE_DEFAULTS,
                    requirements=None):
        '''
        Returns this expression with the substitutions made 
        according to exprMap and/or relabeled according to relabelMap.
        Attempt to automatically expand the iteration if any Indexed 
        sub-expressions substitute their variable for a composite
        (list or tensor).  Indexed should index variables that represent
        composites, but substituting the composite is a signal that
        an outer iteration should be expanded.  An exception is
        raised if this fails.
        '''
        from .composite import _generateCoordOrderAssumptions
        from proveit import ProofFailure, ExprArray
        from proveit.logic import Equals, InSet
        from proveit.number import Less, LessEq, dist_add, \
            zero, one, dist_subtract, Naturals, Integers
        from .composite import _simplifiedCoord
        from proveit._core_.expression.expr import _NoExpandedIteration
        from proveit._core_.expression.label.var import safeDummyVars

        self._checkRelabelMap(relabelMap)
        if relabelMap is None: relabelMap = dict()

        assumptions = defaults.checkedAssumptions(assumptions)
        new_requirements = []
        iter_params = self.lambda_map.parameters
        iter_body = self.lambda_map.body
        ndims = self.ndims
        subbed_start = self.start_indices.substituted(exprMap, relabelMap,
                                                      reservedVars,
                                                      assumptions,
                                                      new_requirements)
        subbed_end = self.end_indices.substituted(exprMap, relabelMap,
                                                  reservedVars, assumptions,
                                                  new_requirements)

        #print("iteration substituted", self, subbed_start, subbed_end)

        # Need to handle the change in scope within the lambda
        # expression.  We won't use 'new_params'.  They aren't relavent
        # after an expansion, this won't be used.
        new_params, inner_expr_map, inner_assumptions, inner_reservations \
            = self.lambda_map._innerScopeSub(exprMap, relabelMap,
                  reservedVars, assumptions, new_requirements)

        # Get sorted substitution parameter start and end
        # values demarcating how the entry array must be split up for
        # each axis.
        all_entry_starts = [None] * ndims
        all_entry_ends = [None] * ndims
        do_expansion = False
        for axis in range(ndims):
            try:
                empty_eq = Equals(dist_add(subbed_end[axis], one),
                                  subbed_start[axis])
                try:
                    # Check if this is an empty iteration which
                    # happens when end+1=start.
                    empty_eq.prove(assumptions, automation=False)
                    all_entry_starts[axis] = all_entry_ends[axis] = []
                    do_expansion = True
                    continue
                except ProofFailure:
                    pass
                param_vals = \
                    iter_body._iterSubParamVals(axis, iter_params[axis],
                                                subbed_start[axis],
                                                subbed_end[axis],
                                                inner_expr_map, relabelMap,
                                                inner_reservations,
                                                inner_assumptions,
                                                new_requirements)
                assert param_vals[0] == subbed_start[axis]
                if param_vals[-1] != subbed_end[axis]:
                    # The last of the param_vals should either be
                    # subbed_end[axis] or known to be
                    # subbed_end[axis]+1.  Let's double-check.
                    eq = Equals(dist_add(subbed_end[axis], one),
                                param_vals[-1])
                    eq.prove(assumptions, automation=False)
                # Populate the entry starts and ends using the
                # param_vals which indicate that start of each contained
                # entry plus the end of this iteration.
                all_entry_starts[axis] = []
                all_entry_ends[axis] = []
                for left, right in zip(param_vals[:-1], param_vals[1:]):
                    all_entry_starts[axis].append(left)
                    try:
                        eq = Equals(dist_add(left, one), right)
                        eq.prove(assumptions, automation=False)
                        new_requirements.append(
                            eq.prove(assumptions, automation=False))
                        # Simple single-entry case: the start and end
                        # are the same.
                        entry_end = left
                    except:
                        # Not the simple case; perform the positive
                        # integrality check.
                        requirement = InSet(dist_subtract(right, left),
                                            Naturals)
                        # Knowing the simplification may help prove the
                        # requirement.
                        _simplifiedCoord(requirement, assumptions, [])
                        try:
                            new_requirements.append(
                                requirement.prove(assumptions))
                        except ProofFailure as e:
                            raise IterationError("Failed to prove requirement "
                                                 "%s:\n%s" % (requirement, e))
                        if right == subbed_end[axis]:
                            # This last entry is the inclusive end
                            # rather than past the end, so it is an
                            # exception.
                            entry_end = right
                        else:
                            # Subtract one from the start of the next
                            # entyr to get the end of this entry.
                            entry_end = dist_subtract(right, one)
                            entry_end = _simplifiedCoord(
                                entry_end, assumptions, requirements)
                    all_entry_ends[axis].append(entry_end)
                # See if we should add the end value as an extra
                # singular entry.  If param_vals[-1] is at the inclusive
                # end, then we have a singular final entry.
                if param_vals[-1] == subbed_end[axis]:
                    end_val = subbed_end[axis]
                    all_entry_starts[axis].append(end_val)
                    all_entry_ends[axis].append(end_val)
                else:
                    # Otherwise, the last param_val will be one after
                    # the inclusive end which we will want to use below
                    # when building the last iteration entry.
                    all_entry_starts[axis].append(param_vals[-1])
                do_expansion = True
            except EmptyIterException:
                # Indexing over a negative or empty range.  The only way this
                # should be allowed is if subbed_end+1=subbed_start.
                Equals(dist_add(subbed_end[axis], one),
                       subbed_start[axis]).prove(assumptions)
                all_entry_starts[axis] = all_entry_ends[axis] = []
                do_expansion = True
            except _NoExpandedIteration:
                pass

        if do_expansion:
            # There are Indexed sub-Expressions whose variable is
            # being replaced with a Composite, so let us
            # expand the iteration for all of the relevant
            # iteration ranges.
            # Sort the argument value ranges.

            # We must have "substition parameter values" along each
            # axis:
            if None in all_entry_starts or None in all_entry_ends:
                raise IterationError("Must expand all axes or none of the "
                                     "axes, when substituting %s" % str(self))

            # Generate the expanded tuple/array as the substition
            # of 'self'.
            shape = [len(all_entry_ends[axis]) for axis in range(ndims)]
            entries = ExprArray.make_empty_entries(shape)
            indices_by_axis = [range(extent) for extent in shape]
            #print('shape', shape, 'indices_by_axis', indices_by_axis, 'sub_param_vals', sub_param_vals)

            extended_inner_assumptions = list(inner_assumptions)
            for axis_starts in all_entry_starts:
                # Generate assumptions that order the
                # successive entry start parameter values
                # must be natural numbers. (This is a requirement for
                # iteration instances and is a simple fact of
                # succession for single entries.)
                extended_inner_assumptions.extend(
                    _generateCoordOrderAssumptions(axis_starts))

            # Maintain lists of parameter values that come before each given entry.
            #prev_param_vals = [[] for axis in range(ndims)]

            # Iterate over each of the new entries, obtaining indices
            # into sub_param_vals for the start parameters of the entry.
            for entry_indices in itertools.product(*indices_by_axis):
                entry_starts = [axis_starts[i] for axis_starts, i in \
                                zip(all_entry_starts, entry_indices)]
                entry_ends = [axis_ends[i] for axis_ends, i in \
                                zip(all_entry_ends, entry_indices)]

                is_singular_entry = True
                for entry_start, entry_end in zip(entry_starts, entry_ends):
                    # Note that empty ranges will be skipped because
                    # equivalent parameter values should be skipped in
                    # the param_vals above.
                    if entry_start != entry_end:
                        # Not a singular entry along this axis, so
                        # it is not a singular entry.  We must do an
                        # iteration for this entry.
                        is_singular_entry = False

                if is_singular_entry:
                    # Single element entry.

                    # Generate the entry by making appropriate
                    # parameter substitutions for the iteration body.
                    entry_inner_expr_map = dict(inner_expr_map)
                    entry_inner_expr_map.update({
                        param: arg
                        for param, arg in zip(iter_params, entry_starts)
                    })
                    for param in iter_params:
                        relabelMap.pop(param, None)
                    entry = iter_body.substituted(entry_inner_expr_map,
                                                  relabelMap,
                                                  inner_reservations,
                                                  extended_inner_assumptions,
                                                  new_requirements)
                else:
                    # Iteration entry.
                    # Shift the iteration parameter so that the
                    # iteration will have the same start-indices
                    # for this sub-range (like shifting a viewing
                    # window, moving the origin to the start of the
                    # sub-range).

                    # Generate "safe" new parameters (the Variables are
                    # not used for anything that might conflict).
                    # Avoid using free variables from these expressions:
                    unsafe_var_exprs = [self]
                    unsafe_var_exprs.extend(exprMap.values())
                    unsafe_var_exprs.extend(relabelMap.values())
                    unsafe_var_exprs.extend(entry_starts)
                    unsafe_var_exprs.extend(entry_ends)
                    new_params = safeDummyVars(ndims, *unsafe_var_exprs)

                    # Make assumptions that places the parameter(s) in the
                    # appropriate range and at an integral coordinate position.
                    # Note, it is possible that this actually represents an
                    # empty range and that these assumptions are contradictory;
                    # but this still suits our purposes regardless.
                    # Also, we will choose to shift the parameter so it
                    # starts at the start index of the iteration.
                    range_expr_map = dict(inner_expr_map)
                    range_assumptions = []
                    shifted_entry_ends = []
                    for axis, (param, new_param, entry_start, entry_end) \
                            in enumerate(zip(iter_params, new_params,
                                             entry_starts, entry_ends)):
                        start_idx = self.start_indices[axis]
                        shift = dist_subtract(entry_start, start_idx)
                        shift = _simplifiedCoord(shift, assumptions,
                                                 new_requirements)
                        if shift != zero:
                            shifted_param = dist_add(new_param, shift)
                        else:
                            shifted_param = new_param
                        range_expr_map[param] = shifted_param
                        shifted_end = dist_subtract(entry_end, shift)
                        shifted_end = _simplifiedCoord(shifted_end,
                                                       assumptions,
                                                       new_requirements)
                        shifted_entry_ends.append(shifted_end)
                        assumption = InSet(new_param, Integers)
                        range_assumptions.append(assumption)
                        assumption = LessEq(entry_start, shifted_param)
                        range_assumptions.append(assumption)
                        # Assume differences with each of the previous
                        # range starts are natural numbers as should be
                        # the case given requirements that have been
                        # met.
                        next_index = entry_indices[axis] + 1
                        prev_starts = all_entry_starts[axis][:next_index]
                        for prev_start in prev_starts:
                            assumption = InSet(
                                dist_subtract(shifted_param, prev_start),
                                Naturals)
                            range_assumptions.append(assumption)
                        next_start = all_entry_starts[axis][next_index]
                        assumption = Less(shifted_param, next_start)
                        range_assumptions.append(assumption)

                    # Perform the substitution.
                    # The fact that our "new parameters" are "safe"
                    # alleviates the need to reserve anything extra.
                    range_lambda_body = iter_body.substituted(
                        range_expr_map, relabelMap, reservedVars,
                        extended_inner_assumptions + range_assumptions,
                        new_requirements)
                    # Any requirements that involve the new parameters
                    # are a direct consequence of the iteration range
                    # and are not external requirements:
                    new_requirements = \
                        [requirement for requirement in new_requirements
                         if requirement.freeVars().isdisjoint(new_params)]
                    entry = Iter(new_params, range_lambda_body,
                                 self.start_indices, shifted_entry_ends)
                # Set this entry in the entries array.
                ExprArray.set_entry(entries, entry_indices, entry)
                '''      
                    # Iteration entry.
                    # Shift the iteration parameter so that the 
                    # iteration will have the same start-indices
                    # for this sub-range (like shifting a viewing 
                    # window, moving the origin to the start of the 
                    # sub-range).

                    # Generate "safe" new parameters (the Variables are
                    # not used for anything that might conflict).
                    # Avoid using free variables from these expressions:
                    unsafe_var_exprs = [self]
                    unsafe_var_exprs.extend(exprMap.values())
                    unsafe_var_exprs.extend(relabelMap.values())
                    unsafe_var_exprs.extend(entry_start_vals)
                    unsafe_var_exprs.extend(entry_end_vals)
                    new_params = safeDummyVars(len(iter_params), 
                                               *unsafe_var_exprs)
                    
                    # Make the appropriate substitution mapping
                    # and add appropriate assumptions for the iteration
                    # parameter(s).
                    range_expr_map = dict(inner_expr_map)
                    range_assumptions = []
                    for start_idx, param, new_param, range_start, range_end \
                            in zip(subbed_start, iter_params, new_params, 
                                   entry_start_vals, entry_end_vals):
                        shifted_param = Add(new_param, subtract(range_start, start_idx))
                        shifted_param = _simplifiedCoord(shifted_param, assumptions,
                                                         requirements)
                        range_expr_map[param] = shifted_param
                        # Include assumptions that the parameters are 
                        # in the proper range.
                        assumption = LessEq(start_idx, new_param)
                        range_assumptions.append(assumption)
                        assumption = InSet(subtract(new_param, start_idx), Naturals)
                        #assumption = LessEq(new_param,
                        #                    subtract(range_end, start_idx))
                        assumption = LessEq(new_param, range_end)
                        range_assumptions.append(assumption)
                    
                    # Perform the substitution.
                    # The fact that our "new parameters" are "safe" 
                    # alleviates the need to reserve anything extra.
                    range_lambda_body = iter_body.substituted(range_expr_map, 
                        relabelMap, reservedVars, 
                        inner_assumptions+range_assumptions, new_requirements)
                    # Any requirements that involve the new parameters 
                    # are a direct consequence of the iteration range 
                    # and are not external requirements:
                    new_requirements = \
                        [requirement for requirement in new_requirements 
                         if requirement.freeVars().isdisjoint(new_params)]
                    range_lambda_map = Lambda(new_params, range_lambda_body)
                    # Obtain the appropriate end indices.
                    end_indices = \
                        [_simplifiedCoord(subtract(range_end, start_idx), 
                                          assumptions, new_requirements) 
                         for start_idx, range_end in zip(subbed_start, 
                                                          entry_end_vals)]
                    entry = Iter(range_lambda_map, subbed_start, end_indices)
                # Set this entry in the entries array.
                ExprArray.set_entry(entries, entry_start_indices, entry)
                '''
            subbed_self = compositeExpression(entries)
        else:
            # No Indexed sub-Expressions whose variable is
            # replaced with a Composite, so let us not expand the
            # iteration.  Just do an ordinary substitution.
            new_requirements = []  # Fresh new requirements.
            subbed_map = self.lambda_map.substituted(exprMap, relabelMap,
                                                     reservedVars, assumptions,
                                                     new_requirements)
            subbed_self = Iter(subbed_map.parameters, subbed_map.body,
                               subbed_start, subbed_end)

        for requirement in new_requirements:
            # Make sure requirements don't use reserved variable in a
            # nested scope.
            requirement._restrictionChecked(reservedVars)
        if requirements is not None:
            requirements += new_requirements  # append new requirements

        return subbed_self