Esempio n. 1
0
File: test.py Progetto: lsrcz/SyGuS
def my_test(cmd, outputfile, testname, timeout=300):
    outputfile.write('\t%s:' % (testname))
    print(cmd)
    try:
        result, rtime = run_command(cmd, timeout)
    except TimeoutError:
        outputfile.write('timeout after %i \n' % (timeout))
    else:
        print(result)
        benchmarkFile = open(testname)
        bm = stripComments(benchmarkFile)
        bmExpr = sexp.sexp.parseString(
            bm, parseAll=True).asList()[0]  # Parse string to python list
        checker = translator.ReadQuery(bmExpr)
        try:
            checkresult = checker.check(str(result, encoding='utf-8'))
        except Exception as e:
            # outputfile.write('Wrong Answer: Invalid check result %s(%f)\n' %(result,rtime))
            outputfile.write('Invalid format(%f)\n' % (rtime))
        else:
            if (checkresult == None):
                outputfile.write('Passed(%f)\n' % (rtime))
            else:
                # outputfile.write('Wrong Answer: get %s(%f)\n' %(result,rtime))
                outputfile.write('Wrong Answer(%f)\n' % (rtime))
Esempio n. 2
0
File: task.py Progetto: lsrcz/SyGuS
    def __init__(self, filename, instrument=True):
        '''
        :param filename:
        There should be only one instance active
        '''
        benchmarkFile = open(filename)
        bm = self._stripComments(benchmarkFile)
        self.ori = TaskData()
        self.ins = None
        self.ori.bmExpr = sexp.sexp.parseString(bm, parseAll=True).asList()[0]
        self.ori.bmLogic = None
        self.ori.bmSyn = None
        self.ori.bmDeclvar = []
        self.ori.bmConstraint = []
        self.ori.bmDefineFun = []
        for expr in self.ori.bmExpr:
            if len(expr) == 0:
                continue
            elif expr[0] == 'set-logic':
                self.ori.bmLogic = expr
            elif expr[0] == 'synth-fun':
                self.ori.bmSyn = expr
            elif expr[0] == 'declare-var':
                self.ori.bmDeclvar.append(expr)
            elif expr[0] == 'constraint':
                self.ori.bmConstraint.append(expr)
            elif expr[0] == 'define-fun':
                self.ori.bmDefineFun.append(expr)
        # TODO: defining a function
        assert len(self.ori.bmDefineFun) == 0

        self._initialize_production()
        Expr.productions = self.productions
        self.ori.z3checker = translator.ReadQuery(self.ori.bmExpr)
        self.ori.inputlist, self.ori.inputtylist, self.ori.inputmap = _get_inputlists(
            self.ori.bmDeclvar)
        self.ori.constraintlist = _get_constraintlist(self.ori.bmConstraint,
                                                      self.functionProto)
        self.ori.constraint = _constraintlist2conjunction(
            self.ori.constraintlist)
        self.ori.vartab = _getvartab(self.ori.inputlist, self.ori.inputtylist)

        self.ori.semchecker = SemChecker(self.functionProto,
                                         self.ori.constraint,
                                         self.ori.inputlist,
                                         self.ori.inputtylist)
        if instrument:
            self.ins = TaskData()
            self.ins.bmDefineFun = self.ori.bmDefineFun
            self.ins.bmLogic = self.ori.bmLogic
            self.ins.bmSyn = self.ori.bmSyn
            self._initialize_instrument()
Esempio n. 3
0
def getPossibleValue(Operators, Expr, Terminals):
    SynFunExpr, VarTable, FunDefMap, Constraints = translator.ReadQuery(Expr)
    returnType = SynFunExpr[3]

    functionCallDic = {}
    ReplacedConsInfo = []
    for i in range(len(Constraints)):
        ReplacedConsInfo.append(
            replaceFunctionCall(Constraints[i], functionCallDic, SynFunExpr[1],
                                SynFunExpr[3], VarTable))
    ReplacedConsSet = getConsSet(ReplacedConsInfo)
    # pprint.pprint(ReplacedConsSet)

    resultSet = []
    argVarTable = {}
    for arg in SynFunExpr[2]:
        declareVar(arg[1], arg[0], argVarTable)

    Samples = []
    sampleNum = 30
    for _ in range(sampleNum):
        sample = {}
        for arg in SynFunExpr[2]:
            value = False
            if arg[1] == 'Bool':
                value = random.randint(0, 1) == 0
            else:
                value = random.randint(0, 100)
            sample[arg[0]] = value
        Samples.append(sample)
    Value = ValueSet(argVarTable, Samples, Operators)

    if returnType == 'Bool':
        resultSet = ['true', 'false']
    else:
        depth = 0
        argVarTable["__result"] = Int("__result")
        for terminal in Terminals['Int']:
            Value.addNewValue(terminal, depth)
        #print(Value)
        while True:
            resultSet += Value.get(depth)
            #print(resultSet)
            if isValueSetFull(resultSet, functionCallDic, SynFunExpr[2],
                              VarTable, ReplacedConsSet):
                break
            depth += 1

    resultSet = simplifyResultSet(resultSet, [], functionCallDic,
                                  SynFunExpr[2], VarTable, ReplacedConsSet)
    return resultSet, Value
