def _slice(expr, args, kwargs, context): sub, start, length = args[0], kwargs['start'], kwargs['len'] if not are_units_compatible(start.typ, BaseType('int128')): raise TypeMismatchException("Type for slice start index must be a unitless number") # Expression representing the length of the slice if not are_units_compatible(length.typ, BaseType('int128')): raise TypeMismatchException("Type for slice length must be a unitless number") # Node representing the position of the output in memory np = context.new_placeholder(ByteArrayType(maxlen=sub.typ.maxlen + 32)) placeholder_node = LLLnode.from_list(np, typ=sub.typ, location='memory') placeholder_plus_32_node = LLLnode.from_list(np + 32, typ=sub.typ, location='memory') # Copies over bytearray data if sub.location == 'storage': adj_sub = LLLnode.from_list( ['add', ['sha3_32', sub], ['add', ['div', '_start', 32], 1]], typ=sub.typ, location=sub.location ) else: adj_sub = LLLnode.from_list( ['add', sub, ['add', ['sub', '_start', ['mod', '_start', 32]], 32]], typ=sub.typ, location=sub.location ) copier = make_byte_slice_copier(placeholder_plus_32_node, adj_sub, ['add', '_length', 32], sub.typ.maxlen) # New maximum length in the type of the result newmaxlen = length.value if not len(length.args) else sub.typ.maxlen maxlen = ['mload', Expr(sub, context=context).lll_node] # Retrieve length of the bytes. out = ['with', '_start', start, ['with', '_length', length, ['with', '_opos', ['add', placeholder_node, ['mod', '_start', 32]], ['seq', ['assert', ['le', ['add', '_start', '_length'], maxlen]], copier, ['mstore', '_opos', '_length'], '_opos']]]] return LLLnode.from_list(out, typ=ByteArrayType(newmaxlen), location='memory', pos=getpos(expr))
def minmax(expr, args, kwargs, context, is_min): def _can_compare_with_uint256(operand): if operand.typ.typ == 'uint256': return True elif operand.typ.typ == 'int128' and operand.typ.is_literal and SizeLimits.in_bounds('uint256', operand.value): # noqa: E501 return True return False left, right = args[0], args[1] if not are_units_compatible(left.typ, right.typ) and not are_units_compatible(right.typ, left.typ): # noqa: E501 raise TypeMismatchException("Units must be compatible", expr) if left.typ.typ == 'uint256': comparator = 'gt' if is_min else 'lt' else: comparator = 'sgt' if is_min else 'slt' if left.typ.typ == right.typ.typ: o = ['if', [comparator, '_l', '_r'], '_r', '_l'] otyp = left.typ otyp.is_literal = False elif _can_compare_with_uint256(left) and _can_compare_with_uint256(right): o = ['if', [comparator, '_l', '_r'], '_r', '_l'] if right.typ.typ == 'uint256': otyp = right.typ else: otyp = left.typ otyp.is_literal = False else: raise TypeMismatchException( "Minmax types incompatible: %s %s" % (left.typ.typ, right.typ.typ) ) return LLLnode.from_list( ['with', '_l', left, ['with', '_r', right, o]], typ=otyp, pos=getpos(expr), )
def minmax(expr, args, kwargs, context, is_min): left, right = args[0], args[1] if not are_units_compatible(left.typ, right.typ) and not are_units_compatible( right.typ, left.typ): raise TypeMismatchException("Units must be compatible", expr) if left.typ.typ == 'uint256': comparator = 'gt' if is_min else 'lt' else: comparator = 'sgt' if is_min else 'slt' if left.typ.typ == right.typ.typ: o = ['if', [comparator, '_l', '_r'], '_r', '_l'] otyp = left.typ otyp.is_literal = False # elif left.typ.typ == 'int128' and right.typ.typ == 'decimal': # o = ['if', [comparator, ['mul', '_l', DECIMAL_DIVISOR], '_r'], '_r', ['mul', '_l', DECIMAL_DIVISOR]] # otyp = 'decimal' # elif left.typ.typ == 'decimal' and right.typ.typ == 'int128': # o = ['if', [comparator, '_l', ['mul', '_r', DECIMAL_DIVISOR]], ['mul', '_r', DECIMAL_DIVISOR], '_l'] # otyp = 'decimal' else: raise TypeMismatchException("Minmax types incompatible: %s %s" % (left.typ.typ, right.typ.typ)) return LLLnode.from_list(['with', '_l', left, ['with', '_r', right, o]], typ=otyp, pos=getpos(expr))
def base_type_conversion(orig, frm, to, pos=None): orig = unwrap_location(orig) if not isinstance(frm, (BaseType, NullType)) or not isinstance(to, BaseType): raise TypeMismatchException( "Base type conversion from or to non-base type: %r %r" % (frm, to), pos) elif is_base_type(frm, to.typ) and are_units_compatible(frm, to): return LLLnode(orig.value, orig.args, typ=to, add_gas_estimate=orig.add_gas_estimate) elif is_base_type(frm, 'int128') and is_base_type( to, 'decimal') and are_units_compatible(frm, to): return LLLnode.from_list(['mul', orig, DECIMAL_DIVISOR], typ=BaseType('decimal', to.unit, to.positional)) elif is_base_type(frm, 'uint256') and is_base_type( to, 'int128') and are_units_compatible(frm, to): return LLLnode.from_list( ['uclample', orig, ['mload', MemoryPositions.MAXNUM]], typ=BaseType('int128')) elif isinstance(frm, NullType): if to.typ not in ('int128', 'bool', 'uint256', 'address', 'bytes32', 'decimal'): # This is only to future proof the use of base_type_conversion. raise TypeMismatchException( "Cannot convert null-type object to type %r" % to, pos) # pragma: no cover return LLLnode.from_list(0, typ=to) else: raise TypeMismatchException( "Typecasting from base type %r to %r unavailable" % (frm, to), pos)
def minmax(expr, args, kwargs, context, comparator): def _can_compare_with_uint256(operand): if operand.typ.typ == 'uint256': return True elif operand.typ.typ == 'int128' and operand.typ.is_literal and SizeLimits.in_bounds('uint256', operand.value): # noqa: E501 return True return False left, right = args[0], args[1] if not are_units_compatible(left.typ, right.typ) and not are_units_compatible(right.typ, left.typ): # noqa: E501 raise TypeMismatchException("Units must be compatible", expr) if left.typ.typ == right.typ.typ: if left.typ.typ != 'uint256': # if comparing like types that are not uint256, use SLT or SGT comparator = f's{comparator}' o = ['if', [comparator, '_l', '_r'], '_r', '_l'] otyp = left.typ otyp.is_literal = False elif _can_compare_with_uint256(left) and _can_compare_with_uint256(right): o = ['if', [comparator, '_l', '_r'], '_r', '_l'] if right.typ.typ == 'uint256': otyp = right.typ else: otyp = left.typ otyp.is_literal = False else: raise TypeMismatchException( f"Minmax types incompatible: {left.typ.typ} {right.typ.typ}" ) return LLLnode.from_list( ['with', '_l', left, ['with', '_r', right, o]], typ=otyp, pos=getpos(expr), )
def base_type_conversion(orig, frm, to, pos): orig = unwrap_location(orig) if getattr(frm, 'is_literal', False) and frm.typ in ('int128', 'uint256') and not SizeLimits.in_bounds(frm.typ, orig.value): raise InvalidLiteralException("Number out of range: " + str(orig.value), pos) # # Valid bytes[32] to bytes32 assignment. # if isinstance(to, BaseType) and to.typ = 'bytes32' and isinstance(frm, ByteArrayType) and frm.maxlen == 32: # return LLLnode(orig.value, orig.args, typ=to, add_gas_estimate=orig.add_gas_estimate) if not isinstance(frm, (BaseType, NullType)) or not isinstance(to, BaseType): raise TypeMismatchException("Base type conversion from or to non-base type: %r %r" % (frm, to), pos) elif is_base_type(frm, to.typ) and are_units_compatible(frm, to): return LLLnode(orig.value, orig.args, typ=to, add_gas_estimate=orig.add_gas_estimate) elif is_base_type(frm, 'int128') and is_base_type(to, 'decimal') and are_units_compatible(frm, to): return LLLnode.from_list(['mul', orig, DECIMAL_DIVISOR], typ=BaseType('decimal', to.unit, to.positional)) elif isinstance(frm, NullType): if to.typ not in ('int128', 'bool', 'uint256', 'address', 'bytes32', 'decimal'): # This is only to future proof the use of base_type_conversion. raise TypeMismatchException("Cannot convert null-type object to type %r" % to, pos) # pragma: no cover return LLLnode.from_list(0, typ=to) elif isinstance(to, ContractType) and frm.typ == 'address': return LLLnode(orig.value, orig.args, typ=to, add_gas_estimate=orig.add_gas_estimate) # Integer literal conversion. elif (frm.typ, to.typ, frm.is_literal) == ('int128', 'uint256', True): return LLLnode(orig.value, orig.args, typ=to, add_gas_estimate=orig.add_gas_estimate) else: raise TypeMismatchException("Typecasting from base type %r to %r unavailable" % (frm, to), pos)
def base_type_conversion(orig, frm, to, pos, in_function_call=False): orig = unwrap_location(orig) if getattr(frm, 'is_literal', False) and frm.typ in ('int128', 'uint256'): if not SizeLimits.in_bounds(frm.typ, orig.value): raise InvalidLiteralException("Number out of range: " + str(orig.value), pos) # Special Case: Literals in function calls should always convey unit type as well. if in_function_call and not (frm.unit == to.unit and frm.positional == to.positional): raise InvalidLiteralException("Function calls require explicit unit definitions on calls, expected %r" % to, pos) if not isinstance(frm, (BaseType, NullType)) or not isinstance(to, BaseType): raise TypeMismatchException("Base type conversion from or to non-base type: %r %r" % (frm, to), pos) elif is_base_type(frm, to.typ) and are_units_compatible(frm, to): return LLLnode(orig.value, orig.args, typ=to, add_gas_estimate=orig.add_gas_estimate) elif is_base_type(frm, 'int128') and is_base_type(to, 'decimal') and are_units_compatible(frm, to): return LLLnode.from_list(['mul', orig, DECIMAL_DIVISOR], typ=BaseType('decimal', to.unit, to.positional)) elif isinstance(frm, NullType): if to.typ not in ('int128', 'bool', 'uint256', 'address', 'bytes32', 'decimal'): # This is only to future proof the use of base_type_conversion. raise TypeMismatchException("Cannot convert null-type object to type %r" % to, pos) # pragma: no cover return LLLnode.from_list(0, typ=to) elif isinstance(to, ContractType) and frm.typ == 'address': return LLLnode(orig.value, orig.args, typ=to, add_gas_estimate=orig.add_gas_estimate) # Integer literal conversion. elif (frm.typ, to.typ, frm.is_literal) == ('int128', 'uint256', True): return LLLnode(orig.value, orig.args, typ=to, add_gas_estimate=orig.add_gas_estimate) else: raise TypeMismatchException("Typecasting from base type %r to %r unavailable" % (frm, to), pos)
def compare(self): left = Expr.parse_value_expr(self.expr.left, self.context) right = Expr.parse_value_expr(self.expr.comparators[0], self.context) if isinstance(self.expr.ops[0], ast.In) and \ isinstance(right.typ, ListType): if not are_units_compatible( left.typ, right.typ.subtype) and not are_units_compatible( right.typ.subtype, left.typ): raise TypeMismatchException( "Can't use IN comparison with different types!", self.expr) return self.build_in_comparator() else: if not are_units_compatible( left.typ, right.typ) and not are_units_compatible( right.typ, left.typ): raise TypeMismatchException( "Can't compare values with different units!", self.expr) if len(self.expr.ops) != 1: raise StructureException( "Cannot have a comparison with more than two elements", self.expr) if isinstance(self.expr.ops[0], ast.Gt): op = 'sgt' elif isinstance(self.expr.ops[0], ast.GtE): op = 'sge' elif isinstance(self.expr.ops[0], ast.LtE): op = 'sle' elif isinstance(self.expr.ops[0], ast.Lt): op = 'slt' elif isinstance(self.expr.ops[0], ast.Eq): op = 'eq' elif isinstance(self.expr.ops[0], ast.NotEq): op = 'ne' else: raise Exception("Unsupported comparison operator") if not is_numeric_type(left.typ) or not is_numeric_type(right.typ): if op not in ('eq', 'ne'): raise TypeMismatchException("Invalid type for comparison op", self.expr) ltyp, rtyp = left.typ.typ, right.typ.typ if ltyp == rtyp: return LLLnode.from_list([op, left, right], typ='bool', pos=getpos(self.expr)) elif ltyp == 'decimal' and rtyp == 'int128': return LLLnode.from_list( [op, left, ['mul', right, DECIMAL_DIVISOR]], typ='bool', pos=getpos(self.expr)) elif ltyp == 'int128' and rtyp == 'decimal': return LLLnode.from_list( [op, ['mul', left, DECIMAL_DIVISOR], right], typ='bool', pos=getpos(self.expr)) else: raise TypeMismatchException( "Unsupported types for comparison: %r %r" % (ltyp, rtyp), self.expr)
def compare(self): left = Expr.parse_value_expr(self.expr.left, self.context) right = Expr.parse_value_expr(self.expr.comparators[0], self.context) if isinstance(self.expr.ops[0], ast.In) and \ isinstance(right.typ, ListType): if not are_units_compatible( left.typ, right.typ.subtype) and not are_units_compatible( right.typ.subtype, left.typ): raise TypeMismatchException( "Can't use IN comparison with different types!", self.expr) return self.build_in_comparator() else: if not are_units_compatible( left.typ, right.typ) and not are_units_compatible( right.typ, left.typ): raise TypeMismatchException( "Can't compare values with different units!", self.expr) if len(self.expr.ops) != 1: raise StructureException( "Cannot have a comparison with more than two elements", self.expr) if isinstance(self.expr.ops[0], ast.Gt): op = 'sgt' elif isinstance(self.expr.ops[0], ast.GtE): op = 'sge' elif isinstance(self.expr.ops[0], ast.LtE): op = 'sle' elif isinstance(self.expr.ops[0], ast.Lt): op = 'slt' elif isinstance(self.expr.ops[0], ast.Eq): op = 'eq' elif isinstance(self.expr.ops[0], ast.NotEq): op = 'ne' else: raise Exception("Unsupported comparison operator") if not is_numeric_type(left.typ) or not is_numeric_type(right.typ): if op not in ('eq', 'ne'): raise TypeMismatchException("Invalid type for comparison op", self.expr) left_type, right_type = left.typ.typ, right.typ.typ if (left_type in ('decimal', 'int128') or right_type in ('decimal', 'int128')) and left_type != right_type: raise TypeMismatchException( 'Implicit conversion from {} to {} disallowed, please convert.' .format(left_type, right_type), self.expr) if left_type == right_type: return LLLnode.from_list([op, left, right], typ='bool', pos=getpos(self.expr)) else: raise TypeMismatchException( "Unsupported types for comparison: %r %r" % (left_type, right_type), self.expr)
def _slice(expr, args, kwargs, context): sub, start, length = args[0], kwargs['start'], kwargs['len'] if not are_units_compatible(start.typ, BaseType('int128')): raise TypeMismatchException("Type for slice start index must be a unitless number", expr) # Expression representing the length of the slice if not are_units_compatible(length.typ, BaseType('int128')): raise TypeMismatchException("Type for slice length must be a unitless number", expr) if is_base_type(sub.typ, 'bytes32'): if (start.typ.is_literal and length.typ.is_literal) and \ not (0 <= start.value + length.value <= 32): raise InvalidLiteralException( 'Invalid start / length values needs to be between 0 and 32.', expr, ) sub_typ_maxlen = 32 else: sub_typ_maxlen = sub.typ.maxlen # Get returntype string or bytes if isinstance(args[0].typ, ByteArrayType) or is_base_type(sub.typ, 'bytes32'): ReturnType = ByteArrayType else: ReturnType = StringType # Node representing the position of the output in memory np = context.new_placeholder(ReturnType(maxlen=sub_typ_maxlen + 32)) placeholder_node = LLLnode.from_list(np, typ=sub.typ, location='memory') placeholder_plus_32_node = LLLnode.from_list(np + 32, typ=sub.typ, location='memory') # Copies over bytearray data if sub.location == 'storage': adj_sub = LLLnode.from_list( ['add', ['sha3_32', sub], ['add', ['div', '_start', 32], 1]], typ=sub.typ, location=sub.location, ) else: adj_sub = LLLnode.from_list( ['add', sub, ['add', ['sub', '_start', ['mod', '_start', 32]], 32]], typ=sub.typ, location=sub.location, ) if is_base_type(sub.typ, 'bytes32'): adj_sub = LLLnode.from_list( sub.args[0], typ=sub.typ, location="memory" ) copier = make_byte_slice_copier( placeholder_plus_32_node, adj_sub, ['add', '_length', 32], sub_typ_maxlen, pos=getpos(expr), ) # New maximum length in the type of the result newmaxlen = length.value if not len(length.args) else sub_typ_maxlen if is_base_type(sub.typ, 'bytes32'): maxlen = 32 else: maxlen = ['mload', Expr(sub, context=context).lll_node] # Retrieve length of the bytes. out = [ 'with', '_start', start, [ 'with', '_length', length, [ 'with', '_opos', ['add', placeholder_node, ['mod', '_start', 32]], [ 'seq', ['assert', ['le', ['add', '_start', '_length'], maxlen]], copier, ['mstore', '_opos', '_length'], '_opos' ], ], ], ] return LLLnode.from_list(out, typ=ReturnType(newmaxlen), location='memory', pos=getpos(expr))
def parse_return(self): if self.context.return_type is None: if self.stmt.value: raise TypeMismatchException("Not expecting to return a value", self.stmt) return LLLnode.from_list(self.make_return_stmt(0, 0), typ=None, pos=getpos(self.stmt), valency=0) if not self.stmt.value: raise TypeMismatchException("Expecting to return a value", self.stmt) def zero_pad(bytez_placeholder, maxlen): zero_padder = LLLnode.from_list(['pass']) if maxlen > 0: zero_pad_i = self.context.new_placeholder( BaseType('uint256')) # Iterator used to zero pad memory. zero_padder = LLLnode.from_list( [ 'repeat', zero_pad_i, ['mload', bytez_placeholder], maxlen, [ 'seq', [ 'if', ['gt', ['mload', zero_pad_i], maxlen], 'break' ], # stay within allocated bounds [ 'mstore8', [ 'add', ['add', 32, bytez_placeholder], ['mload', zero_pad_i] ], 0 ] ] ], annotation="Zero pad") return zero_padder sub = Expr(self.stmt.value, self.context).lll_node self.context.increment_return_counter() # Returning a value (most common case) if isinstance(sub.typ, BaseType): if not isinstance(self.context.return_type, BaseType): raise TypeMismatchException( "Trying to return base type %r, output expecting %r" % (sub.typ, self.context.return_type), self.stmt.value) sub = unwrap_location(sub) if not are_units_compatible(sub.typ, self.context.return_type): raise TypeMismatchException( "Return type units mismatch %r %r" % (sub.typ, self.context.return_type), self.stmt.value) elif sub.typ.is_literal and ( self.context.return_type.typ == sub.typ or 'int' in self.context.return_type.typ and 'int' in sub.typ.typ): if not SizeLimits.in_bounds(self.context.return_type.typ, sub.value): raise InvalidLiteralException( "Number out of range: " + str(sub.value), self.stmt) else: return LLLnode.from_list([ 'seq', ['mstore', 0, sub], self.make_return_stmt(0, 32) ], typ=None, pos=getpos(self.stmt), valency=0) elif is_base_type(sub.typ, self.context.return_type.typ) or \ (is_base_type(sub.typ, 'int128') and is_base_type(self.context.return_type, 'int256')): return LLLnode.from_list( ['seq', ['mstore', 0, sub], self.make_return_stmt(0, 32)], typ=None, pos=getpos(self.stmt), valency=0) else: raise TypeMismatchException( "Unsupported type conversion: %r to %r" % (sub.typ, self.context.return_type), self.stmt.value) # Returning a byte array elif isinstance(sub.typ, ByteArrayType): if not isinstance(self.context.return_type, ByteArrayType): raise TypeMismatchException( "Trying to return base type %r, output expecting %r" % (sub.typ, self.context.return_type), self.stmt.value) if sub.typ.maxlen > self.context.return_type.maxlen: raise TypeMismatchException( "Cannot cast from greater max-length %d to shorter max-length %d" % (sub.typ.maxlen, self.context.return_type.maxlen), self.stmt.value) loop_memory_position = self.context.new_placeholder( typ=BaseType( 'uint256')) # loop memory has to be allocated first. len_placeholder = self.context.new_placeholder( typ=BaseType('uint256') ) # len & bytez placeholder have to be declared after each other at all times. bytez_placeholder = self.context.new_placeholder(typ=sub.typ) if sub.location in ('storage', 'memory'): return LLLnode.from_list([ 'seq', make_byte_array_copier(LLLnode( bytez_placeholder, location='memory', typ=sub.typ), sub, pos=getpos(self.stmt)), zero_pad(bytez_placeholder, sub.typ.maxlen), ['mstore', len_placeholder, 32], self.make_return_stmt( len_placeholder, ['ceil32', ['add', ['mload', bytez_placeholder], 64]], loop_memory_position=loop_memory_position) ], typ=None, pos=getpos(self.stmt), valency=0) else: raise Exception("Invalid location: %s" % sub.location) elif isinstance(sub.typ, ListType): sub_base_type = re.split(r'\(|\[', str(sub.typ.subtype))[0] ret_base_type = re.split(r'\(|\[', str(self.context.return_type.subtype))[0] loop_memory_position = self.context.new_placeholder( typ=BaseType('uint256')) if sub_base_type != ret_base_type: raise TypeMismatchException( "List return type %r does not match specified return type, expecting %r" % (sub_base_type, ret_base_type), self.stmt) elif sub.location == "memory" and sub.value != "multi": return LLLnode.from_list(self.make_return_stmt( sub, get_size_of_type(self.context.return_type) * 32, loop_memory_position=loop_memory_position), typ=None, pos=getpos(self.stmt), valency=0) else: new_sub = LLLnode.from_list(self.context.new_placeholder( self.context.return_type), typ=self.context.return_type, location='memory') setter = make_setter(new_sub, sub, 'memory', pos=getpos(self.stmt)) return LLLnode.from_list([ 'seq', setter, self.make_return_stmt( new_sub, get_size_of_type(self.context.return_type) * 32, loop_memory_position=loop_memory_position) ], typ=None, pos=getpos(self.stmt)) # Returning a tuple. elif isinstance(sub.typ, TupleType): if not isinstance(self.context.return_type, TupleType): raise TypeMismatchException( "Trying to return tuple type %r, output expecting %r" % (sub.typ, self.context.return_type), self.stmt.value) if len(self.context.return_type.members) != len(sub.typ.members): raise StructureException("Tuple lengths don't match!", self.stmt) # check return type matches, sub type. for i, ret_x in enumerate(self.context.return_type.members): s_member = sub.typ.members[i] sub_type = s_member if isinstance(s_member, NodeType) else s_member.typ if type(sub_type) is not type(ret_x): raise StructureException( "Tuple return type does not match annotated return. {} != {}" .format(type(sub_type), type(ret_x)), self.stmt) # Is from a call expression. if len(sub.args[0].args) > 0 and sub.args[0].args[ 0].value == 'call': # self-call to public. mem_pos = sub.args[0].args[-1] mem_size = get_size_of_type(sub.typ) * 32 return LLLnode.from_list(['return', mem_pos, mem_size], typ=sub.typ) elif (sub.annotation and 'Internal Call' in sub.annotation): mem_pos = sub.args[ -1].value if sub.value == 'seq_unchecked' else sub.args[ 0].args[-1] mem_size = get_size_of_type(sub.typ) * 32 # Add zero padder if bytes are present in output. zero_padder = ['pass'] byte_arrays = [(i, x) for i, x in enumerate(sub.typ.members) if isinstance(x, ByteArrayType)] if byte_arrays: i, x = byte_arrays[-1] zero_padder = zero_pad(bytez_placeholder=[ 'add', mem_pos, ['mload', mem_pos + i * 32] ], maxlen=x.maxlen) return LLLnode.from_list( ['seq'] + [sub] + [zero_padder] + [self.make_return_stmt(mem_pos, mem_size)], typ=sub.typ, pos=getpos(self.stmt), valency=0) subs = [] # Pre-allocate loop_memory_position if required for private function returning. loop_memory_position = self.context.new_placeholder( typ=BaseType('uint256')) if self.context.is_private else None # Allocate dynamic off set counter, to keep track of the total packed dynamic data size. dynamic_offset_counter_placeholder = self.context.new_placeholder( typ=BaseType('uint256')) dynamic_offset_counter = LLLnode( dynamic_offset_counter_placeholder, typ=None, annotation= "dynamic_offset_counter" # dynamic offset position counter. ) new_sub = LLLnode.from_list( self.context.new_placeholder(typ=BaseType('uint256')), typ=self.context.return_type, location='memory', annotation='new_sub') keyz = list(range(len(sub.typ.members))) dynamic_offset_start = 32 * len( sub.args) # The static list of args end. left_token = LLLnode.from_list('_loc', typ=new_sub.typ, location="memory") def get_dynamic_offset_value(): # Get value of dynamic offset counter. return ['mload', dynamic_offset_counter] def increment_dynamic_offset(dynamic_spot): # Increment dyanmic offset counter in memory. return [ 'mstore', dynamic_offset_counter, [ 'add', ['add', ['ceil32', ['mload', dynamic_spot]], 32], ['mload', dynamic_offset_counter] ] ] for i, typ in enumerate(keyz): arg = sub.args[i] variable_offset = LLLnode.from_list( ['add', 32 * i, left_token], typ=arg.typ, annotation='variable_offset') if isinstance(arg.typ, ByteArrayType): # Store offset pointer value. subs.append([ 'mstore', variable_offset, get_dynamic_offset_value() ]) # Store dynamic data, from offset pointer onwards. dynamic_spot = LLLnode.from_list( ['add', left_token, get_dynamic_offset_value()], location="memory", typ=arg.typ, annotation='dynamic_spot') subs.append( make_setter(dynamic_spot, arg, location="memory", pos=getpos(self.stmt))) subs.append(increment_dynamic_offset(dynamic_spot)) elif isinstance(arg.typ, BaseType): subs.append( make_setter(variable_offset, arg, "memory", pos=getpos(self.stmt))) else: raise Exception("Can't return type %s as part of tuple", type(arg.typ)) setter = LLLnode.from_list([ 'seq', [ 'mstore', dynamic_offset_counter, dynamic_offset_start ], ['with', '_loc', new_sub, ['seq'] + subs] ], typ=None) return LLLnode.from_list([ 'seq', setter, self.make_return_stmt(new_sub, get_dynamic_offset_value(), loop_memory_position) ], typ=None, pos=getpos(self.stmt), valency=0) else: raise TypeMismatchException("Can only return base type!", self.stmt)
def parse_return(self): from .parser import (make_setter) if self.context.return_type is None: if self.stmt.value: raise TypeMismatchException("Not expecting to return a value", self.stmt) return LLLnode.from_list(['return', 0, 0], typ=None, pos=getpos(self.stmt)) if not self.stmt.value: raise TypeMismatchException("Expecting to return a value", self.stmt) sub = Expr(self.stmt.value, self.context).lll_node self.context.increment_return_counter() # Returning a value (most common case) if isinstance(sub.typ, BaseType): if not isinstance(self.context.return_type, BaseType): raise TypeMismatchException( "Trying to return base type %r, output expecting %r" % (sub.typ, self.context.return_type), self.stmt.value) sub = unwrap_location(sub) if not are_units_compatible(sub.typ, self.context.return_type): raise TypeMismatchException( "Return type units mismatch %r %r" % (sub.typ, self.context.return_type), self.stmt.value) elif is_base_type(sub.typ, self.context.return_type.typ) or \ (is_base_type(sub.typ, 'int128') and is_base_type(self.context.return_type, 'int256')): return LLLnode.from_list( ['seq', ['mstore', 0, sub], ['return', 0, 32]], typ=None, pos=getpos(self.stmt)) if sub.typ.is_literal and SizeLimits.in_bounds( self.context.return_type.typ, sub.value): return LLLnode.from_list( ['seq', ['mstore', 0, sub], ['return', 0, 32]], typ=None, pos=getpos(self.stmt)) else: raise TypeMismatchException( "Unsupported type conversion: %r to %r" % (sub.typ, self.context.return_type), self.stmt.value) # Returning a byte array elif isinstance(sub.typ, ByteArrayType): if not isinstance(self.context.return_type, ByteArrayType): raise TypeMismatchException( "Trying to return base type %r, output expecting %r" % (sub.typ, self.context.return_type), self.stmt.value) if sub.typ.maxlen > self.context.return_type.maxlen: raise TypeMismatchException( "Cannot cast from greater max-length %d to shorter max-length %d" % (sub.typ.maxlen, self.context.return_type.maxlen), self.stmt.value) zero_padder = LLLnode.from_list(['pass']) if sub.typ.maxlen > 0: zero_pad_i = self.context.new_placeholder( BaseType('uint256')) # Iterator used to zero pad memory. zero_padder = LLLnode.from_list( [ 'repeat', zero_pad_i, ['mload', '_loc'], sub.typ.maxlen, [ 'seq', [ 'if', ['gt', ['mload', zero_pad_i], sub.typ.maxlen], 'break' ], # stay within allocated bounds [ 'mstore8', [ 'add', ['add', 32, '_loc'], ['mload', zero_pad_i] ], 0 ] ] ], annotation="Zero pad") # Returning something already in memory if sub.location == 'memory': return LLLnode.from_list([ 'with', '_loc', sub, [ 'seq', ['mstore', ['sub', '_loc', 32], 32], zero_padder, [ 'return', ['sub', '_loc', 32], ['ceil32', ['add', ['mload', '_loc'], 64]] ] ] ], typ=None, pos=getpos(self.stmt)) # Copying from storage elif sub.location == 'storage': # Instantiate a byte array at some index fake_byte_array = LLLnode(self.context.get_next_mem() + 32, typ=sub.typ, location='memory', pos=getpos(self.stmt)) o = [ 'with', '_loc', self.context.get_next_mem() + 32, [ 'seq', # Copy the data to this byte array make_byte_array_copier(fake_byte_array, sub), # Store the number 32 before it for ABI formatting purposes ['mstore', self.context.get_next_mem(), 32], zero_padder, # Return it [ 'return', self.context.get_next_mem(), [ 'add', [ 'ceil32', [ 'mload', self.context.get_next_mem() + 32 ] ], 64 ] ] ] ] return LLLnode.from_list(o, typ=None, pos=getpos(self.stmt)) else: raise Exception("Invalid location: %s" % sub.location) elif isinstance(sub.typ, ListType): sub_base_type = re.split(r'\(|\[', str(sub.typ.subtype))[0] ret_base_type = re.split(r'\(|\[', str(self.context.return_type.subtype))[0] if sub_base_type != ret_base_type: raise TypeMismatchException( "List return type %r does not match specified return type, expecting %r" % (sub_base_type, ret_base_type), self.stmt) elif sub.location == "memory" and sub.value != "multi": return LLLnode.from_list([ 'return', sub, get_size_of_type(self.context.return_type) * 32 ], typ=None, pos=getpos(self.stmt)) else: new_sub = LLLnode.from_list(self.context.new_placeholder( self.context.return_type), typ=self.context.return_type, location='memory') setter = make_setter(new_sub, sub, 'memory', pos=getpos(self.stmt)) return LLLnode.from_list([ 'seq', setter, [ 'return', new_sub, get_size_of_type(self.context.return_type) * 32 ] ], typ=None, pos=getpos(self.stmt)) # Returning a tuple. elif isinstance(sub.typ, TupleType): if len(self.context.return_type.members) != len(sub.typ.members): raise StructureException("Tuple lengths don't match!", self.stmt) subs = [] dynamic_offset_counter = LLLnode( self.context.get_next_mem(), typ=None, annotation="dynamic_offset_counter" ) # dynamic offset position counter. new_sub = LLLnode.from_list(self.context.get_next_mem() + 32, typ=self.context.return_type, location='memory', annotation='new_sub') keyz = list(range(len(sub.typ.members))) dynamic_offset_start = 32 * len( sub.args) # The static list of args end. left_token = LLLnode.from_list('_loc', typ=new_sub.typ, location="memory") def get_dynamic_offset_value(): # Get value of dynamic offset counter. return ['mload', dynamic_offset_counter] def increment_dynamic_offset(dynamic_spot): # Increment dyanmic offset counter in memory. return [ 'mstore', dynamic_offset_counter, [ 'add', ['add', ['ceil32', ['mload', dynamic_spot]], 32], ['mload', dynamic_offset_counter] ] ] for i, typ in enumerate(keyz): arg = sub.args[i] variable_offset = LLLnode.from_list( ['add', 32 * i, left_token], typ=arg.typ, annotation='variable_offset') if isinstance(arg.typ, ByteArrayType): # Store offset pointer value. subs.append([ 'mstore', variable_offset, get_dynamic_offset_value() ]) # Store dynamic data, from offset pointer onwards. dynamic_spot = LLLnode.from_list( ['add', left_token, get_dynamic_offset_value()], location="memory", typ=arg.typ, annotation='dynamic_spot') subs.append( make_setter(dynamic_spot, arg, location="memory", pos=getpos(self.stmt))) subs.append(increment_dynamic_offset(dynamic_spot)) elif isinstance(arg.typ, BaseType): subs.append( make_setter(variable_offset, arg, "memory", pos=getpos(self.stmt))) else: raise Exception("Can't return type %s as part of tuple", type(arg.typ)) setter = LLLnode.from_list([ 'seq', [ 'mstore', dynamic_offset_counter, dynamic_offset_start ], ['with', '_loc', new_sub, ['seq'] + subs] ], typ=None) return LLLnode.from_list([ 'seq', setter, ['return', new_sub, get_dynamic_offset_value()] ], typ=None, pos=getpos(self.stmt)) else: raise TypeMismatchException("Can only return base type!", self.stmt)
def compare(self): left = Expr.parse_value_expr(self.expr.left, self.context) right = Expr.parse_value_expr(self.expr.comparators[0], self.context) if isinstance(right.typ, NullType): raise InvalidLiteralException( 'Comparison to None is not allowed, compare against a default value.', self.expr, ) if isinstance(left.typ, ByteArrayLike) and isinstance(right.typ, ByteArrayLike): # TODO: Can this if branch be removed ^ pass elif isinstance(self.expr.ops[0], vy_ast.In) and isinstance(right.typ, ListType): if left.typ != right.typ.subtype: raise TypeMismatchException( "Can't use IN comparison with different types!", self.expr, ) return self.build_in_comparator() else: if not are_units_compatible(left.typ, right.typ) and not are_units_compatible(right.typ, left.typ): # noqa: E501 raise TypeMismatchException("Can't compare values with different units!", self.expr) if len(self.expr.ops) != 1: raise StructureException( "Cannot have a comparison with more than two elements", self.expr, ) if isinstance(self.expr.ops[0], vy_ast.Gt): op = 'sgt' elif isinstance(self.expr.ops[0], vy_ast.GtE): op = 'sge' elif isinstance(self.expr.ops[0], vy_ast.LtE): op = 'sle' elif isinstance(self.expr.ops[0], vy_ast.Lt): op = 'slt' elif isinstance(self.expr.ops[0], vy_ast.Eq): op = 'eq' elif isinstance(self.expr.ops[0], vy_ast.NotEq): op = 'ne' else: raise Exception("Unsupported comparison operator") # Compare (limited to 32) byte arrays. if isinstance(left.typ, ByteArrayLike) and isinstance(right.typ, ByteArrayLike): left = Expr(self.expr.left, self.context).lll_node right = Expr(self.expr.comparators[0], self.context).lll_node length_mismatch = (left.typ.maxlen != right.typ.maxlen) left_over_32 = left.typ.maxlen > 32 right_over_32 = right.typ.maxlen > 32 if length_mismatch or left_over_32 or right_over_32: left_keccak = keccak256_helper(self.expr, [left], None, self.context) right_keccak = keccak256_helper(self.expr, [right], None, self.context) if op == 'eq' or op == 'ne': return LLLnode.from_list( [op, left_keccak, right_keccak], typ='bool', pos=getpos(self.expr), ) else: raise ParserException( "Can only compare strings/bytes of length shorter", " than 32 bytes other than equality comparisons", self.expr, ) else: def load_bytearray(side): if side.location == 'memory': return ['mload', ['add', 32, side]] elif side.location == 'storage': return ['sload', ['add', 1, ['sha3_32', side]]] return LLLnode.from_list( [op, load_bytearray(left), load_bytearray(right)], typ='bool', pos=getpos(self.expr), ) # Compare other types. if not is_numeric_type(left.typ) or not is_numeric_type(right.typ): if op not in ('eq', 'ne'): raise TypeMismatchException("Invalid type for comparison op", self.expr) left_type, right_type = left.typ.typ, right.typ.typ # Special Case: comparison of a literal integer. If in valid range allow it to be compared. if {left_type, right_type} == {'int128', 'uint256'} and {left.typ.is_literal, right.typ.is_literal} == {True, False}: # noqa: E501 comparison_allowed = False if left.typ.is_literal and SizeLimits.in_bounds(right_type, left.value): comparison_allowed = True elif right.typ.is_literal and SizeLimits.in_bounds(left_type, right.value): comparison_allowed = True op = self._signed_to_unsigned_comparision_op(op) if comparison_allowed: return LLLnode.from_list([op, left, right], typ='bool', pos=getpos(self.expr)) elif {left_type, right_type} == {'uint256', 'uint256'}: op = self._signed_to_unsigned_comparision_op(op) elif (left_type in ('decimal', 'int128') or right_type in ('decimal', 'int128')) and left_type != right_type: # noqa: E501 raise TypeMismatchException( f'Implicit conversion from {left_type} to {right_type} disallowed, please convert.', self.expr, ) if left_type == right_type: return LLLnode.from_list([op, left, right], typ='bool', pos=getpos(self.expr)) else: raise TypeMismatchException( f"Unsupported types for comparison: {left_type} {right_type}", self.expr, )
def arithmetic(self): left = Expr.parse_value_expr(self.expr.left, self.context) right = Expr.parse_value_expr(self.expr.right, self.context) if not is_numeric_type(left.typ) or not is_numeric_type(right.typ): raise TypeMismatchException( f"Unsupported types for arithmetic op: {left.typ} {right.typ}", self.expr, ) arithmetic_pair = {left.typ.typ, right.typ.typ} # Special Case: Simplify any literal to literal arithmetic at compile time. if left.typ.is_literal and right.typ.is_literal and \ isinstance(right.value, int) and isinstance(left.value, int) and \ arithmetic_pair.issubset({'uint256', 'int128'}): if isinstance(self.expr.op, vy_ast.Add): val = left.value + right.value elif isinstance(self.expr.op, vy_ast.Sub): val = left.value - right.value elif isinstance(self.expr.op, vy_ast.Mult): val = left.value * right.value elif isinstance(self.expr.op, vy_ast.Pow): val = left.value ** right.value elif isinstance(self.expr.op, (vy_ast.Div, vy_ast.Mod)): if right.value == 0: raise ZeroDivisionException( "integer division or modulo by zero", self.expr, ) if isinstance(self.expr.op, vy_ast.Div): val = left.value // right.value elif isinstance(self.expr.op, vy_ast.Mod): # modified modulo logic to remain consistent with EVM behaviour val = abs(left.value) % abs(right.value) if left.value < 0: val = -val else: raise ParserException( f'Unsupported literal operator: {type(self.expr.op)}', self.expr, ) num = vy_ast.Int(n=val) num.full_source_code = self.expr.full_source_code num.node_source_code = self.expr.node_source_code num.lineno = self.expr.lineno num.col_offset = self.expr.col_offset num.end_lineno = self.expr.end_lineno num.end_col_offset = self.expr.end_col_offset return Expr.parse_value_expr(num, self.context) pos = getpos(self.expr) # Special case with uint256 were int literal may be casted. if arithmetic_pair == {'uint256', 'int128'}: # Check right side literal. if right.typ.is_literal and SizeLimits.in_bounds('uint256', right.value): right = LLLnode.from_list( right.value, typ=BaseType('uint256', None, is_literal=True), pos=pos, ) # Check left side literal. elif left.typ.is_literal and SizeLimits.in_bounds('uint256', left.value): left = LLLnode.from_list( left.value, typ=BaseType('uint256', None, is_literal=True), pos=pos, ) if left.typ.typ == "decimal" and isinstance(self.expr.op, vy_ast.Pow): raise TypeMismatchException( "Cannot perform exponentiation on decimal values.", self.expr, ) # Only allow explicit conversions to occur. if left.typ.typ != right.typ.typ: raise TypeMismatchException( f"Cannot implicitly convert {left.typ.typ} to {right.typ.typ}.", self.expr, ) ltyp, rtyp = left.typ.typ, right.typ.typ if isinstance(self.expr.op, (vy_ast.Add, vy_ast.Sub)): if left.typ.unit != right.typ.unit and left.typ.unit != {} and right.typ.unit != {}: raise TypeMismatchException( f"Unit mismatch: {left.typ.unit} {right.typ.unit}", self.expr, ) if ( left.typ.positional and right.typ.positional and isinstance(self.expr.op, vy_ast.Add) ): raise TypeMismatchException( "Cannot add two positional units!", self.expr, ) new_unit = left.typ.unit or right.typ.unit # xor, as subtracting two positionals gives a delta new_positional = left.typ.positional ^ right.typ.positional new_typ = BaseType(ltyp, new_unit, new_positional) op = 'add' if isinstance(self.expr.op, vy_ast.Add) else 'sub' if ltyp == 'uint256' and isinstance(self.expr.op, vy_ast.Add): # safeadd arith = ['seq', ['assert', ['ge', ['add', 'l', 'r'], 'l']], ['add', 'l', 'r']] elif ltyp == 'uint256' and isinstance(self.expr.op, vy_ast.Sub): # safesub arith = ['seq', ['assert', ['ge', 'l', 'r']], ['sub', 'l', 'r']] elif ltyp == rtyp: arith = [op, 'l', 'r'] else: raise Exception(f"Unsupported Operation '{op}({ltyp}, {rtyp})'") elif isinstance(self.expr.op, vy_ast.Mult): if left.typ.positional or right.typ.positional: raise TypeMismatchException("Cannot multiply positional values!", self.expr) new_unit = combine_units(left.typ.unit, right.typ.unit) new_typ = BaseType(ltyp, new_unit) if ltyp == rtyp == 'uint256': arith = ['with', 'ans', ['mul', 'l', 'r'], ['seq', ['assert', ['or', ['eq', ['div', 'ans', 'l'], 'r'], ['iszero', 'l']]], 'ans']] elif ltyp == rtyp == 'int128': # TODO should this be 'smul' (note edge cases in YP for smul) arith = ['mul', 'l', 'r'] elif ltyp == rtyp == 'decimal': # TODO should this be smul arith = ['with', 'ans', ['mul', 'l', 'r'], ['seq', ['assert', ['or', ['eq', ['sdiv', 'ans', 'l'], 'r'], ['iszero', 'l']]], ['sdiv', 'ans', DECIMAL_DIVISOR]]] else: raise Exception(f"Unsupported Operation 'mul({ltyp}, {rtyp})'") elif isinstance(self.expr.op, vy_ast.Div): if right.typ.is_literal and right.value == 0: raise ZeroDivisionException("Cannot divide by 0.", self.expr) if left.typ.positional or right.typ.positional: raise TypeMismatchException("Cannot divide positional values!", self.expr) new_unit = combine_units(left.typ.unit, right.typ.unit, div=True) new_typ = BaseType(ltyp, new_unit) if ltyp == rtyp == 'uint256': arith = ['div', 'l', ['clamp_nonzero', 'r']] elif ltyp == rtyp == 'int128': arith = ['sdiv', 'l', ['clamp_nonzero', 'r']] elif ltyp == rtyp == 'decimal': arith = ['sdiv', # TODO check overflow cases, also should it be smul ['mul', 'l', DECIMAL_DIVISOR], ['clamp_nonzero', 'r']] else: raise Exception(f"Unsupported Operation 'div({ltyp}, {rtyp})'") elif isinstance(self.expr.op, vy_ast.Mod): if right.typ.is_literal and right.value == 0: raise ZeroDivisionException("Cannot calculate modulus of 0.", self.expr) if left.typ.positional or right.typ.positional: raise TypeMismatchException( "Cannot use positional values as modulus arguments!", self.expr, ) if not are_units_compatible(left.typ, right.typ) and not (left.typ.unit or right.typ.unit): # noqa: E501 raise TypeMismatchException("Modulus arguments must have same unit", self.expr) new_unit = left.typ.unit or right.typ.unit new_typ = BaseType(ltyp, new_unit) if ltyp == rtyp == 'uint256': arith = ['mod', 'l', ['clamp_nonzero', 'r']] elif ltyp == rtyp: # TODO should this be regular mod arith = ['smod', 'l', ['clamp_nonzero', 'r']] else: raise Exception(f"Unsupported Operation 'mod({ltyp}, {rtyp})'") elif isinstance(self.expr.op, vy_ast.Pow): if left.typ.positional or right.typ.positional: raise TypeMismatchException( "Cannot use positional values as exponential arguments!", self.expr, ) if right.typ.unit: raise TypeMismatchException( "Cannot use unit values as exponents", self.expr, ) if ltyp != 'int128' and ltyp != 'uint256' and isinstance(self.expr.right, vy_ast.Name): raise TypeMismatchException( "Cannot use dynamic values as exponents, for unit base types", self.expr, ) new_unit = left.typ.unit if left.typ.unit and not isinstance(self.expr.right, vy_ast.Name): new_unit = {left.typ.unit.copy().popitem()[0]: self.expr.right.n} new_typ = BaseType(ltyp, new_unit) if ltyp == rtyp == 'uint256': arith = ['seq', ['assert', ['or', # r == 1 | iszero(r) # could be simplified to ~(r & 1) ['or', ['eq', 'r', 1], ['iszero', 'r']], ['lt', 'l', ['exp', 'l', 'r']]]], ['exp', 'l', 'r']] elif ltyp == rtyp == 'int128': arith = ['exp', 'l', 'r'] else: raise TypeMismatchException('Only whole number exponents are supported', self.expr) else: raise ParserException(f"Unsupported binary operator: {self.expr.op}", self.expr) p = ['seq'] if new_typ.typ == 'int128': p.append([ 'clamp', ['mload', MemoryPositions.MINNUM], arith, ['mload', MemoryPositions.MAXNUM], ]) elif new_typ.typ == 'decimal': p.append([ 'clamp', ['mload', MemoryPositions.MINDECIMAL], arith, ['mload', MemoryPositions.MAXDECIMAL], ]) elif new_typ.typ == 'uint256': p.append(arith) else: raise Exception(f"{arith} {new_typ}") p = ['with', 'l', left, ['with', 'r', right, p]] return LLLnode.from_list(p, typ=new_typ, pos=pos)
def enforce_units(typ, obj, expected): if not are_units_compatible(typ, expected): raise TypeMismatchException("Invalid units", obj)
def compare(self): left = Expr.parse_value_expr(self.expr.left, self.context) right = Expr.parse_value_expr(self.expr.comparators[0], self.context) if isinstance(left.typ, ByteArrayType) and isinstance( right.typ, ByteArrayType): if left.typ.maxlen != right.typ.maxlen: raise TypeMismatchException( 'Can only compare bytes of the same length', self.expr) if left.typ.maxlen > 32 or right.typ.maxlen > 32: raise ParserException( 'Can only compare bytes of length shorter than 32 bytes', self.expr) elif isinstance(self.expr.ops[0], ast.In) and \ isinstance(right.typ, ListType): if not are_units_compatible( left.typ, right.typ.subtype) and not are_units_compatible( right.typ.subtype, left.typ): raise TypeMismatchException( "Can't use IN comparison with different types!", self.expr) return self.build_in_comparator() else: if not are_units_compatible( left.typ, right.typ) and not are_units_compatible( right.typ, left.typ): raise TypeMismatchException( "Can't compare values with different units!", self.expr) if len(self.expr.ops) != 1: raise StructureException( "Cannot have a comparison with more than two elements", self.expr) if isinstance(self.expr.ops[0], ast.Gt): op = 'sgt' elif isinstance(self.expr.ops[0], ast.GtE): op = 'sge' elif isinstance(self.expr.ops[0], ast.LtE): op = 'sle' elif isinstance(self.expr.ops[0], ast.Lt): op = 'slt' elif isinstance(self.expr.ops[0], ast.Eq): op = 'eq' elif isinstance(self.expr.ops[0], ast.NotEq): op = 'ne' else: raise Exception("Unsupported comparison operator") # Compare (limited to 32) byte arrays. if isinstance(left.typ, ByteArrayType) and isinstance( left.typ, ByteArrayType): left = Expr(self.expr.left, self.context).lll_node right = Expr(self.expr.comparators[0], self.context).lll_node def load_bytearray(side): if side.location == 'memory': return ['mload', ['add', 32, side]] elif side.location == 'storage': return ['sload', ['add', 1, ['sha3_32', side]]] return LLLnode.from_list( [op, load_bytearray(left), load_bytearray(right)], typ='bool', pos=getpos(self.expr)) # Compare other types. if not is_numeric_type(left.typ) or not is_numeric_type(right.typ): if op not in ('eq', 'ne'): raise TypeMismatchException("Invalid type for comparison op", self.expr) left_type, right_type = left.typ.typ, right.typ.typ # Special Case: comparison of a literal integer. If in valid range allow it to be compared. if {left_type, right_type} == {'int128', 'uint256'} and { left.typ.is_literal, right.typ.is_literal } == {True, False}: comparison_allowed = False if left.typ.is_literal and SizeLimits.in_bounds( right_type, left.value): comparison_allowed = True elif right.typ.is_literal and SizeLimits.in_bounds( left_type, right.value): comparison_allowed = True op = self._signed_to_unsigned_comparision_op(op) if comparison_allowed: return LLLnode.from_list([op, left, right], typ='bool', pos=getpos(self.expr)) elif {left_type, right_type} == {'uint256', 'uint256'}: op = self._signed_to_unsigned_comparision_op(op) elif (left_type in ('decimal', 'int128') or right_type in ('decimal', 'int128')) and left_type != right_type: raise TypeMismatchException( 'Implicit conversion from {} to {} disallowed, please convert.' .format(left_type, right_type), self.expr) if left_type == right_type: return LLLnode.from_list([op, left, right], typ='bool', pos=getpos(self.expr)) else: raise TypeMismatchException( "Unsupported types for comparison: %r %r" % (left_type, right_type), self.expr)
def arithmetic(self): pre_alloc_left, left = self.arithmetic_get_reference(self.expr.left) pre_alloc_right, right = self.arithmetic_get_reference(self.expr.right) if not is_numeric_type(left.typ) or not is_numeric_type(right.typ): raise TypeMismatchException( "Unsupported types for arithmetic op: %r %r" % (left.typ, right.typ), self.expr, ) arithmetic_pair = {left.typ.typ, right.typ.typ} # Special Case: Simplify any literal to literal arithmetic at compile time. if left.typ.is_literal and right.typ.is_literal and \ isinstance(right.value, int) and isinstance(left.value, int): if isinstance(self.expr.op, ast.Add): val = left.value + right.value elif isinstance(self.expr.op, ast.Sub): val = left.value - right.value elif isinstance(self.expr.op, ast.Mult): val = left.value * right.value elif isinstance(self.expr.op, ast.Div): val = left.value // right.value elif isinstance(self.expr.op, ast.Mod): val = left.value % right.value elif isinstance(self.expr.op, ast.Pow): val = left.value**right.value else: raise ParserException( 'Unsupported literal operator: %s' % str(type(self.expr.op)), self.expr, ) num = ast.Num(val) num.source_code = self.expr.source_code num.lineno = self.expr.lineno num.col_offset = self.expr.col_offset return Expr.parse_value_expr(num, self.context) # Special case with uint256 were int literal may be casted. if arithmetic_pair == {'uint256', 'int128'}: # Check right side literal. if right.typ.is_literal and SizeLimits.in_bounds( 'uint256', right.value): right = LLLnode.from_list( right.value, typ=BaseType('uint256', None, is_literal=True), pos=getpos(self.expr), ) arithmetic_pair = {left.typ.typ, right.typ.typ} # Check left side literal. elif left.typ.is_literal and SizeLimits.in_bounds( 'uint256', left.value): left = LLLnode.from_list( left.value, typ=BaseType('uint256', None, is_literal=True), pos=getpos(self.expr), ) arithmetic_pair = {left.typ.typ, right.typ.typ} # Only allow explicit conversions to occur. if left.typ.typ != right.typ.typ: raise TypeMismatchException( "Cannot implicitly convert {} to {}.".format( left.typ.typ, right.typ.typ), self.expr, ) ltyp, rtyp = left.typ.typ, right.typ.typ if isinstance(self.expr.op, (ast.Add, ast.Sub)): if left.typ.unit != right.typ.unit and left.typ.unit != {} and right.typ.unit != {}: raise TypeMismatchException( "Unit mismatch: %r %r" % (left.typ.unit, right.typ.unit), self.expr, ) if left.typ.positional and right.typ.positional and isinstance( self.expr.op, ast.Add): raise TypeMismatchException( "Cannot add two positional units!", self.expr, ) new_unit = left.typ.unit or right.typ.unit # xor, as subtracting two positionals gives a delta new_positional = left.typ.positional ^ right.typ.positional op = 'add' if isinstance(self.expr.op, ast.Add) else 'sub' if ltyp == 'uint256' and isinstance(self.expr.op, ast.Add): o = LLLnode.from_list( [ 'seq', # Checks that: a + b >= a ['assert', ['ge', ['add', left, right], left]], ['add', left, right], ], typ=BaseType('uint256', new_unit, new_positional), pos=getpos(self.expr)) elif ltyp == 'uint256' and isinstance(self.expr.op, ast.Sub): o = LLLnode.from_list( [ 'seq', # Checks that: a >= b ['assert', ['ge', left, right]], ['sub', left, right] ], typ=BaseType('uint256', new_unit, new_positional), pos=getpos(self.expr)) elif ltyp == rtyp: o = LLLnode.from_list( [op, left, right], typ=BaseType(ltyp, new_unit, new_positional), pos=getpos(self.expr), ) else: raise Exception("Unsupported Operation '%r(%r, %r)'" % (op, ltyp, rtyp)) elif isinstance(self.expr.op, ast.Mult): if left.typ.positional or right.typ.positional: raise TypeMismatchException( "Cannot multiply positional values!", self.expr) new_unit = combine_units(left.typ.unit, right.typ.unit) if ltyp == rtyp == 'uint256': o = LLLnode.from_list([ 'if', ['eq', left, 0], [0], [ 'seq', [ 'assert', ['eq', ['div', ['mul', left, right], left], right] ], ['mul', left, right] ], ], typ=BaseType('uint256', new_unit), pos=getpos(self.expr)) elif ltyp == rtyp == 'int128': o = LLLnode.from_list( ['mul', left, right], typ=BaseType('int128', new_unit), pos=getpos(self.expr), ) elif ltyp == rtyp == 'decimal': o = LLLnode.from_list([ 'with', 'r', right, [ 'with', 'l', left, [ 'with', 'ans', ['mul', 'l', 'r'], [ 'seq', [ 'assert', [ 'or', [ 'eq', ['sdiv', 'ans', 'l'], 'r' ], ['iszero', 'l'] ] ], ['sdiv', 'ans', DECIMAL_DIVISOR], ], ], ], ], typ=BaseType('decimal', new_unit), pos=getpos(self.expr)) else: raise Exception("Unsupported Operation 'mul(%r, %r)'" % (ltyp, rtyp)) elif isinstance(self.expr.op, ast.Div): if left.typ.positional or right.typ.positional: raise TypeMismatchException("Cannot divide positional values!", self.expr) new_unit = combine_units(left.typ.unit, right.typ.unit, div=True) if ltyp == rtyp == 'uint256': o = LLLnode.from_list( [ 'seq', # Checks that: b != 0 ['assert', right], ['div', left, right], ], typ=BaseType('uint256', new_unit), pos=getpos(self.expr)) elif ltyp == rtyp == 'int128': o = LLLnode.from_list( ['sdiv', left, ['clamp_nonzero', right]], typ=BaseType('int128', new_unit), pos=getpos(self.expr), ) elif ltyp == rtyp == 'decimal': o = LLLnode.from_list([ 'with', 'l', left, [ 'with', 'r', ['clamp_nonzero', right], [ 'sdiv', ['mul', 'l', DECIMAL_DIVISOR], 'r', ], ] ], typ=BaseType('decimal', new_unit), pos=getpos(self.expr)) else: raise Exception("Unsupported Operation 'div(%r, %r)'" % (ltyp, rtyp)) elif isinstance(self.expr.op, ast.Mod): if left.typ.positional or right.typ.positional: raise TypeMismatchException( "Cannot use positional values as modulus arguments!", self.expr, ) if not are_units_compatible(left.typ, right.typ) and not ( left.typ.unit or right.typ.unit): # noqa: E501 raise TypeMismatchException( "Modulus arguments must have same unit", self.expr) new_unit = left.typ.unit or right.typ.unit if ltyp == rtyp == 'uint256': o = LLLnode.from_list( ['seq', ['assert', right], ['mod', left, right]], typ=BaseType('uint256', new_unit), pos=getpos(self.expr)) elif ltyp == rtyp: o = LLLnode.from_list( ['smod', left, ['clamp_nonzero', right]], typ=BaseType(ltyp, new_unit), pos=getpos(self.expr), ) else: raise Exception("Unsupported Operation 'mod(%r, %r)'" % (ltyp, rtyp)) elif isinstance(self.expr.op, ast.Pow): if left.typ.positional or right.typ.positional: raise TypeMismatchException( "Cannot use positional values as exponential arguments!", self.expr, ) if right.typ.unit: raise TypeMismatchException( "Cannot use unit values as exponents", self.expr, ) if ltyp != 'int128' and ltyp != 'uint256' and isinstance( self.expr.right, ast.Name): raise TypeMismatchException( "Cannot use dynamic values as exponents, for unit base types", self.expr, ) if ltyp == rtyp == 'uint256': o = LLLnode.from_list([ 'seq', [ 'assert', [ 'or', ['or', ['eq', right, 1], ['iszero', right]], ['lt', left, ['exp', left, right]] ], ], ['exp', left, right], ], typ=BaseType('uint256'), pos=getpos(self.expr)) elif ltyp == rtyp == 'int128': new_unit = left.typ.unit if left.typ.unit and not isinstance(self.expr.right, ast.Name): new_unit = { left.typ.unit.copy().popitem()[0]: self.expr.right.n } o = LLLnode.from_list( ['exp', left, right], typ=BaseType('int128', new_unit), pos=getpos(self.expr), ) else: raise TypeMismatchException( 'Only whole number exponents are supported', self.expr) else: raise ParserException( "Unsupported binary operator: %r" % self.expr.op, self.expr) p = ['seq'] if pre_alloc_left: p.append(pre_alloc_left) if pre_alloc_right: p.append(pre_alloc_right) if o.typ.typ == 'int128': p.append([ 'clamp', ['mload', MemoryPositions.MINNUM], o, ['mload', MemoryPositions.MAXNUM], ]) return LLLnode.from_list(p, typ=o.typ, pos=getpos(self.expr)) elif o.typ.typ == 'decimal': p.append([ 'clamp', ['mload', MemoryPositions.MINDECIMAL], o, ['mload', MemoryPositions.MAXDECIMAL], ]) return LLLnode.from_list(p, typ=o.typ, pos=getpos(self.expr)) if o.typ.typ == 'uint256': p.append(o) return LLLnode.from_list(p, typ=o.typ, pos=getpos(self.expr)) else: raise Exception("%r %r" % (o, o.typ))