def __init__(self, parameter_or_parameters, body, conditions=tuple(), styles=None, requirements=tuple()): ''' Initialize a Lambda function expression given parameter(s) and a body. Each parameter must be a Variable. When there is a single parameter, there will be a 'parameter' attribute. Either way, there will be a 'parameters' attribute that bundles the one or more Variables into an ExprList. The 'body' attribute will be the lambda function body Expression (that may or may not be a Composite). Zero or more expressions may be provided. ''' from proveit._core_.expression.composite import compositeExpression, singleOrCompositeExpression, Iter if styles is None: styles = dict() self.parameters = compositeExpression(parameter_or_parameters) parameterVars = [getParamVar(parameter) for parameter in self.parameters] if len(self.parameters) == 1: # has a single parameter self.parameter = self.parameters[0] self.parameter_or_parameters = self.parameter else: self.parameter_or_parameters = self.parameters self.parameterVars = tuple(parameterVars) self.parameterVarSet = frozenset(parameterVars) if len(self.parameterVarSet) != len(self.parameters): raise ValueError('Lambda parameters Variables must be unique with respect to each other.') body = singleOrCompositeExpression(body) if not isinstance(body, Expression): raise TypeError('A Lambda body must be of type Expression') if isinstance(body, Iter): raise TypeError('An Iter must be within an ExprList or ExprTensor, not directly as a Lambda body') self.body = body self.conditions = compositeExpression(conditions) for requirement in self.body.getRequirements(): if not self.parameterVarSet.isdisjoint(requirement.freeVars()): raise LambdaError("Cannot generate a Lambda expression with parameter variables involved in Lambda body requirements: " + str(requirement)) sub_exprs = [self.parameter_or_parameters, self.body] if len(self.conditions)>0: sub_exprs.append(self.conditions) # Create a "generic" version (if not already) of the Lambda expression since the # choice of parameter labeling is irrelevant. generic_body_vars = self.body._genericExpr.usedVars() generic_condition_vars = self.conditions._genericExpr.usedVars() used_generic_vars = generic_body_vars.union(generic_condition_vars) generic_params = tuple(safeDummyVars(len(self.parameterVars), *(used_generic_vars-self.parameterVarSet))) if generic_params != self.parameterVars: relabel_map = {param:generic_param for param, generic_param in zip(self.parameterVars, generic_params)} # temporarily disable automation during the relabeling process prev_automation = defaults.automation defaults.automation = False generic_parameters = self.parameters._genericExpr.relabeled(relabel_map) generic_body = self.body._genericExpr.relabeled(relabel_map) generic_conditions = self.conditions._genericExpr.relabeled(relabel_map) self._genericExpr = Lambda(generic_parameters, generic_body, generic_conditions, styles=dict(styles), requirements=requirements) defaults.automation = prev_automation # restore to previous value Expression.__init__(self, ['Lambda'], sub_exprs, styles=styles, requirements=requirements)
def safeDummyVars(self, n): from proveit._core_.expression.label.var import safeDummyVars return safeDummyVars(n, self)
def substituted(self, exprMap, relabelMap=None, reservedVars=None, assumptions=USE_DEFAULTS, requirements=None): ''' Returns this expression with the substitutions made according to exprMap and/or relabeled according to relabelMap. Attempt to automatically expand the iteration if any Indexed sub-expressions substitute their variable for a composite (list or tensor). Indexed should index variables that represent composites, but substituting the composite is a signal that an outer iteration should be expanded. An exception is raised if this fails. ''' from .composite import _generateCoordOrderAssumptions from proveit import ProofFailure, ExprArray from proveit.logic import Equals, InSet from proveit.number import Less, LessEq, dist_add, \ zero, one, dist_subtract, Naturals, Integers from .composite import _simplifiedCoord from proveit._core_.expression.expr import _NoExpandedIteration from proveit._core_.expression.label.var import safeDummyVars self._checkRelabelMap(relabelMap) if relabelMap is None: relabelMap = dict() assumptions = defaults.checkedAssumptions(assumptions) new_requirements = [] iter_params = self.lambda_map.parameters iter_body = self.lambda_map.body ndims = self.ndims subbed_start = self.start_indices.substituted(exprMap, relabelMap, reservedVars, assumptions, new_requirements) subbed_end = self.end_indices.substituted(exprMap, relabelMap, reservedVars, assumptions, new_requirements) #print("iteration substituted", self, subbed_start, subbed_end) # Need to handle the change in scope within the lambda # expression. We won't use 'new_params'. They aren't relavent # after an expansion, this won't be used. new_params, inner_expr_map, inner_assumptions, inner_reservations \ = self.lambda_map._innerScopeSub(exprMap, relabelMap, reservedVars, assumptions, new_requirements) # Get sorted substitution parameter start and end # values demarcating how the entry array must be split up for # each axis. all_entry_starts = [None] * ndims all_entry_ends = [None] * ndims do_expansion = False for axis in range(ndims): try: empty_eq = Equals(dist_add(subbed_end[axis], one), subbed_start[axis]) try: # Check if this is an empty iteration which # happens when end+1=start. empty_eq.prove(assumptions, automation=False) all_entry_starts[axis] = all_entry_ends[axis] = [] do_expansion = True continue except ProofFailure: pass param_vals = \ iter_body._iterSubParamVals(axis, iter_params[axis], subbed_start[axis], subbed_end[axis], inner_expr_map, relabelMap, inner_reservations, inner_assumptions, new_requirements) assert param_vals[0] == subbed_start[axis] if param_vals[-1] != subbed_end[axis]: # The last of the param_vals should either be # subbed_end[axis] or known to be # subbed_end[axis]+1. Let's double-check. eq = Equals(dist_add(subbed_end[axis], one), param_vals[-1]) eq.prove(assumptions, automation=False) # Populate the entry starts and ends using the # param_vals which indicate that start of each contained # entry plus the end of this iteration. all_entry_starts[axis] = [] all_entry_ends[axis] = [] for left, right in zip(param_vals[:-1], param_vals[1:]): all_entry_starts[axis].append(left) try: eq = Equals(dist_add(left, one), right) eq.prove(assumptions, automation=False) new_requirements.append( eq.prove(assumptions, automation=False)) # Simple single-entry case: the start and end # are the same. entry_end = left except: # Not the simple case; perform the positive # integrality check. requirement = InSet(dist_subtract(right, left), Naturals) # Knowing the simplification may help prove the # requirement. _simplifiedCoord(requirement, assumptions, []) try: new_requirements.append( requirement.prove(assumptions)) except ProofFailure as e: raise IterationError("Failed to prove requirement " "%s:\n%s" % (requirement, e)) if right == subbed_end[axis]: # This last entry is the inclusive end # rather than past the end, so it is an # exception. entry_end = right else: # Subtract one from the start of the next # entyr to get the end of this entry. entry_end = dist_subtract(right, one) entry_end = _simplifiedCoord( entry_end, assumptions, requirements) all_entry_ends[axis].append(entry_end) # See if we should add the end value as an extra # singular entry. If param_vals[-1] is at the inclusive # end, then we have a singular final entry. if param_vals[-1] == subbed_end[axis]: end_val = subbed_end[axis] all_entry_starts[axis].append(end_val) all_entry_ends[axis].append(end_val) else: # Otherwise, the last param_val will be one after # the inclusive end which we will want to use below # when building the last iteration entry. all_entry_starts[axis].append(param_vals[-1]) do_expansion = True except EmptyIterException: # Indexing over a negative or empty range. The only way this # should be allowed is if subbed_end+1=subbed_start. Equals(dist_add(subbed_end[axis], one), subbed_start[axis]).prove(assumptions) all_entry_starts[axis] = all_entry_ends[axis] = [] do_expansion = True except _NoExpandedIteration: pass if do_expansion: # There are Indexed sub-Expressions whose variable is # being replaced with a Composite, so let us # expand the iteration for all of the relevant # iteration ranges. # Sort the argument value ranges. # We must have "substition parameter values" along each # axis: if None in all_entry_starts or None in all_entry_ends: raise IterationError("Must expand all axes or none of the " "axes, when substituting %s" % str(self)) # Generate the expanded tuple/array as the substition # of 'self'. shape = [len(all_entry_ends[axis]) for axis in range(ndims)] entries = ExprArray.make_empty_entries(shape) indices_by_axis = [range(extent) for extent in shape] #print('shape', shape, 'indices_by_axis', indices_by_axis, 'sub_param_vals', sub_param_vals) extended_inner_assumptions = list(inner_assumptions) for axis_starts in all_entry_starts: # Generate assumptions that order the # successive entry start parameter values # must be natural numbers. (This is a requirement for # iteration instances and is a simple fact of # succession for single entries.) extended_inner_assumptions.extend( _generateCoordOrderAssumptions(axis_starts)) # Maintain lists of parameter values that come before each given entry. #prev_param_vals = [[] for axis in range(ndims)] # Iterate over each of the new entries, obtaining indices # into sub_param_vals for the start parameters of the entry. for entry_indices in itertools.product(*indices_by_axis): entry_starts = [axis_starts[i] for axis_starts, i in \ zip(all_entry_starts, entry_indices)] entry_ends = [axis_ends[i] for axis_ends, i in \ zip(all_entry_ends, entry_indices)] is_singular_entry = True for entry_start, entry_end in zip(entry_starts, entry_ends): # Note that empty ranges will be skipped because # equivalent parameter values should be skipped in # the param_vals above. if entry_start != entry_end: # Not a singular entry along this axis, so # it is not a singular entry. We must do an # iteration for this entry. is_singular_entry = False if is_singular_entry: # Single element entry. # Generate the entry by making appropriate # parameter substitutions for the iteration body. entry_inner_expr_map = dict(inner_expr_map) entry_inner_expr_map.update({ param: arg for param, arg in zip(iter_params, entry_starts) }) for param in iter_params: relabelMap.pop(param, None) entry = iter_body.substituted(entry_inner_expr_map, relabelMap, inner_reservations, extended_inner_assumptions, new_requirements) else: # Iteration entry. # Shift the iteration parameter so that the # iteration will have the same start-indices # for this sub-range (like shifting a viewing # window, moving the origin to the start of the # sub-range). # Generate "safe" new parameters (the Variables are # not used for anything that might conflict). # Avoid using free variables from these expressions: unsafe_var_exprs = [self] unsafe_var_exprs.extend(exprMap.values()) unsafe_var_exprs.extend(relabelMap.values()) unsafe_var_exprs.extend(entry_starts) unsafe_var_exprs.extend(entry_ends) new_params = safeDummyVars(ndims, *unsafe_var_exprs) # Make assumptions that places the parameter(s) in the # appropriate range and at an integral coordinate position. # Note, it is possible that this actually represents an # empty range and that these assumptions are contradictory; # but this still suits our purposes regardless. # Also, we will choose to shift the parameter so it # starts at the start index of the iteration. range_expr_map = dict(inner_expr_map) range_assumptions = [] shifted_entry_ends = [] for axis, (param, new_param, entry_start, entry_end) \ in enumerate(zip(iter_params, new_params, entry_starts, entry_ends)): start_idx = self.start_indices[axis] shift = dist_subtract(entry_start, start_idx) shift = _simplifiedCoord(shift, assumptions, new_requirements) if shift != zero: shifted_param = dist_add(new_param, shift) else: shifted_param = new_param range_expr_map[param] = shifted_param shifted_end = dist_subtract(entry_end, shift) shifted_end = _simplifiedCoord(shifted_end, assumptions, new_requirements) shifted_entry_ends.append(shifted_end) assumption = InSet(new_param, Integers) range_assumptions.append(assumption) assumption = LessEq(entry_start, shifted_param) range_assumptions.append(assumption) # Assume differences with each of the previous # range starts are natural numbers as should be # the case given requirements that have been # met. next_index = entry_indices[axis] + 1 prev_starts = all_entry_starts[axis][:next_index] for prev_start in prev_starts: assumption = InSet( dist_subtract(shifted_param, prev_start), Naturals) range_assumptions.append(assumption) next_start = all_entry_starts[axis][next_index] assumption = Less(shifted_param, next_start) range_assumptions.append(assumption) # Perform the substitution. # The fact that our "new parameters" are "safe" # alleviates the need to reserve anything extra. range_lambda_body = iter_body.substituted( range_expr_map, relabelMap, reservedVars, extended_inner_assumptions + range_assumptions, new_requirements) # Any requirements that involve the new parameters # are a direct consequence of the iteration range # and are not external requirements: new_requirements = \ [requirement for requirement in new_requirements if requirement.freeVars().isdisjoint(new_params)] entry = Iter(new_params, range_lambda_body, self.start_indices, shifted_entry_ends) # Set this entry in the entries array. ExprArray.set_entry(entries, entry_indices, entry) ''' # Iteration entry. # Shift the iteration parameter so that the # iteration will have the same start-indices # for this sub-range (like shifting a viewing # window, moving the origin to the start of the # sub-range). # Generate "safe" new parameters (the Variables are # not used for anything that might conflict). # Avoid using free variables from these expressions: unsafe_var_exprs = [self] unsafe_var_exprs.extend(exprMap.values()) unsafe_var_exprs.extend(relabelMap.values()) unsafe_var_exprs.extend(entry_start_vals) unsafe_var_exprs.extend(entry_end_vals) new_params = safeDummyVars(len(iter_params), *unsafe_var_exprs) # Make the appropriate substitution mapping # and add appropriate assumptions for the iteration # parameter(s). range_expr_map = dict(inner_expr_map) range_assumptions = [] for start_idx, param, new_param, range_start, range_end \ in zip(subbed_start, iter_params, new_params, entry_start_vals, entry_end_vals): shifted_param = Add(new_param, subtract(range_start, start_idx)) shifted_param = _simplifiedCoord(shifted_param, assumptions, requirements) range_expr_map[param] = shifted_param # Include assumptions that the parameters are # in the proper range. assumption = LessEq(start_idx, new_param) range_assumptions.append(assumption) assumption = InSet(subtract(new_param, start_idx), Naturals) #assumption = LessEq(new_param, # subtract(range_end, start_idx)) assumption = LessEq(new_param, range_end) range_assumptions.append(assumption) # Perform the substitution. # The fact that our "new parameters" are "safe" # alleviates the need to reserve anything extra. range_lambda_body = iter_body.substituted(range_expr_map, relabelMap, reservedVars, inner_assumptions+range_assumptions, new_requirements) # Any requirements that involve the new parameters # are a direct consequence of the iteration range # and are not external requirements: new_requirements = \ [requirement for requirement in new_requirements if requirement.freeVars().isdisjoint(new_params)] range_lambda_map = Lambda(new_params, range_lambda_body) # Obtain the appropriate end indices. end_indices = \ [_simplifiedCoord(subtract(range_end, start_idx), assumptions, new_requirements) for start_idx, range_end in zip(subbed_start, entry_end_vals)] entry = Iter(range_lambda_map, subbed_start, end_indices) # Set this entry in the entries array. ExprArray.set_entry(entries, entry_start_indices, entry) ''' subbed_self = compositeExpression(entries) else: # No Indexed sub-Expressions whose variable is # replaced with a Composite, so let us not expand the # iteration. Just do an ordinary substitution. new_requirements = [] # Fresh new requirements. subbed_map = self.lambda_map.substituted(exprMap, relabelMap, reservedVars, assumptions, new_requirements) subbed_self = Iter(subbed_map.parameters, subbed_map.body, subbed_start, subbed_end) for requirement in new_requirements: # Make sure requirements don't use reserved variable in a # nested scope. requirement._restrictionChecked(reservedVars) if requirements is not None: requirements += new_requirements # append new requirements return subbed_self