Esempio n. 4
0
File: task.py Progetto: lsrcz/SyGuS
 def _initialize_instrument(self):
     cnf = buildCNF(self.ori.constraint, self.ori.vartab,
                    self.functionProto)
     cnfsuffix = _instrument_add_suffix(cnf, self.ori.inputmap)
     cnfconstrains = getCNFclause(cnfsuffix)
     if not all(map(_check_single_call, cnfconstrains)):
         self.ins = self.ori
         return
     constraintlist, eqmaplist = zip(*list(
         map(lambda x: _instrument_constraint(x, self.functionProto),
             cnfconstrains)))
     self.ins.constraintlist = _merge_eqclass(constraintlist, eqmaplist)
     self.ins.constraint = _assembly_clauses_with('and',
                                                  self.ins.constraintlist)
     self.ins.inputlist, self.ins.inputtylist, self.ins.inputmap = \
         _get_instrumented_varlist(self.ins.constraint, self.functionProto)
     self.ins.bmDeclvar = []
     for m in self.ins.inputmap:
         self.ins.bmDeclvar.append(['declare-var', m, self.ins.inputmap[m]])
     self.ins.bmConstraint = []
     for c in self.ins.constraintlist:
         self.ins.bmConstraint.append(['constraint', exprToList(c)])
     self.ins.bmExpr = \
         [self.ins.bmLogic] + \
         self.ins.bmDeclvar + \
         self.ins.bmDefineFun + \
         [self.ins.bmSyn] + \
         self.ins.bmConstraint + \
         [self.ins.bmEnd]
     self.ins.z3checker = translator.ReadQuery(self.ins.bmExpr)
     self.ins.vartab = _getvartab(self.ins.inputlist, self.ins.inputtylist)
     self.ins.semchecker = SemChecker(self.functionProto,
                                      self.ins.constraint,
                                      self.ins.inputlist,
                                      self.ins.inputtylist)
     i = 1
Esempio n. 5
0
File: main.py Progetto: thwfhk/SyGuS
def stripComments(bmFile):
    noComments = '('
    for line in bmFile:
        line = line.split(';', 1)[0]
        noComments += line
    return noComments + ')'


if __name__ == '__main__':
    benchmarkFile = open(sys.argv[1])
    bm = stripComments(benchmarkFile)
    # Parse string to python list
    bmExpr = sexp.sexp.parseString(bm, parseAll=True).asList()[0] 
    # pprint.pprint(bmExpr)
    checker = translator.ReadQuery(bmExpr)
    SynFunExpr = []
    StartSym = 'My-Start-Symbol' #virtual starting symbol
    for expr in bmExpr:
        if len(expr)==0:
            continue
        elif expr[0]=='synth-fun':
            SynFunExpr=expr
    print("Function to Synthesize: ", SynFunExpr)
    FuncDefine = ['define-fun']+SynFunExpr[1:4] #copy function signature
    FuncDefineStr = translator.toString(FuncDefine,ForceBracket = True)
    # use Force Bracket = True on function definition. MAGIC CODE. DO NOT MODIFY THE ARGUMENT ForceBracket = True.
    #print(FuncDefine)
    BfsQueue = [[StartSym]] #Top-down
    Productions = {StartSym:[]}
    Type = {StartSym:SynFunExpr[3]} # set starting symbol's return type
