def genLNot(resultLoc, srcLoc): if resultLoc.getType().isUnknown(): resultLoc = resultLoc.removeUnknown(srcLoc.getType()) if resultLoc.getType() != srcLoc.getType(): raise SemanticError( srcLoc.getPosition(), "Incompatible types: {} and {}".format(resultLoc.getType(), srcLoc.getType())) if srcLoc.getType().getSize() != 1 or srcLoc.getType().getSign(): raise SemanticError(srcLoc.getPosition(), "Argument for `!' should be of type u8") assert (resultLoc.getIndirLevel() == 1) assert (srcLoc.getIndirLevel() == 1 or srcLoc.getIndirLevel() == 0) result = '; {} = !{}\n'.format(resultLoc, srcLoc) if srcLoc.getIndirLevel() == 0: # constant c = srcLoc.getSource() if c.isNumber(): c = int(not bool(c)) else: c = 'int(not bool({}))'.format(c) # Warning return Value(srcLoc.getPosition(), BoolType(), 0, c, True), result else: # var result += loadByte('a', srcLoc, 0) result += ''' dec a ; c = a == 0 mov a, 0 adc a, 0 '''.format(srcLoc.getSource()) return Value.register(srcLoc.getPosition(), BoolType()), result return resultLoc, result
def genNe(resultLoc, src1Loc, src2Loc): resultLoc = resultLoc.withType(BoolType()) assert (resultLoc.getIndirLevel() == 1) if src1Loc.getType() != src2Loc.getType(): raise SemanticError( src1Loc.getPosition(), "Incompatible types: {} and {}".format(src1Loc.getType(), src2Loc.getType())) t = src1Loc.getType() s1 = src1Loc.getSource() s2 = src2Loc.getSource() rs = resultLoc.getSource() l1 = src1Loc.getIndirLevel() l2 = src2Loc.getIndirLevel() isWord = t.getSize() == 2 result = '; {} == {}\n'.format(src1Loc, src2Loc) if l1 == 0 and l2 == 0: # const == const pos = src1Loc.getPosition() - src2Loc.getPosition() if s1.isNumber() and s2.isNumber(): return Value(pos, BoolType(), 0, int(int(s1) != int(s2)), True), result else: return Value(pos, BoolType(), 0, "int(({}) != ({}))".format(s1, s2), True), result else: result += _genEqNeCmp(src1Loc, src2Loc) result += ''' dec b ldi a, 1 sbb a, 0 ''' return Value.register(src1Loc.getPosition() - src2Loc.getPosition(), BoolType()), result
def genNeg(resultLoc, srcLoc): if resultLoc.getType().isUnknown(): resultLoc = resultLoc.removeUnknown(srcLoc.getType()) if resultLoc.getType() != srcLoc.getType(): raise SemanticError( srcLoc.getPosition(), "Incompatible types: {} and {}".format(resultLoc.getType(), srcLoc.getType())) if not srcLoc.getType().getSign(): raise SemanticError( srcLoc.getPosition(), "Argument for unary `-' should be of a signed type") assert (resultLoc.getIndirLevel() == 1) assert (srcLoc.getIndirLevel() == 1 or srcLoc.getIndirLevel() == 0) t = srcLoc.getType() if t.getSize() > 2: raise NatrixNotImplementedError( srcLoc.getPosition(), "Negation of ints wider than s16 is not implemented") result = '; {} = -{}\n'.format(resultLoc, srcLoc) if srcLoc.getIndirLevel() == 0: # constant c = srcLoc.getSource() if c.isNumber(): c = -int(c) & (0xff if t.getSize() == 1 else 0xffff) else: c = '-({})'.format(c) # Warning return Value(srcLoc.getPosition(), t, 0, c, True), result else: # var if t.getSize() == 1: result += loadByte('a', srcLoc, 0) result += 'neg a\n' return Value.register(srcLoc.getPosition(), t), result else: result += f''' ldi pl, lo({srcLoc.getSource()}) ldi ph, hi({srcLoc.getSource()}) ld b ''' result += incP(srcLoc.isAligned()) result += ''' ld a not a not b inc b adc a, 0 ''' result += f''' ldi pl, lo({resultLoc.getSource()}) ldi ph, hi({resultLoc.getSource()}) st b ''' result += incP(resultLoc.isAligned()) result += ''' st a ''' return resultLoc, result
def genBNot(resultLoc, srcLoc): if resultLoc.getType().isUnknown(): resultLoc = resultLoc.removeUnknown(srcLoc.getType()) if resultLoc.getType() != srcLoc.getType(): raise SemanticError( srcLoc.getPosition(), "Incompatible types: {} and {}".format(resultLoc.getType(), srcLoc.getType())) assert (resultLoc.getIndirLevel() == 1) assert (srcLoc.getIndirLevel() == 1 or srcLoc.getIndirLevel() == 0) t = srcLoc.getType() result = '; {} = ~{}\n'.format(resultLoc, srcLoc) if srcLoc.getIndirLevel() == 0: # constant c = srcLoc.getSource() c = ~c return Value(srcLoc.getPosition(), t, 0, c, True), result # var s = srcLoc.getSource() rs = resultLoc.getSource() if t.getSize() == 1: result += loadByte('a', srcLoc, 0) result += 'not a\n' return Value.register(srcLoc.getPosition(), t), result else: # size > 1 for offset in range(0, t.getSize(), 2): rest = t.getSize() - offset result += f''' ldi pl, lo({s} + {offset}) ldi ph, hi({s} + {offset}) ld b ''' if rest > 1: result += incP(srcLoc.isAligned()) result += 'ld a\n' result += f''' ldi pl, lo({rs} + {offset}) ldi ph, hi({rs} + {offset}) not b st b ''' if rest > 1: if resultLoc.isAligned: result += ''' inc pl not a st a ''' else: result += ''' mov b, a inc pl mov a, 0 adc ph, a not b st b ''' return resultLoc, result
def genDeref(resultLoc, srcLoc, offset=0): if resultLoc.getType().isUnknown(): resultLoc = resultLoc.removeUnknown(srcLoc.getType().deref()) if srcLoc.getType().deref() != resultLoc.getType() and not srcLoc.getType( ).deref().isStruct(): raise SemanticError( srcLoc.getPosition(), "Incompatible types for deref: {} and {}".format( srcLoc.getType().deref(), resultLoc.getType())) assert (srcLoc.getIndirLevel() <= 1) t = resultLoc.getType() if srcLoc.getIndirLevel() == 0: return Value.withOffset(srcLoc.getPosition(), resultLoc.getType(), 1, srcLoc.getSource(), True, offset), "" result = '; {} = deref {} + {}\n'.format(resultLoc, srcLoc, offset) result += '; result is {}aligned, srcLoc is {}aligned'.format( "" if resultLoc.isAligned() else "not ", "" if srcLoc.isAligned() else "not ") rs = resultLoc.getSource() if t.getSize() == 1: result += loadP(srcLoc, offset) result += 'ld a\n' return Value.register(srcLoc.getPosition(), t), result else: # t.getSize() > 1 for byteOffset in reversed(range(0, t.getSize(), 2)): rest = min(2, t.getSize() - byteOffset) result += loadP(srcLoc, byteOffset + offset) result += 'ld b\n' if rest > 1: result += ''' mov a, 0 inc pl adc ph, a ld a ''' result += f''' ldi pl, lo({rs} + {byteOffset}) ldi ph, hi({rs} + {byteOffset}) st b ''' if rest > 1: if resultLoc.isAligned(): result += ''' inc pl st a ''' else: result += f''' ldi pl, lo({rs} + {byteOffset + 1}) ldi ph, hi({rs} + {byteOffset + 1}) st a ''' return resultLoc, result
def genMod(resultLoc, src1Loc, src2Loc): resultLoc, result = _genDMCommon(resultLoc, src1Loc, src2Loc) if resultLoc.getType().getSize() == 2: result += copyW("__cc_r_remainder", resultLoc.getSource(), True, False) return resultLoc, result else: result += ''' ldi pl, lo(__cc_r_remainder) ldi ph, hi(__cc_r_remainder) ld a ''' return Value.register(src1Loc.getPosition() - src2Loc.getPosition(), resultLoc.getType()), result
def _genMulVC(resultLoc, v, c): if c == 0: return Value(resultLoc.getPosition(), v.getType(), 0, 0, True), "" elif c == 1: return v, "" t = resultLoc.getType() if t.getSize() == 1: return Value.register(v.getPosition(), t), _genMulVCByte(resultLoc.getSource(), v, c) elif t.getSize() == 2: return resultLoc, _genMulVCWord(resultLoc, v, c) elif t.getSize() == 4: return resultLoc, _genMulVCDword(resultLoc, v, c)
def _genShByteByVar(resultLoc, src1Loc, src2Loc, labelProvider, op): if src2Loc.getType().getSize() > 1: raise NatrixNotImplementedError(src2Loc.getPosition(), "Shift by variables over 8 bits") lBegin = labelProvider.allocLabel("shift_begin") lLoop = labelProvider.allocLabel("shift_loop") lEnd = labelProvider.allocLabel("shift_end") lInf = labelProvider.allocLabel("shift_inf") if src1Loc.getSource().isRegister(): raise RegisterNotSupportedError(0) result = '; {} = {} {}, {} (byte)\n'.format(resultLoc, op, src1Loc, src2Loc) result += loadByte('a', src2Loc, 0) result += f''' ldi b, 7 sub b, a ldi pl, lo({lBegin}) ldi ph, hi({lBegin}) jnc ; a <= 7 {lInf}: ''' if src1Loc.getType().getSign() and op != 'shl': result += loadByte('b', src1Loc, 0) result += ''' shl b exp b ''' else: result += 'ldi b, 0\n' result += f''' ldi pl, lo({lEnd}) ldi ph, hi({lEnd}) jmp {lBegin}: ''' result += loadByte('b', src1Loc, 0) result += f''' ldi pl, lo({lEnd}) ldi ph, hi({lEnd}) add a, 0 jz ; a == 0 {lLoop}: {op} b dec a ldi pl, lo({lLoop}) ldi ph, hi({lLoop}) jnz {lEnd}: mov a, b ''' return Value.register(resultLoc.getPosition(), resultLoc.getType()), result
def genSHLVarByConst(resultLoc, srcLoc, n): rs = resultLoc.getSource() s = srcLoc.getSource() size = srcLoc.getType().getSize() result = f'; {resultLoc} := shl {srcLoc}, {n}\n' if n == 0: return srcLoc, result if size == 1: assert (resultLoc.getType().getSize() <= 2) expandToWord = resultLoc.getType().getSize() == 2 signed = srcLoc.getType().getSign() if n >= 8: if expandToWord: n -= 8 if n >= 8: result += f''' mov a, 0 ldi pl, lo({rs}) ldi ph, hi({rs}) st a inc pl ''' if not resultLoc.isAligned(): result += 'adc ph, a\n' result += 'st a\n' else: result += loadByte('b', srcLoc, 0) for i in range(n): result += 'shl b\n' result += f''' mov a, 0 ldi pl, lo({rs}) ldi ph, hi({rs}) st a inc pl ''' if not resultLoc.isAligned(): result += 'adc ph, a\n' result += 'st b\n' else: result += f''' mov a, 0 ''' return Value.register(srcLoc.getPosition(), resultLoc.getType()), result else: result += loadByte('b', srcLoc, 0) if expandToWord: if not signed: result += 'mov a, 0\n' for i in range(n): if expandToWord: if i > 0: result += 'shl a\n' if i == 0 and signed: result += ''' shl b exp a ''' else: result += ''' shl b adc a, 0 ''' else: result += 'shl b\n' if expandToWord: result += f''' ldi pl, lo({rs}) ldi ph, hi({rs}) st b ''' if resultLoc.isAligned(): result += 'inc pl\n' else: result += f''' ldi pl, lo({rs} + 1) ldi ph, hi({rs} + 1) ''' result += 'st a\n' else: result += 'mov a, b\n' return Value.register(srcLoc.getPosition(), resultLoc.getType()), result elif size == 2: assert (resultLoc.getType().getSize() == 2) assert (not s.isRegister()) if n >= 16: result += ''' mov a, 0 ldi pl, lo({0}) ldi ph, hi({0}) st a inc pl st a '''.format(rs) elif n >= 8: result += ''' ldi pl, lo({0}) ldi ph, hi({0}) ld a '''.format(s) for i in range(n - 8): result += 'shl a\n' result += ''' ldi pl, lo({0} + 1) ldi ph, hi({0} + 1) st a mov a, 0 dec pl st a '''.format(rs) else: # 1..7 result += ''' ldi pl, lo({0}) ldi ph, hi({0}) ld b ldi pl, lo({0} + 1) ldi ph, hi({0} + 1) ld a '''.format(s) # TODO optimize aligned for i in range(n): result += ''' shl a shl b adc a, 0 ''' result += ''' ldi pl, lo({0}) ldi ph, hi({0}) st b inc pl st a '''.format(rs) else: if resultLoc != srcLoc: result += _genSHLVarByConstLarge(resultLoc, srcLoc, n) else: result += _genSHLVarByConstLargeInplace(resultLoc, n) return resultLoc, result
def genMul(resultLoc, src1Loc, src2Loc): if src1Loc.getType() != src2Loc.getType(): raise SemanticError(src1Loc.getPosition() - src2Loc.getPosition(), "multiplication types mismatch") t = src1Loc.getType() resultLoc = resultLoc.withType(t) if t.getSize() not in {1, 2, 4}: raise NatrixNotImplementedError( src1Loc.getPosition(), f"multiplication of {t.getSize() * 8}-bit integers is not implemented" ) l1 = src1Loc.getIndirLevel() l2 = src2Loc.getIndirLevel() if l1 == 0 and l2 == 0: raise NotImplementedError("doing shit with pointers?") elif l1 == 0: s = src1Loc.getSource() if s.isNumber(): return _genMulVC(resultLoc, src2Loc, int(s)) else: raise NotImplementedError("doing shit with pointers?") elif l2 == 0: s = src2Loc.getSource() if s.isNumber(): return _genMulVC(resultLoc, src1Loc, int(s)) else: raise NotImplementedError("doing shit with pointers?") else: result = '; {} = {} * {}\n'.format(resultLoc, src1Loc, src2Loc) if t.getSize() == 2: result += copyW(src1Loc.getSource(), "__cc_r_a", src1Loc.isAligned(), True) result += copyW(src2Loc.getSource(), "__cc_r_b", src2Loc.isAligned(), True) result += call("__cc_mul_word") result += copyW("__cc_r_r", resultLoc.getSource(), True, resultLoc.isAligned()) return resultLoc, result elif t.getSize() == 1: result += loadByte('b', src1Loc, 0) result += loadByte('a', src2Loc, 0) result += ''' ldi pl, lo(__cc_r_a) ldi ph, hi(__cc_r_a) st b ldi pl, lo(__cc_r_b) st a ''' result += call("__cc_mul_byte") result += ''' ldi pl, lo(__cc_r_r) ldi ph, hi(__cc_r_r) ld a ''' return Value.register( src1Loc.getPosition() - src2Loc.getPosition(), t), result elif t.getSize() == 4: result += copyW(src1Loc.getSource(), "__cc_r_a", src1Loc.isAligned(), True, 0) result += copyW(src1Loc.getSource(), "__cc_r_a", src1Loc.isAligned(), True, 2) result += copyW(src2Loc.getSource(), "__cc_r_b", src1Loc.isAligned(), True, 0) result += copyW(src2Loc.getSource(), "__cc_r_b", src1Loc.isAligned(), True, 2) result += call("__cc_mul_dword") result += copyW("__cc_r_r", resultLoc.getSource(), True, resultLoc.isAligned(), 0) result += copyW("__cc_r_r", resultLoc.getSource(), True, resultLoc.isAligned(), 2) return resultLoc, result
def _genCmpSignedLong(resultLoc, src1Loc, src2Loc, op, labelProvider): s1 = src1Loc.getSource() s2 = src2Loc.getSource() rs = resultLoc.getSource() l1 = src1Loc.getIndirLevel() l2 = src2Loc.getIndirLevel() size = src1Loc.getType().getSize() assert (not s1.isRegister()) assert (not s2.isRegister()) result = '; compare signed {} and {} ({})\n'.format(src1Loc, src2Loc, op) result += loadByte('a', src1Loc, size - 1) result += loadByte('b', src2Loc, size - 1) result += 'sub a, b\n' labelEnd = labelProvider.allocLabel("cmp_end") labelCmpLo = labelProvider.allocLabel("cmp_lo") result += f''' ldi pl, lo({labelCmpLo}) ldi ph, hi({labelCmpLo}) jz ''' # hi1 != hi2 labelO = labelProvider.allocLabel("cmp_no") result += f''' ldi pl, lo({labelO}) ldi ph, hi({labelO}) jno ''' # O is set result += 'shl a\n' # S -> C if op == 'lt' or op == 'le': # return !C result += 'ldi a, 1\n' result += 'sbb a, 0' else: # return C result += 'mov a, 0\n' result += 'adc a, 0\n' result += f''' ldi pl, lo({labelEnd}) ldi ph, hi({labelEnd}) jmp ''' result += f''' {labelO}: shl a ''' # O is clear if op == 'gt' or op == 'ge': # return !C result += 'ldi a, 1\n' result += 'sbb a, 0\n' else: # return C result += 'mov a, 0\n' result += 'adc a, 0\n' result += f''' ldi pl, lo({labelEnd}) ldi ph, hi({labelEnd}) jmp ''' # hi1 == hi2 result += f'{labelCmpLo}:\n' for offset in reversed(range(size - 1)): labelNext = labelProvider.allocLabel(f"cmp_{offset}") result += loadByte('a', src1Loc, offset) result += loadByte('b', src2Loc, offset) result += 'sub a, b\n' result += f''' ldi pl, lo({labelNext}) ldi ph, hi({labelNext}) jz ''' if op[0] == 'l': # less, return C result += 'mov a, 0\n' result += 'adc a, 0\n' else: # greater, return !C result += 'ldi a, 1\n' result += 'sbb a, 0\n' result += f''' ldi pl, lo({labelEnd}) ldi ph, hi({labelEnd}) jmp ''' result += f'{labelNext}:\n' if op == 'le' or op == 'ge': # if equal, return 1 result += 'ldi a, 1\n' else: # if equal, return 0 result += 'mov a, 0\n' result += f'{labelEnd}:\n' return Value.register(src1Loc.getPosition() - src2Loc.getPosition(), BoolType()), result
def _genIntBinary(resultLoc, src1Loc, src2Loc, opLo, opHi, pyPattern, constLambda, carryIrrelevant): if resultLoc.getType().isUnknown(): resultLoc = resultLoc.removeUnknown(src1Loc.getType()) if resultLoc.getType() != src1Loc.getType(): raise SemanticError(resultLoc.getPosition(), "Incompatible result and source types: {} and {}".format(resultLoc.getType(), src1Loc.getType())) if src1Loc.getType() != src2Loc.getType(): raise SemanticError(resultLoc.getPosition(), "Incompatible source types: {} and {}".format(src1Loc.getType(), src2Loc.getType())) assert(resultLoc.getIndirLevel() == 1) l1 = src1Loc.getIndirLevel() l2 = src2Loc.getIndirLevel() assert(l1 == 0 or l1 == 1) assert(l2 == 0 or l2 == 1) if l1 == 0 and l2 == 0: raise RuntimeError("Case unhandled by ConstTransformer") t = resultLoc.getType() rs = resultLoc.getSource() s1 = src1Loc.getSource() s2 = src2Loc.getSource() result = "; {} = {} {}, {}\n".format(resultLoc, opLo, src1Loc, src2Loc) if t.getSize() <= 2: isWord = t.getSize() == 2 if not isWord: assert(not s1.isRegister() or not s2.isRegister()) result += loadByte('b', src1Loc, 0) result += loadByte('a', src2Loc, 0) result += f''' {opLo} b, a mov a, b ''' return Value.register(resultLoc.getPosition(), t), result # now it's a word if l2 == 0: # var + const result += ''' ldi pl, lo({0}) ldi ph, hi({0}) ld b '''.format(src1Loc.getSource()) result += incP(src1Loc.isAligned()) c = src2Loc.getSource() result += loadByte('a', src2Loc, 0) result += f''' {opLo} b, a ld a ''' result += loadByte('pl', src2Loc, 1) result += f'{opHi} a, pl\n' result += ''' ldi pl, lo({0}) ldi ph, hi({0}) st b '''.format(rs) if resultLoc.isAligned(): result += ''' inc pl st a ''' else: result += f''' ldi pl, lo({rs} + 1) ldi ph, hi({rs} + 1) st a ''' elif l1 == 0: # const + var result += ''' ldi pl, lo({0}) ldi ph, hi({0}) ld a '''.format(src2Loc.getSource()) if src2Loc.isAligned(): result += 'inc pl\n' else: result += ''' ldi pl, lo({0} + 1) ldi ph, hi({0} + 1) '''.format(src2Loc.getSource()) result += loadByte('b', src1Loc, 0) result += f''' {opLo} b, a ld pl ''' result += loadByte('a', src1Loc, 1) result += f''' {opHi} a, pl ldi pl, lo({rs}) ldi ph, hi({rs}) st b '''.format(rs) if resultLoc.isAligned(): result += ''' inc pl st a ''' else: result += f''' ldi pl, lo({rs} + 1) ldi ph, hi({rs} + 1) st a ''' else: # var + var s1 = src1Loc.getSource() s2 = src2Loc.getSource() result += ''' ldi pl, lo({0}) ldi ph, hi({0}) ld b ldi pl, lo({1}) ldi ph, hi({1}) ld a {2} b, a '''.format(s1, s2, opLo) if isWord: result += ''' ldi pl, lo({0} + 1) ldi ph, hi({0} + 1) ld a ldi pl, lo({1} + 1) ldi ph, hi({1} + 1) ld pl {2} a, pl '''.format(s1, s2, opHi) result += ''' ldi pl, lo({0}) ldi ph, hi({0}) st b '''.format(rs) if isWord: if resultLoc.isAligned(): result += ''' inc pl st a ''' else: result += ''' ldi pl, lo({0} + 1) ldi ph, hi({0} + 1) st a '''.format(rs) else: # size > 2 # This produces less size-efficient code than above, that's why it's not a general case. for offset in range(0, t.getSize(), 2): rest = t.getSize() - offset result += loadByte('a', src1Loc, offset) if rest > 1: if l1 == 0: result += loadByte('b', src1Loc, offset + 1) else: if (offset == 0 or carryIrrelevant) and src1Loc.isAligned(): # can trash flags if it's the first pair of bytes or if op is like and, or, xor result += ''' inc pl ld b ''' else: # must preserve flags if src1Loc.isAligned(): result += f''' ldi pl, lo({s1} + {offset + 1}) ld b ''' else: result += loadByte('b', src1Loc, offset + 1) result += loadByte('pl', src2Loc, offset) if offset == 0: result += f'{opLo} a, pl\n' else: result += f'{opHi} a, pl\n' result += f''' ldi pl, lo({rs} + {offset}) ldi ph, hi({rs} + {offset}) st a ''' if rest > 1: result += loadByte('a', src2Loc, offset + 1) result += f''' {opHi} b, a ldi pl, lo({rs} + {offset + 1}) ldi ph, hi({rs} + {offset + 1}) st b ''' return resultLoc, result
def _genCmpSignedByte(resultLoc, src1Loc, src2Loc, op, labelProvider): s1 = src1Loc.getSource() s2 = src2Loc.getSource() rs = resultLoc.getSource() l1 = src1Loc.getIndirLevel() l2 = src2Loc.getIndirLevel() result = '; compare signed bytes {} and {} ({})\n'.format( src1Loc, src2Loc, op) result += loadByte('b', src1Loc, 0) result += loadByte('a', src2Loc, 0) result += 'sub b, a\n' labelEnd = labelProvider.allocLabel("cmp_end") if op == 'lt' or op == 'gt': # if equal, return 0 result += 'mov a, 0\n' else: # if equal, return 1 result += 'ldi a, 1\n' result += ''' ldi pl, lo({0}) ldi ph, hi({0}) jz '''.format(labelEnd) labelO = labelProvider.allocLabel("cmp_no") result += ''' ldi pl, lo({0}) ldi ph, hi({0}) jno '''.format(labelO) # O is set result += 'shl b\n' # S -> C if op == 'lt' or op == 'le': # return !C if op == 'lt': result += 'ldi a, 1\n' # else a is already 1 result += 'sbb a, 0' else: # return C if op == 'ge': result += 'mov a, 0\n' # else a is already 0 result += 'adc a, 0\n' result += ''' ldi pl, lo({0}) ldi ph, hi({0}) jmp '''.format(labelEnd) result += ''' {0}: shl b '''.format(labelO) # O is clear if op == 'gt' or op == 'ge': # return !C if op == 'gt': result += 'ldi a, 1\n' # else a is already 1 result += 'sbb a, 0\n' else: # return C if op == 'le': result += 'mov a, 0\n' # else a is already 0 result += 'adc a, 0\n' result += '{}:\n'.format(labelEnd) return Value.register(src1Loc.getPosition() - src2Loc.getPosition(), BoolType()), result
def _genCmpUnsigned(resultLoc, src1Loc, src2Loc, op, labelProvider): s1 = src1Loc.getSource() s2 = src2Loc.getSource() rs = resultLoc.getSource() l1 = src1Loc.getIndirLevel() l2 = src2Loc.getIndirLevel() t = src1Loc.getType() result = '; compare unsigned {} and {} ({})\n'.format(src1Loc, src2Loc, op) if l1 == 0 and l2 == 0: # const and const pos = src1Loc.getPosition() - src2Loc.getPosition() if s1.isNumber() and s2.isNumber(): pyop = { "gt": operator.gt, "lt": operator.lt, "ge": operator.ge, "le": operator.le }[op] return Value(pos, BoolType(), 0, int(pyop(int(s1), int(s2))), True), result else: pyop = {"gt": ">", "lt": "<", "ge": ">=", "le": "<="}[op] return Value(pos, BoolType(), 0, "int(({}) {} ({}))".format(s1, pyop, s2), True), result if t.getSize() <= 2: if op == 'lt' or op == 'ge': result = _genCmpSubFlags(src1Loc, src2Loc) else: result += _genCmpSub(src1Loc, src2Loc, op) # C = carry flag # Z = (b | pl) == 0 if op == 'lt': # 1 if C result += ''' mov a, 0 adc a, 0 ''' elif op == 'ge': # 1 if !C result += ''' ldi a, 1 sbb a, 0 ''' elif op == 'le': # 1 if C or !(b | pl) result += ''' exp ph ; ph = C (0xff or 0x00) mov a, pl or a, b dec a ; C = Z exp a ; a = Z (0xff or 0x00) or a, ph ; a = C | Z (0xff or 0x00) ldi b, 1 and a, b ''' elif op == 'gt': # 1 if !(C or Z) result += ''' exp ph ; ph = C (0xff or 0x00) mov a, pl or a, b dec a ; C = Z exp a ; a = Z (0xff or 0x00) or a, ph ; a = C | Z (0xff or 0x00) not a ldi b, 1 and a, b ''' return Value.register(src1Loc.getPosition() - src2Loc.getPosition(), BoolType()), result else: # large size and le or gt if op[0] == 'g': src1Loc, src2Loc = src2Loc, src1Loc result += _genCmpSubFlags(src1Loc, src2Loc) # C if s1 < s2 if op[1] == 't': result += ''' mov a, 0 adc a, 0 ''' return Value.register( src1Loc.getPosition() - src2Loc.getPosition(), BoolType()), result else: labelEnd = labelProvider.allocLabel("cmp_end") result += f''' ldi b, 1 ldi pl, lo({labelEnd}) ldi ph, hi({labelEnd}) jc dec b ''' for offset in range(t.getSize()): result += loadByte('a', src1Loc, offset) result += loadByte('pl', src2Loc, offset) result += f''' sub a, pl ldi pl, lo({labelEnd}) ldi ph, hi({labelEnd}) jnz ''' # invert b result += f''' inc b {labelEnd}: mov a, b ''' return Value.register( src1Loc.getPosition() - src2Loc.getPosition(), BoolType()), result
def _genSARVarByConst(resultLoc, srcLoc, n): rs = resultLoc.getSource() s = srcLoc.getSource() size = srcLoc.getType().getSize() result = '; sar {}, {}\n'.format(srcLoc, n) if n == 0: return srcLoc, result if size == 1: if n >= 8: result += loadByte('a', srcLoc, 0) result += ''' shl a mov a, 0 sbb a, 0 '''.format(s, rs) else: result += loadByte('a', srcLoc, 0) for i in range(n): result += 'sar a\n' return Value.register(srcLoc.getPosition(), resultLoc.getType()), result elif size == 2: assert (not s.isRegister()) if n >= 16: result += ''' ldi pl, lo({0} + 1) ldi ph, hi({0} + 1) ld a shl a mov a, 0 sbb a, 0 ldi pl, lo({1}) ldi ph, hi({1}) st a inc pl st a '''.format(s, rs) elif n >= 8: result += ''' ldi pl, lo({0} + 1) ldi ph, hi({0} + 1) ld a ld b shl a mov a, 0 sbb a, 0 '''.format(s) for i in range(n - 8): result += 'sar b\n' result += ''' ldi pl, lo({0}) ldi ph, hi({0}) st b inc pl st a '''.format(rs) else: # 1..7 result += ''' ldi pl, lo({0}) ldi ph, hi({0}) ld b ldi pl, lo({0} + 1) ldi ph, hi({0} + 1) ld a mov pl, a '''.format(s) # TODO optimize aligned for i in range(n): result += ''' sar a shr b ''' for i in range(8 - n): result += 'shl pl\n' result += ''' mov ph, a mov a, pl or b, a mov a, ph ldi pl, lo({0}) ldi ph, hi({0}) st b inc pl st a '''.format(rs) return resultLoc, result else: assert (not s.isRegister()) return _genSARVarByConstLarge(resultLoc, srcLoc, n)
def _genCmpSignedWord(resultLoc, src1Loc, src2Loc, op, labelProvider): s1 = src1Loc.getSource() s2 = src2Loc.getSource() rs = resultLoc.getSource() l1 = src1Loc.getIndirLevel() l2 = src2Loc.getIndirLevel() assert (not s1.isRegister()) assert (not s2.isRegister()) result = '; compare signed words {} and {} ({})\n'.format( src1Loc, src2Loc, op) if l1 == 0: result += 'ldi b, hi({})\n'.format(s1) else: result += ''' ldi pl, lo({0} + 1) ldi ph, hi({0} + 1) ld b '''.format(s1) if l2 == 0: result += 'ldi a, hi({})\n'.format(s2) else: result += ''' ldi pl, lo({0} + 1) ldi ph, hi({0} + 1) ld a '''.format(s2) result += 'sub b, a\n' labelEnd = labelProvider.allocLabel("cmp_end") labelCmpLo = labelProvider.allocLabel("cmp_lo") result += ''' ldi pl, lo({0}) ldi ph, hi({0}) jz '''.format(labelCmpLo) # hi1 != hi2 labelO = labelProvider.allocLabel("cmp_no") result += ''' ldi pl, lo({0}) ldi ph, hi({0}) jno '''.format(labelO) # O is set result += 'shl b\n' # S -> C if op == 'lt' or op == 'le': # return !C result += 'ldi a, 1\n' result += 'sbb a, 0' else: # return C result += 'mov a, 0\n' result += 'adc a, 0\n' result += ''' ldi pl, lo({0}) ldi ph, hi({0}) jmp '''.format(labelEnd) result += ''' {0}: shl b '''.format(labelO) # O is clear if op == 'gt' or op == 'ge': # return !C result += 'ldi a, 1\n' result += 'sbb a, 0\n' else: # return C result += 'mov a, 0\n' result += 'adc a, 0\n' result += ''' ldi pl, lo({0}) ldi ph, hi({0}) jmp '''.format(labelEnd) # hi1 == hi2 (= a) result += '{}:\n'.format(labelCmpLo) # if sign, compare low parts as unsigned if l1 == 0: result += 'ldi b, lo({})\n'.format(s1) else: result += ''' ldi pl, lo({0}) ldi ph, hi({0}) ld b '''.format(s1) if l2 == 0: result += 'ldi a, lo({})\n'.format(s2) else: result += ''' ldi pl, lo({0}) ldi ph, hi({0}) ld a '''.format(s2) result += 'sub b, a\n' if op == 'le' or op == 'ge': # if equal, return 1 result += 'ldi a, 1\n' else: # if equal, return 0 result += 'mov a, 0\n' result += ''' ldi pl, lo({0}) ldi ph, hi({0}) jz '''.format(labelEnd) if op[0] == 'l': # less, return C if op == 'le': result += 'mov a, 0\n' # else a is already 0 result += 'adc a, 0\n' else: # greater, return !C if op == 'gt': result += 'ldi a, 1\n' # else a is already 1 result += 'sbb a, 0\n' result += '{}:\n'.format(labelEnd) return Value.register(src1Loc.getPosition() - src2Loc.getPosition(), BoolType()), result
def _genBoolBinary(resultLoc, src1Loc, src2Loc, op, pyPattern, constLambda): if resultLoc.getType().isUnknown(): resultLoc = resultLoc.removeUnknown(src1Loc.getType()) if resultLoc.getType() != src1Loc.getType(): raise SemanticError(resultLoc.getPosition(), "Incompatible result and source types: {} and {}".format(resultLoc.getType(), src1Loc.getType())) if src1Loc.getType() != src2Loc.getType(): raise SemanticError(src1Loc.getPosition() - src2Loc.getPosition(), "Incompatible source types: {} and {}".format(src1Loc.getType(), src2Loc.getType())) if src1Loc.getType() != BoolType(): raise SemanticError(src1Loc.getPosition(), "Bool type (u8) expected") assert(resultLoc.getIndirLevel() == 1) rs = resultLoc.getSource() s1 = src1Loc.getSource() s2 = src2Loc.getSource() l1 = src1Loc.getIndirLevel() l2 = src2Loc.getIndirLevel() result = '; {} = bool {} {}, {}\n'.format(resultLoc, op, src1Loc, src2Loc) pos = src1Loc.getPosition() - src2Loc.getPosition() if l1 == 0 and l2 == 0: # const and const if s1.isNumber() and s2.isNumber(): return Value(pos, BoolType(), 0, int(constLambda(bool(s1), bool(s2))), True), result else: return Value(pos, BoolType(), 0, pyPattern.format(s1, s2), True), result elif l1 == 0 or l2 == 0: # var and const if l1 == 0: s1, s2 = s2, s1 src1Loc, src2Loc = src2Loc, src1Loc if s2.isNumber(): s2 = bool(s2) if op == 'or': if not s2: return src1Loc, "" else: return Value(pos, BoolType(), 0, 1, True), "" elif op == 'and': if not s2: return Value(pos, BoolType(), 0, 0, True), "" else: return src1Loc, result else: raise RuntimeError("Unhandled binary boolean op: {}".format(op)) else: result += loadByte('a', src1Loc, 0) result += f''' dec a ldi a, 1 sbb a, 0 ldi b, int(bool({s2})) {op} a, b ''' else: # var and var assert(not s1.isRegister() or not s2.isRegister()) if s2.isRegister(): s1, s2 = s2, s1 src1Loc, src2Loc = src2Loc, src1Loc result += loadByte('a', src1Loc, 0) result += f''' dec a ldi a, 1 sbb a, 0 mov b, a ''' result += loadByte('a', src2Loc, 0) result += f''' dec a ldi a, 1 sbb a, 0 {op} a, b ''' return Value.register(src1Loc.getPosition() - src2Loc.getPosition(), BoolType()), result