Esempio n. 6
0
File: main.py Progetto: lsrcz/SyGuS
def getTermCondition(Expr, TermInfo, currentTerm, RemainTerms, ConsTable,
                     VarTable):
    SynFunExpr, VarTable, FunDefMap, Constraints = translator.ReadQuery(Expr)
    inputVarTable = VarTable.copy()

    functionCallDic = {}
    ReplacedConsInfo = []
    for i in range(len(Constraints)):
        ReplacedConsInfo.append(
            replaceFunctionCall(Constraints[i], functionCallDic, SynFunExpr[1],
                                SynFunExpr[3], VarTable))
    ReplacedConsSet = getConsSet(ReplacedConsInfo)
    assert len(ReplacedConsSet) == 1 and len(ReplacedConsSet[0][0]) == 1

    ReplacedCons = ReplacedConsSet[0][1]
    # print(functionCallDic)
    functionCallVar = None
    functionArgs = None
    for functionCallId in functionCallDic:
        functionCallVar, functionArgs = functionCallDic[functionCallId]
    # print(functionCallVar, functionArgs)

    exampleGenerator = Solver()
    checkSolver = Solver()
    for condition, term in TermInfo:
        spec = "(assert (not %s))" % (translator.toString(
            replaceArgs(condition, SynFunExpr[2], functionArgs)))
        spec = parse_smt2_string(spec, decls=VarTable)
        exampleGenerator.add(spec)
        checkSolver.add(spec)
    for term in RemainTerms:
        spec = "(assert (not (= %s %s)))" % (
            str(functionCallVar),
            translator.toString(replaceArgs(term, SynFunExpr[2],
                                            functionArgs)))
        # print(spec)
        exampleGenerator.add(parse_smt2_string(spec, decls=VarTable))
    spec = "(assert (= %s %s))" % (str(functionCallVar),
                                   translator.toString(
                                       replaceArgs(currentTerm, SynFunExpr[2],
                                                   functionArgs)))
    spec = parse_smt2_string(spec, decls=VarTable)
    exampleGenerator.add(spec)
    checkSolver.add(spec)
    spec = "\n".join(
        list(
            map(lambda x: "(assert %s)" % (translator.toString(x[1:])),
                ReplacedCons)))
    spec = parse_smt2_string(spec, decls=VarTable)
    exampleGenerator.add(spec)
    checkSolver.add(Not(And(spec)))
    # print(checkSolver)

    depth = 0
    result = []
    qualityConsNum = 3
    inputVars = []
    for var in inputVarTable:
        inputVars.append(inputVarTable[var])

    Examples = []
    currentCondition = []
    while True:
        exampleGenerator.push()
        if len(currentCondition) > 0:
            spec = "(assert (not %s))" % (translator.toString(
                replaceArgs(currentCondition, SynFunExpr[2], functionArgs)))
            exampleGenerator.add(parse_smt2_string(spec, decls=VarTable))

        exampleResult = exampleGenerator.check()
        if exampleResult == unsat:
            break
        exampleGenerator.push()
        for __ in range(1, qualityConsNum):
            lVar = inputVars[random.randint(0, len(inputVars) - 1)]
            rVar = inputVars[random.randint(0, len(inputVars) - 1)]
            exampleGenerator.push()
            exampleGenerator.add(lVar > rVar + 5)
            if exampleGenerator.check() == sat:
                exampleGenerator.pop()
                exampleGenerator.add(lVar > rVar + 5)
            else:
                exampleGenerator.pop()
        exampleGenerator.check()
        example = exampleGenerator.model()
        exampleGenerator.pop()
        exampleGenerator.pop()
        Examples.append(example)

        BestCons = None
        isChange = False
        while len(Examples) > 0:
            suitableCons = ConsTable.filter(depth, Examples)
            if checkValid(checkSolver, suitableCons, VarTable, SynFunExpr[2],
                          functionArgs):
                BestCons = suitableCons
                break
            Examples = Examples[1:]
            isChange = True
        if isChange and len(currentCondition) > 0:
            if len(result) == 0:
                result = currentCondition
            else:
                result = ["or", result, currentCondition]
            spec = "(assert (not %s))" % (translator.toString(
                replaceArgs(currentCondition, SynFunExpr[2], functionArgs)))
            exampleGenerator.add(parse_smt2_string(spec, decls=VarTable))
            currentCondition = []
        # input()
        # print(Examples)
        if BestCons is None:
            depth += 1
            continue

        reducedCondition = reduceCons(checkSolver, BestCons, [], VarTable,
                                      SynFunExpr[2], functionArgs, True)
        currentCondition = reformatListCons(reducedCondition)
        if len(currentCondition) == 0:
            return []
    if len(result) == 0:
        result = currentCondition
    else:
        result = ["or", result, currentCondition]
    return result
Esempio n. 7
0
File: main.py Progetto: lsrcz/SyGuS
        NTType = NonTerm[1]
        assert NTType in ['Int', 'Bool']
        if NTType == Type[StartSym]:
            Productions[StartSym].append(NTName)
        Type[NTName] = NTType
        # Productions[NTName] = NonTerm[2]
        Productions[NTName] = []
        for NT in NonTerm[2]:
            if type(NT) == tuple:
                Productions[NTName].append(
                    str(NT[1])
                )  # deal with ('Int',0). You can also utilize type information, but you will suffer from these tuples.
            else:
                Productions[NTName].append(NT)

    SynFunExpr, VarTable, _, Constraints, checker = translator.ReadQuery(
        bmExpr, True)
    previousSynFunExpr, _, _, previousConstraints = translator.ReadQuery(
        task.ori.bmExpr)

    for NonTerm in SynFunExpr[4]:
        for NT in NonTerm[2]:
            current = NT
            if type(NT) == tuple:
                current = str(NT[1])
            if type(current) == str:
                if current not in Type and current not in Terminals[
                        NonTerm[1]]:
                    Terminals[NonTerm[1]].append(current)

    Operators = []
    for operatorType in defaultOperators:
Esempio n. 8
0
def trySolve(Terminals, Operators, ReturnType, bmExpr):
    # print(Operators)
    InputVar = []
    OutputVar = []
    OperatorTypeVar = []
    TypeNumber = {'Bool': 0, 'Int': 0}
    MidNumber = max(len(Terminals['Bool']), len(Terminals['Int']))
    StartNumber = MidNumber
    SynFunExpr, VarTable, FunDefMap, Constraints = translator.ReadQuery(bmExpr)
    inputVars = list(VarTable.keys())
    inputVarTable = VarTable.copy()
    qualityCons = []
    for i in range(len(inputVars)):
        x = VarTable[inputVars[i]]
        qualityCons.append(x > 4)
        for j in range(len(inputVars)):
            if i != j:
                y = VarTable[inputVars[j]]
                qualityCons.append(x > y + 4)
    random.shuffle(qualityCons)

    for operator in Operators:
        outputType = operator[1]
        OutputVar.append(
            declareVar('Int', getId(outputType, TypeNumber[outputType]),
                       VarTable))
        TypeNumber[outputType] += 1
        MidNumber += 1
        currentInputVar = []
        for arg in operator[2:]:
            if type(arg) == list:
                inputType = arg[0]
                currentInputVar.append(
                    declareVar('Int', getId(inputType, TypeNumber[inputType]),
                               VarTable))
                TypeNumber[inputType] += 1
        InputVar.append(currentInputVar)
    resultVar = declareVar('Int', getId(ReturnType, TypeNumber[ReturnType]),
                           VarTable)
    TypeNumber[ReturnType] += 1

    s1 = Solver()
    s2 = Solver()
    s3 = Solver()

    functionCallDic = {}
    ReplacedCons = []
    for i in range(len(Constraints)):
        ReplacedCons.append(
            replaceFunctionCall(Constraints[i], functionCallDic, SynFunExpr[1],
                                SynFunExpr[3], VarTable))
    spec = "\n".join(
        list(
            map(lambda x: "(assert %s)" % translator.toString(x[1:]),
                ReplacedCons)))
    spec = parse_smt2_string(spec, decls=VarTable)
    s1.add(spec)

    inputQualityCons = []
    for constraint in qualityCons:
        inputQualityCons.append(constraint)
        s1.push()
        s1.add(And(inputQualityCons))
        currentRes = s1.check()
        s1.pop()
        if currentRes == unsat:
            inputQualityCons = inputQualityCons[:-1]

    s1.push()
    s1.add(inputQualityCons)
    s1.check()
    currentModel = s1.model()
    s1.pop()
    Models = []
    s1 = Solver()

    ArgumentDict = {}
    # print(Terminals)
    argId = -1
    for arg in SynFunExpr[2]:
        argId += 1
        ArgumentDict[arg[0]] = argId

    SimplifyOption = False
    for operator in Operators:
        if '+' or '-' in operator[0] and "0" in Terminals['Int']:
            SimplifyOption = True
        if '*' or '/' or 'div' in operator[0] and "1" in Terminals['Int']:
            SimplifyOption = True

    for i in range(len(Operators)):
        OperatorTypeVar.append(
            declareVar('Int', getId("operatorType", i), VarTable))
        outputVar = OutputVar[i]
        operator = Operators[i]
        operatorTypeVar = OperatorTypeVar[i]
        s3.add(outputVar >= StartNumber)
        s3.add(outputVar < MidNumber)
        s3.add(operatorTypeVar >= 0)
        s3.add(operatorTypeVar < len(operator[0]))
        for inputVar in InputVar[i]:
            s3.add(outputVar > inputVar)
            if 'Bool' in str(inputVar):
                currentInputType = 'Bool'
            else:
                currentInputType = 'Int'
            if SimplifyOption and currentInputType == 'Int' and operator[
                    1] == 'Bool':
                # print(operator)
                currentCons = []
                terminalId = -1
                for terminal in Terminals['Int']:
                    terminalId += 1
                    try:
                        int(terminal)
                    except:
                        # print(terminal)
                        currentCons += [inputVar == terminalId]
            else:
                currentCons = [
                    And(inputVar >= 0,
                        inputVar < len(Terminals[currentInputType]))
                ]
            for j in range(len(Operators)):
                if Operators[j][1] == currentInputType:
                    currentCons.append(inputVar == OutputVar[j])
            s3.add(Or(currentCons))
    currentCons = []
    for i in range(len(Operators)):
        if Operators[i][1] == ReturnType:
            currentCons.append(resultVar == OutputVar[i])
    # print "fin ", currentCons
    s3.add(Or(currentCons))
    for i in range(len(Operators)):
        for j in range(i + 1, len(Operators)):
            s3.add(OutputVar[i] != OutputVar[j])
    # print(Operators)
    # print(Terminals)
    # print(MidNumber)
    # print len(Terminals['Int']), len(Terminals['Bool'])
    # print(inputQualityCons)
    # print(Models)
    # print(Terminals)
    while True:
        s3.push()
        callId = -1
        for currentModel in Models:
            newVarTable = VarTable.copy()
            currentOuterCons = ReplacedCons.copy()
            for functionCallName in functionCallDic:
                returnValueVar, CurrentArguments = functionCallDic[
                    functionCallName]
                InputValueVar = []
                OutputValueVar = []
                callId += 1
                ValueTypeNumber = {
                    'Bool': len(Terminals['Bool']),
                    'Int': len(Terminals['Int'])
                }
                for i in range(len(Operators)):
                    operator = Operators[i]
                    outputType = operator[1]
                    outputValueVar = declareVar(
                        outputType,
                        getId(outputType + str(callId) + "-",
                              ValueTypeNumber[outputType]), newVarTable)
                    OutputValueVar.append(outputValueVar)
                    ValueTypeNumber[outputType] += 1
                    currentInputValue = []
                    operatorTypeVar = OperatorTypeVar[i]
                    InputValueVarTable = []
                    for arg in operator[2:]:
                        if type(arg) == list:
                            inputType = arg[0]
                            inputValueVar = declareVar(
                                inputType,
                                getId(inputType + str(callId) + "-",
                                      ValueTypeNumber[inputType]), newVarTable)
                            InputValueVarTable.append(inputValueVar)
                            ValueTypeNumber[inputType] += 1
                            currentInputValue.append(inputValueVar)
                    # print(currentInputValue)
                    InputValueVar.append(currentInputValue)
                    for typeId in range(len(operator[0])):
                        currentCons = [operator[0][typeId]]
                        inputValueVarId = -1
                        for arg in operator[2:]:
                            if type(arg) == list:
                                inputValueVarId += 1
                                currentCons.append(
                                    str(InputValueVarTable[inputValueVarId]))
                            else:
                                currentCons.append(arg)
                        currentCons = ["=", str(outputValueVar), currentCons]
                        currentCons = [
                            "=>", ["=", str(typeId),
                                   str(operatorTypeVar)], currentCons
                        ]
                        spec = '(assert %s)' % (
                            translator.toString(currentCons))
                        spec = parse_smt2_string(spec, decls=dict(newVarTable))
                        # print(spec[0])
                        s3.add(spec[0])
                # print "CurrentArg: ", CurrentArguments
                # print ArgumentDict
                # print(InputValueVar)
                for i in range(len(Operators)):
                    for j in range(len(InputVar[i])):
                        inputVar = InputVar[i][j]
                        inputValue = InputValueVar[i][j]
                        for k in range(len(Operators)):
                            if i == k: continue
                            outputVar = OutputVar[k]
                            outputValue = OutputValueVar[k]
                            outputType = Operators[k][1]
                            if outputType in str(inputVar):
                                s3.add(
                                    Implies(outputVar == inputVar,
                                            outputValue == inputValue))
                        if "Bool" in str(inputVar):
                            currentType = "Bool"
                        else:
                            currentType = "Int"
                        for k in range(len(Terminals[currentType])):
                            terminal = Terminals[currentType][k]
                            if terminal in ArgumentDict:
                                argId = ArgumentDict[terminal]
                                currentCons = [
                                    "=>", ["=", str(inputVar),
                                           str(k)],
                                    [
                                        "=", CurrentArguments[argId],
                                        str(inputValue)
                                    ]
                                ]
                            else:
                                currentCons = [
                                    "=>", ["=", str(inputVar),
                                           str(k)],
                                    ["=", terminal,
                                     str(inputValue)]
                                ]
                            currentCons = '(assert %s)' % (
                                translator.toString(currentCons))
                            currentCons = parse_smt2_string(
                                currentCons, decls=dict(newVarTable))
                            currentCons = currentModel.eval(currentCons[0])
                            s3.add(currentCons)
                newReturnValueVar = declareVar(ReturnType,
                                               getId("returnValueVar", callId),
                                               newVarTable)
                currentOuterCons = replaceCons(currentOuterCons,
                                               str(returnValueVar),
                                               str(newReturnValueVar))
                for k in range(len(Operators)):
                    outputVar = OutputVar[k]
                    outputValue = OutputValueVar[k]
                    outputType = Operators[k][1]
                    if outputType == ReturnType:
                        s3.add(
                            Implies(resultVar == outputVar,
                                    newReturnValueVar == outputValue))
            spec = "\n".join(
                list(
                    map(lambda x: "(assert %s)" % translator.toString(x[1:]),
                        currentOuterCons)))
            spec = parse_smt2_string(spec, decls=newVarTable)
            #print(spec)
            spec = list(map(lambda x: currentModel.eval(x), spec))
            #print(spec)
            s3.add(spec)
        #print "start"
        #print(s3)
        resS3 = s3.check()
        # print(s3.unsat_core())
        # print(resS3)
        # print "end"
        # print(resS3)
        if resS3 == unsat:
            return None
        currentCodeModel = s3.model()
        s3.pop()
        OutputTable = {}
        for i in range(len(Operators)):
            outputId = currentCodeModel[OutputVar[i]].as_long()
            OutputTable[outputId] = i
        '''print("Now")
        for i in range(len(Operators)):
            print("")
            print(Operators[i])
            print(currentCodeModel[OutputVar[i]])
            print(map(lambda x: currentCodeModel[x], InputVar[i]))'''
        resultId = currentCodeModel[resultVar].as_long()
        # print(currentCodeModel)
        if resultId < len(Terminals[ReturnType]):
            resultCode = Terminals[ReturnType][resultId]
        else:
            resultCode = getCode(OutputTable[resultId], currentCodeModel,
                                 Operators, InputVar, OutputTable, Terminals,
                                 OperatorTypeVar)

        #print translator.toString(resultCode)

        s2.push()
        FuncDefineStr = '(define-fun'
        for i in range(1, 4):
            currentStr = translator.toString(SynFunExpr[i])
            if i == 2 and len(SynFunExpr[i]) == 1:
                currentStr = "(%s)" % (currentStr)
            FuncDefineStr += " " + currentStr
        FuncDefineStr += ")"
        #print FuncDefineStr
        fullResultCode = FuncDefineStr[:-1] + ' ' + translator.toString(
            resultCode) + FuncDefineStr[-1]
        spec_smt2 = [fullResultCode]
        for constraint in Constraints:
            spec_smt2.append('(assert %s)' %
                             (translator.toString(constraint[1:])))
        spec_smt2 = '\n'.join(spec_smt2)
        spec = parse_smt2_string(spec_smt2, decls=dict(VarTable))
        # print "End"
        s2.add(Not(And(spec)))

        while True:
            s2.push()
            s2.add(And(inputQualityCons))
            resS2 = s2.check()
            if resS2 == unsat:
                if len(inputQualityCons) == 0:
                    return fullResultCode
                else:
                    s2.pop()
                    inputQualityCons = inputQualityCons[:-1]
                    continue
            newInput = s2.model()
            s2.pop()
            break

        s2.pop()

        s1.push()
        for var in inputVarTable:
            newInputValue = newInput[inputVarTable[var]]
            if newInputValue is not None:
                s1.add(inputVarTable[var] == newInputValue)
        s1.check()
        newFullInput = s1.model()
        s1.pop()
        #print(newFullInput)
        #print(fullResultCode)
        #input()
        Models.append(newFullInput)
    return None
Esempio n. 9
0
        NTType = NonTerm[1]
        assert NTType in ['Int', 'Bool']
        if NTType == Type[StartSym]:
            Productions[StartSym].append(NTName)
        Type[NTName] = NTType
        #Productions[NTName] = NonTerm[2]
        Productions[NTName] = []
        for NT in NonTerm[2]:
            if type(NT) == tuple:
                Productions[NTName].append(
                    str(NT[1])
                )  # deal with ('Int',0). You can also utilize type information, but you will suffer from these tuples.
            else:
                Productions[NTName].append(NT)

    _, _, _, Constraints = translator.ReadQuery(bmExpr)

    operatorTable = {}
    for NonTerm in SynFunExpr[4]:
        for NT in NonTerm[2]:
            current = NT
            if type(NT) == tuple:
                current = str(NT[1])
            if type(current) == str:
                if current not in Type and current not in Terminals[
                        NonTerm[1]]:
                    Terminals[NonTerm[1]].append(current)
            else:
                operatorArgs = []
                for i in NT[1:]:
                    if i in Type:
Esempio n. 10
0
File: main.py Progetto: lsrcz/SyGuS
def trySolve(Terminals, Operators, ReturnType, bmExpr):
    # print(Operators)
    InputVar = []
    OutputVar = []
    TypeNumber = {'Bool': 0, 'Int': 0}
    MidNumber = max(len(Terminals['Bool']), len(Terminals['Int']))
    StartNumber = MidNumber
    SynFunExpr, VarTable, FunDefMap, Constraints = translator.ReadQuery(bmExpr)
    for operator in Operators:
        outputType = operator[1]
        OutputVar.append(declareVar('Int', getId(outputType, TypeNumber[outputType]), VarTable))
        TypeNumber[outputType] += 1
        MidNumber += 1
        currentInputVar = []
        for arg in operator[2:]:
            if type(arg) == list:
                inputType = arg[0]
                currentInputVar.append(declareVar('Int', getId(inputType, TypeNumber[inputType]), VarTable))
                TypeNumber[inputType] += 1
        InputVar.append(currentInputVar)
    resultVar = declareVar('Int', getId(ReturnType, TypeNumber[ReturnType]), VarTable)
    TypeNumber[ReturnType] += 1

    s1 = Solver()
    s2 = Solver()
    s3 = Solver()

    functionCallDic = {}
    ReplacedCons = []
    for i in range(len(Constraints)):
        ReplacedCons.append(replaceFunctionCall(Constraints[i], functionCallDic, SynFunExpr[1], SynFunExpr[3], VarTable))
    spec = "\n".join(list(map(lambda x: "(assert %s)"%translator.toString(x[1:]), ReplacedCons)))
    spec = parse_smt2_string(spec, decls=VarTable)
    # print(spec)
    s1.add(spec)
    s1.check()
    currentModel = s1.model()
    Models = [currentModel]
    s1VarTable = VarTable.copy()

    ArgumentDict = {}
    # print(Terminals)
    argId = -1
    for arg in SynFunExpr[2]:
        argId += 1
        ArgumentDict[arg[0]] = argId

    for i in range(len(Operators)):
        outputVar = OutputVar[i]
        operator = Operators[i]
        s3.add(outputVar >= StartNumber)
        s3.add(outputVar < MidNumber)
        for inputVar in InputVar[i]:
            s3.add(outputVar > inputVar)
            if 'Bool' in str(inputVar):
                currentInputType = 'Bool'
            else:
                currentInputType = 'Int'
            currentCons = [And(inputVar >= 0, inputVar < len(Terminals[currentInputType]))]
            for j in range(len(Operators)):
                if Operators[j][1] == currentInputType:
                    currentCons.append(inputVar == OutputVar[j])
            s3.add(Or(currentCons))
    currentCons = []
    for i in range(len(Operators)):
        if Operators[i][1] == ReturnType:
            currentCons.append(resultVar == OutputVar[i])
    # print "fin ", currentCons
    s3.add(Or(currentCons))
    for i in range(len(Operators)):
        for j in range(i+1, len(Operators)):
            s3.add(OutputVar[i] != OutputVar[j])
    # print(Operators)
    # print(Terminals)
    # print(MidNumber)
    # print len(Terminals['Int']), len(Terminals['Bool'])
    while True:
        s3.push()
        callId = -1
        for currentModel in Models:
            for functionCallName in functionCallDic:
                newVarTable = VarTable.copy()
                returnValueVar, CurrentArguments = functionCallDic[functionCallName]
                InputValueVar = []
                OutputValueVar = []
                callId += 1
                ValueTypeNumber = {'Bool': len(Terminals['Bool']), 'Int': len(Terminals['Int'])}
                for operator in Operators:
                    outputType = operator[1]
                    outputValueVar = declareVar(outputType,
                                                getId(outputType + str(callId) + "-", ValueTypeNumber[outputType]),
                                                newVarTable)
                    OutputValueVar.append(outputValueVar)
                    ValueTypeNumber[outputType] += 1
                    currentInputValue = []
                    currentCons = [operator[0]]
                    for arg in operator[2:]:
                        if type(arg) == list:
                            inputType = arg[0]
                            inputValueVar = declareVar(inputType,
                                                       getId(inputType + str(callId) + "-", ValueTypeNumber[inputType]),
                                                       newVarTable)
                            ValueTypeNumber[inputType] += 1
                            currentInputValue.append(inputValueVar)
                            currentCons.append(str(inputValueVar))
                        else:
                            currentCons.append(arg)
                    InputValueVar.append(currentInputValue)
                    currentCons = ["=", str(outputValueVar), currentCons]
                    spec = '(assert %s)' % (translator.toString(currentCons))
                    spec = parse_smt2_string(spec, decls=dict(newVarTable))
                    # print(spec[0])
                    s3.add(spec[0])
                # print "CurrentArg: ", CurrentArguments
                # print ArgumentDict
                for i in range(len(Operators)):
                    for j in range(len(InputVar[i])):
                        inputVar = InputVar[i][j]
                        inputValue = InputValueVar[i][j]
                        for k in range(len(Operators)):
                            if i == k: continue
                            outputVar = OutputVar[k]
                            outputValue = OutputValueVar[k]
                            outputType = Operators[k][1]
                            if outputType in str(inputVar):
                                s3.add(Implies(outputVar == inputVar, outputValue == inputValue))
                        if "Bool" in str(inputVar):
                            currentType = "Bool"
                        else:
                            currentType = "Int"
                        for k in range(len(Terminals[currentType])):
                            terminal = Terminals[currentType][k]
                            if terminal in ArgumentDict:
                                argId = ArgumentDict[terminal]
                                currentCons = ["=>", ["=", str(inputVar), str(k)],
                                               ["=", CurrentArguments[argId], str(inputValue)]]
                            else:
                                currentCons = ["=>", ["=", str(inputVar), str(k)], ["=", terminal, str(inputValue)]]
                            currentCons = '(assert %s)' % (translator.toString(currentCons))
                            currentCons = parse_smt2_string(currentCons, decls=dict(newVarTable))
                            currentCons = currentModel.eval(currentCons[0])
                            s3.add(currentCons)
                for k in range(len(Operators)):
                    outputVar = OutputVar[k]
                    outputValue = OutputValueVar[k]
                    outputType = Operators[k][1]
                    if outputType == ReturnType:
                        # print Implies(resultVar == outputVar, currentModel.eval(returnValueVar) == outputValue)
                        s3.add(Implies(resultVar == outputVar, currentModel.eval(returnValueVar) == outputValue))
        #print "start"
        resS3 = s3.check()
        # print "end"
        # print resS3
        if resS3 == unsat:
            return None
        currentCodeModel = s3.model()
        s3.pop()
        OutputTable = {}
        for i in range(len(Operators)):
            outputId = currentCodeModel[OutputVar[i]].as_long()
            OutputTable[outputId] = i
        '''print("Now")
        for i in range(len(Operators)):
            print("")
            print(Operators[i])
            print(currentCodeModel[OutputVar[i]])
            print(map(lambda x: currentCodeModel[x], InputVar[i]))'''
        resultId = currentCodeModel[resultVar].as_long()
        # print(currentCodeModel)
        if resultId < len(Terminals[ReturnType]):
            resultCode = Terminals[ReturnType][resultId]
        else:
            resultCode = getCode(OutputTable[resultId], currentCodeModel, Operators, InputVar, OutputTable, Terminals)

        #print translator.toString(resultCode)

        s2.push()
        FuncDefineStr = '(define-fun'
        for i in range(1, 4):
            currentStr = translator.toString(SynFunExpr[i])
            if i == 2 and len(SynFunExpr[i]) == 1:
                currentStr = "(%s)"%(currentStr)
            FuncDefineStr += " " + currentStr
        FuncDefineStr += ")"
        #print FuncDefineStr
        fullResultCode = FuncDefineStr[:-1] + ' ' + translator.toString(resultCode) + FuncDefineStr[-1]
        spec_smt2=[fullResultCode]
        for constraint in Constraints:
            spec_smt2.append('(assert %s)'%(translator.toString(constraint[1:])))
        spec_smt2='\n'.join(spec_smt2)
        # print(spec_smt2)
        # print spec_smt2
        # print VarTable
        spec = parse_smt2_string(spec_smt2, decls=dict(VarTable))
        # print "End"
        s2.add(Not(And(spec)))

        resS2 = s2.check()
        if resS2 == unsat:
            return fullResultCode
        newInput = s2.model()
        s2.pop()

        s1.push()
        for var in s1VarTable:
            newInputValue = newInput[s1VarTable[var]]
            if newInputValue is not None:
                s1.add(s1VarTable[var] == newInputValue)
        s1.check()
        newFullInput = s1.model()
        s1.pop()
        # print(newFullInput)
        # print(fullResultCode)
        # raw_input()
        Models.append(newFullInput)
    return None
Esempio n. 11
0
        noComments += line
    return noComments + ')'


if __name__ == '__main__':
    TE_set = set()
    begin_t = time.time()
    file_name = 'open_tests/' + default_file
    if (len(sys.argv) > 1):
        file_name = sys.argv[1]
    benchmarkFile = open(file_name)
    bm = stripComments(benchmarkFile)
    bmExpr = sexp.sexp.parseString(
        bm, parseAll=True).asList()[0]  # Parse string to python list
    # pprint.pprint(bmExpr)
    checker, is_ite_prior, is_cmp_prior, preCons = translator.ReadQuery(bmExpr)
    # print (checker.check('(define-fun f ((x Int)) Int (mod (* x 3) 10)  )'))
    # raw_input()
    SynFunExpr = []
    StartSym = 'My-Start-Symbol'  # virtual starting symbol
    for expr in bmExpr:
        if len(expr) == 0:
            continue
        elif expr[0] == 'synth-fun':
            SynFunExpr = expr
    FuncDefine = ['define-fun'] + SynFunExpr[1:4]  # copy function signature
    FuncDefineStr = translator.toString(FuncDefine, ForceBracket=True)
    # print(FuncDefine)
    BfsQueue = [[StartSym]]  # Top-down
    Productions = {StartSym: []}
    Type = {StartSym: SynFunExpr[3]}  # set starting symbol's return type