예제 #1
0
def run(eqfile, timeout, ploc, wd):
    path = ploc.findProgram("woorpje")

    if tool:
        try:
            time = timer.Timer()
            smtmodel = os.path.join(wd, "model.smt")
            out = subprocess.check_output([
                tool, '--smtmodel', smtmodel, '--solver', '1', '-S', '1',
                '--smttimeout', '10', eqfile
            ],
                                          timeout=timeout)
            #out = subprocess.check_output ([tool,'--simplify', eqfile],timeout=timeout)
            #print(out.decode().strip())
            time.stop()
            with open(smtmodel) as f:
                model = f.read()
                return utils.Result(True, time.getTime(), False, 0, out, model)
        except subprocess.CalledProcessError as ex:

            time.stop()
            if ex.returncode == 10 or ex.returncode == 20:
                return utils.Result(None, time.getTime(), False, 0, ex.output)
            elif ex.returncode == 1:
                return utils.Result(False, time.getTime(), False, 0, ex.output)
            else:
                return utils.Result(None, time.getTime(), False, 0, ex.output)
        except subprocess.TimeoutExpired:
            return utils.Result(None, timeout, True, 0)

    else:
        raise "woorpje Not in Path"
예제 #2
0
def run(params, eq, timeout, ploc, wd):
    path = ploc.findProgram("woorpjeSMT")
    if not path:
        raise "WoorpjeSMT Not in Path"

    tempd = tempfile.mkdtemp()
    smtfile = os.path.join(tempd, "out.smt")
    time = timer.Timer()
    myerror = ""
    SMTSolverCalls = 0
    try:
        smtmodel = os.path.join(wd, "smtmodel.smt")
        time = timer.Timer()
        #out = subprocess.check_output ([path, '--smtmodel',smtmodel,'--smttimeout', '15','--solver','4']+params+[eq],timeout=timeout)
        p = subprocess.run([
            path, '--smtmodel', smtmodel, '--solver', '4', '--smttimeout', '15'
        ] + params + [eq],
                           stdout=subprocess.PIPE,
                           encoding='ascii',
                           universal_newlines=True,
                           timeout=timeout)

        time.stop()
        output = p.stdout.splitlines()
        for l in output:
            if l.startswith("SMTCalls:"):
                SMTSolverCalls = [int(x) for x in l.split(" ")
                                  if x.isdigit()][0]

        if p.returncode == 0:
            with open(smtmodel) as f:
                #model = f.read()
                model = ""
                for l in f:
                    #model += "".join(l.split("_"))
                    model += l.replace("(define-fun _", "(define-fun ",
                                       1).replace("_()", "()", 1) + "\n"

                return utils.Result(True, time.getTime(), False,
                                    SMTSolverCalls, "\n".join(output), model)

        elif p.returncode == 10 or p.returncode == 20:
            return utils.Result(None, time.getTime(), False, SMTSolverCalls,
                                "\n".join(output))
        elif p.returncode == 1:
            return utils.Result(False, time.getTime(), False, SMTSolverCalls,
                                "\n".join(output))
        else:
            return utils.Result(None, time.getTime(), False, SMTSolverCalls,
                                "\n".join(output))
    except Exception as e:
        time.stop()
        return utils.Result(None, timeout, True, SMTSolverCalls, str(e))
예제 #3
0
def run(eqfile, timeout, heuristicNo, smtSolverNo, heuristic_param_name, param,
        ploc, wd):
    tool = ploc.findProgram("woorpjeSMT")

    if tool:
        SMTSolverCalls = 0
        try:
            time = timer.Timer()
            smtmodel = os.path.join(wd, "model.smt")
            p = subprocess.run([
                tool, '--smtmodel', smtmodel, '--solver', '4', '-S',
                str(smtSolverNo), '--smttimeout', '15', '--levisheuristics',
                str(heuristicNo),
                str(heuristic_param_name),
                str(param), eqfile
            ],
                               stdout=subprocess.PIPE,
                               encoding='ascii',
                               universal_newlines=True,
                               timeout=timeout)
            time.stop()
            output = p.stdout.splitlines()

            for l in output:
                if l.startswith("SMTCalls:"):
                    SMTSolverCalls = [
                        int(x) for x in l.split(" ") if x.isdigit()
                    ][0]

            if p.returncode == 0:
                with open(smtmodel) as f:
                    model = f.read()
                    return utils.Result(True, time.getTime(), False,
                                        SMTSolverCalls, "\n".join(output),
                                        model)
            elif p.returncode == 10 or p.returncode == 20:
                return utils.Result(None, time.getTime(), False,
                                    SMTSolverCalls)
            elif p.returncode == 1:
                return utils.Result(False, time.getTime(), False,
                                    SMTSolverCalls)
            else:
                return utils.Result(None, time.getTime(), False,
                                    SMTSolverCalls)
        except Exception as e:
            time.stop()
            return utils.Result(None, timeout, True, SMTSolverCalls)

    else:
        raise "woorpje Not in Path"
예제 #4
0
def test_data(model, dataset_dir):
    confusion_matrix = defaultdict(Counter)
    doc_count = 0
    count_correct_classification = 0
    print("Testing Data...")
    print("Please wait...")
    for topic in os.listdir(dataset_dir):

        if topic.startswith('.'):
            continue
        topic_dir = dataset_dir + "/" + topic

        for test_file in os.listdir(topic_dir):
            doc_count += 1
            file_path = topic_dir + "/" + test_file
            words = utils.get_file_content(file_path)
            # Topic classified using our model.
            classified_topic = utils.find_topic(model, words)

            # Is it the same as our ground truth?
            if classified_topic == topic:
                count_correct_classification += 1

            # Tabulate results into our confusion matrix.
            confusion_matrix[topic][classified_topic] += 1

    # Return the result object.
    return utils.Result(confusion_matrix, count_correct_classification,
                        doc_count)
예제 #5
0
def train_epoch(train_loader, models, optimizer, logr, epoch):
    model, bprox, mdn = models
    model = model.train()
    block_average_meter = utils.AverageMeter()
    average_meter = utils.AverageMeter()

    for i, batch_data in enumerate(train_loader):
        batch_data = {
            key: val.to(device)
            for key, val in batch_data.items() if val is not None
        }

        # monocular depth estimation in range of (2,1000) as in paper
        # scale to be in depth range of the appropriate dataset
        # bilateral proxy was also trained to take images in range of (2,1000)
        # and output images in range of (2, 1000)
        if args.dataset == 'kitti':
            mdi = torch.clamp(
                F.interpolate(utils.DepthNorm(mdn(batch_data['rgb'])),
                              scale_factor=2), 2, 1000) / 1000 * 85
            inpainted = bprox(
                torch.cat([mdi / 85 * 1000, batch_data['d'] / 85 * 1000],
                          dim=1)) / 1000 * 85
        elif args.dataset == 'nyu_v2':
            mdi = torch.clamp(utils.DepthNorm(mdn(batch_data['rgb'])), 10,
                              1000) / 1000 * 10
            inpainted = bprox(
                torch.cat([mdi / 10 * 1000, batch_data['d'] / 10 * 1000],
                          dim=1)) / 1000 * 10
        else:
            print("invalid dataset")
            exit()

        batch_data['bproxi'] = inpainted
        batch_data['mdi'] = mdi

        pred = model(batch_data)
        loss = utils.loss(pred, batch_data, args.dataset)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        writer.add_scalar('data/train_loss',
                          loss.cpu().data.numpy(),
                          epoch * len(train_loader) + i)
        result = utils.Result()
        result.evaluate(pred, batch_data['gt'])
        block_average_meter.update(result)
        average_meter.update(result)
        if (i + 1) % 20 == 0:
            logr.print(i, epoch, args.lr, len(train_loader),
                       block_average_meter, average_meter)
예제 #6
0
def run(eq, timeout, ploc, wd):
    path = ploc.findProgram("Trau")
    if not path:
        raise "Trau Not in Path"

    tempd = tempfile.mkdtemp()
    smtfile = os.path.join(tempd, "out_trau.smt")
    #tools.woorpje2smt.run (eq,smtfile,ploc)

    # hack to get rid of (get-model), not needed for z3 and returns 1 / Error if input is unsat
    f = open(eq, "r")
    copy = open(smtfile, "w")
    for l in f:
        if "(get-model)" not in l:
            copy.write(l)

    f.close()
    copy.close()

    time = timer.Timer()
    try:
        out = subprocess.check_output(
            [path, "smt.string_solver=trau", "dump_models=true", smtfile],
            timeout=timeout).decode().strip()
    except subprocess.TimeoutExpired:
        return utils.Result(None, timeout * 1000, True, 1)
    except subprocess.CalledProcessError as e:
        time.stop()
        out = "Error in " + eq + ": " + str(e)
        return utils.Result(None, time.getTime_ms(), False, 1, out)

    time.stop()

    if "NOT IMPLEMENTED YET!" in out and not time >= timeout:
        out = "Error in " + eq + ": " + out

    shutil.rmtree(tempd)
    if "unsat" in out:
        return utils.Result(False, time.getTime_ms(), False, 1, out)
    elif "sat" in out:
        return utils.Result(True, time.getTime_ms(), False, 1, out,
                            "\n".join(out.split("\n")[1:]))
    elif time.getTime() >= timeout:
        return utils.Result(None, timeout * 1000, True, 1)
    elif "unknown" in out:
        return utils.Result(None, time.getTime_ms(), False, 1, out)
    else:
        # must be an error
        return utils.Result(None, time.getTime_ms(), False, 1,
                            f"Error in {eq} # stdout: {out}")
예제 #7
0
def run(params, eq, timeout, ploc, wd):
    path = ploc.findProgram("Z3Overlaps")
    if not path:
        raise "Z3BV Not in Path"

    tempd = tempfile.mkdtemp()
    smtfile = os.path.join(tempd, "out.smt")
    #tools.woorpje2smt.run (eq,smtfile,ploc)

    # hack to get rid of (get-model), not needed for z3 and returns 1 / Error if input is unsat
    f = open(eq, "r")
    copy = open(smtfile, "w")
    for l in f:
        if "(get-model)" not in l:
            copy.write(l)

    f.close()
    copy.close()

    time = timer.Timer()
    myerror = ""
    try:
        out = subprocess.check_output(
            [
                path, "smt.string_solver=z3str3", "dump_models=true",
                "smt.str.search_overlaps=true",
                "smt.str.fixed_length_iterations=5",
                "smt.str.pre_milliseconds=2500"
            ] + params + [smtfile],
            timeout=timeout,
            stderr=subprocess.STDOUT).decode().strip()
    except subprocess.TimeoutExpired:
        return utils.Result(None, timeout, True, 1)

    except subprocess.CalledProcessError as e:
        time.stop()
        out = "Error in " + eq + ": " + str(e)
        return utils.Result(None, time.getTime(), False, 1, out)

    time.stop()

    if "NOT IMPLEMENTED YET!" in out and not time >= timeout:
        out = "Error in " + eq + ": " + out

    shutil.rmtree(tempd)
    if "unsat" in out:
        return utils.Result(False, time.getTime(), False, 1, out)
    elif "sat" in out:
        return utils.Result(True, time.getTime(), False, 1, out,
                            "\n".join(out.split("\n")[1:]))
    elif time.getTime() >= timeout:
        return utils.Result(None, timeout, True, 1)

    return utils.Result(None, time.getTime(), False, 1, out)
예제 #8
0
def run(eq, timeout, ploc, wd):
    path = ploc.findProgram("Ostrich")
    if not path:
        raise "Ostrich Not in Path"

    tempd = tempfile.mkdtemp()
    smtfile = os.path.join(tempd, "out.smt")
    #tools.woorpje2smt.run (eq,smtfile,ploc)

    # hack to get rid of (get-model), not needed for z3 and returns 1 / Error if input is unsat
    f = open(eq, "r")
    copy = open(smtfile, "w")
    for l in f:
        if "(get-model)" not in l:
            copy.write(l)

    f.close()
    copy.close()

    time = timer.Timer()
    try:
        out = subprocess.check_output([
            path, "-logo", "-length=on", "+quiet", "-inputFormat=smtlib",
            "+model", "-timeout=" + str(timeout) + "000", smtfile
        ],
                                      timeout=timeout).decode().strip()
    except subprocess.TimeoutExpired:
        return utils.Result(None, timeout * 1000, True, 1)
    except subprocess.CalledProcessError as e:
        time.stop()
        out = "Error in " + eq + ": " + str(e)
        return utils.Result(None, time.getTime_ms(), False, 1, out)

    time.stop()
    shutil.rmtree(tempd)
    if "unsat" in out:
        return utils.Result(False, time.getTime_ms(), False, 1, out)
    elif "sat" in out:
        return utils.Result(True, time.getTime_ms(), False, 1, out,
                            "\n".join(out.split("\n")[1:]))
    elif time.getTime() >= timeout:
        return utils.Result(None, timeout * 1000, True, 1)
    elif "unknown" in out:
        return utils.Result(None, time.getTime_ms(), False, 1, out)
    else:
        # must be an error
        return utils.Result(None, time.getTime_ms(), False, 1,
                            f"Error in {eq} # stdout: {out}")
예제 #9
0
def runSpecific (tup):
    try:
        solvername,func,model,timeout,ploc,verifiers = tup
        progressMessage (model,solvername)
        tempd = tempfile.mkdtemp ()
        result = func(model.filepath,timeout,ploc,tempd)

        # verification goes here
        v = verification.verifier.Verifier()
        if result.result == True:
             result = v.verifyModel (result,ploc,model.filepath,timeout,verifiers)

        shutil.rmtree (tempd)    
        return result
    except Exception as e:
        print (str(e))
        shutil.rmtree (tempd)
        return utils.Result(None,timeout,True,0,str(e))
예제 #10
0
def runSpecific(tup):
    logging.debug("Running job %s", tup)
    try:
        solvername, func, model, timeout, ploc, verifiers = tup
        progressMessage(model, solvername)
        tempd = tempfile.mkdtemp()
        result = func(model.filepath, timeout, ploc, tempd)

        # verification goes here
        v = verification.verifier.Verifier()
        if result.result == True:
            result = v.verifyModel(result, ploc, model.filepath, timeout, verifiers)

        shutil.rmtree(tempd)
        return result
    except Exception as e:
        logger.exception("Error while running experiments: ", e)
        shutil.rmtree(tempd)
        return utils.Result(None, timeout, True, 0, str(e))
예제 #11
0
파일: tester.py 프로젝트: domoritz/asprin
 def run(self, dir, options):
     errors, error = False, False
     for i in sorted(os.listdir(dir)):
         abs_i = os.path.join(dir, i)
         if os.path.isdir(abs_i):
             error = self.run(os.path.join(dir, i), options)
         elif str(abs_i)[-3:] == ".lp":
             with open(abs_i, 'r') as f:
                 test = utils.Test(f.read(), options)
             print("Testing {}...".format(abs_i))
             import tempfile
             tmp = tempfile.TemporaryFile()
             with cd(dir):
                 subprocess.call(test.command,
                                 stdout=tmp,
                                 stderr=subprocess.STDOUT,
                                 shell=True)
                 tmp.seek(0)
                 result = utils.Result(tmp.read())
             error = result.compare(test)
         if error:
             errors = True
     return errors
예제 #12
0
def run(eqfile, timeout, ploc, wd):
    #sfile = "/root/words/benchmarkExtract/benchmarkTool/kaluzaSmallSatExtracted/test.smt"
    tool = ploc.findProgram("woorpjeSMT")
    if tool:
        try:
            smtmodel = os.path.join(wd, "smtmodel.smt")
            time = timer.Timer()
            out = subprocess.check_output([
                tool, '--smtmodel', smtmodel, '--solver', '1', '-S', '1',
                '--smttimeout', '10', eqfile
            ],
                                          timeout=timeout)
            #out = subprocess.check_output ([tool,'--simplify', eqfile],timeout=timeout)
            #print(out.decode().strip())
            time.stop()
            #extractFile(eqfile,sfile)
            with open(smtmodel) as f:
                model = f.read()
                return utils.Result(True, time.getTime(), False, 0, out, model)
        except subprocess.CalledProcessError as ex:
            time.stop()
            if ex.returncode == 0:
                extractFile(eqfile, sfile)
            if ex.returncode == 10 or ex.returncode == 20:
                return utils.Result(None, time.getTime(), False, 0, ex.output)
            elif ex.returncode == 1:
                return utils.Result(False, time.getTime(), False, 0, ex.output)
            elif ex.returncode == 134 or ex.returncode == 255:
                return utils.Result(None, 0, False, 0, ex.output)
            else:
                return utils.Result(None, time.getTime(), False, 0, ex.output)
        except subprocess.TimeoutExpired:
            #extractFile(eqfile,sfile)
            return utils.Result(None, timeout, True, 0)

    else:
        raise "woorpje Not in Path"
예제 #13
0
def run (eq,timeout,ploc,wd):
    return utils.Result(None,0,False,1,"")
예제 #14
0
def run (eq,timeout,ploc,wd,solver="1",param="60"):
    path = ploc.findProgram ("Z3Bin")
    if not path:
        raise "Z3 Not in Path"
    tempd = tempfile.mkdtemp ()
    smtfile = os.path.join (tempd,"z3_out.smt")
    #tools.oorpje2smt.run (eq,smtfile,ploc)

    setLogicPresent = False
    #set logic present?
    with open(eq) as flc:
        for l in flc:
            if not l.startswith(";") and '(set-logic' in l:
                setLogicPresent = True

    # hack to insert (get-model), which is needed for cvc4 to output a model
    f=open(eq,"r")
    copy=open(smtfile,"w")
    firstLine = None 
    
    if not setLogicPresent:
        copy.write("(set-logic QF_S)\n")


    for l in f:
        if not l.startswith(";") and firstLine == None:
            firstLine = True
        # set (set-logic ALL) if no logic was set
        #if "(set-logic" not in l and firstLine:
        # if not setLogicPresent:
        #     copy.write("(set-logic ALL)\n")    
       
        if firstLine:
            firstLine = False 
        
        #if "(get-model)" not in l and "(check-sat)" not in l and "(exit)" not in l:
        for exp in ["\(get-model\)","\(check-sat\)","\(exit\)","\(set-info :status sat\)","\(set-info :status unsat\)"]:
            l = re.sub(exp, '', l)
        
        
        if "(set-logic" in l:
            l = re.sub('\(set-logic.*?\)', '(set-logic QF_S)', l)
        copy.write(l)

    copy.write("\n(check-sat)")
    f.close()
    copy.close() 
    

    time = timer.Timer ()
    try:
        out = subprocess.check_output ([path,"smt.string_solver=z3str3","dump_models=true","smt.arith.solver=2",smtfile],timeout=int(timeout)).decode().strip()
    except subprocess.TimeoutExpired:
        return utils.Result(None,timeout*1000,True,1)
    except subprocess.CalledProcessError as e:
        time.stop()

        shutil.rmtree(tempd)

        if time.getTime() >= timeout:
            return utils.Result(None,time.getTime_ms(),True,1)
        else:
            #print(eq)
            #print(smtfile)
            #print("----------CVC4"+str(e.output))
            out = "Error in " + eq + ": " + str(e.output)
            return utils.Result(None,time.getTime_ms(),False,1,out)
            """if "SIG" in str(e):          
                return utils.Result(None,time.getTime(),False,1,out)
            else:
                # treat unsupported operations as timeout:
                return utils.Result(None,timeout,True,1,str(e))
            """
    
    time.stop ()
    #shutil.rmtree (tempd)
    if "unsat" in out:
        return utils.Result(False,time.getTime_ms (),False,1,out)
    elif "sat" in out:
        return utils.Result(True,time.getTime_ms (),False,1,out,"\n".join(out.split("\n")[1:]))
    elif time.getTime() >= timeout:
        return utils.Result(None,time.getTime_ms (),True,1)
    elif "unknown" in out:
        return utils.Result(None,time.getTime_ms  (),False,1,out)
    else:
        # must be an error
        return utils.Result(None,time.getTime_ms (), False,1,f"Error in {eq} # stdout: {out}")
예제 #15
0
def run(eq, timeout, ploc, wd, solver="1", param="60"):
    path = ploc.findProgram("cvc4")
    if not path:
        raise "Z3 Not in Path"
    tempd = tempfile.mkdtemp()
    smtfile = os.path.join(tempd, "cvc4_out.smt")

    setLogicPresent = False
    #set logic present?
    with open(eq) as flc:
        for l in flc:
            if not l.startswith(";") and '(set-logic' in l:
                setLogicPresent = True

    f = open(eq, "r")
    copy = open(smtfile, "w")
    firstLine = None

    if not setLogicPresent:
        copy.write("(set-logic QF_SLIA)\n")

    for l in f:
        if not l.startswith(";") and firstLine == None:
            firstLine = True

        if firstLine:
            firstLine = False
        for exp in [
                "\(get-model\)", "\(check-sat\)", "\(exit\)",
                "\(set-info :status sat\)", "\(set-info :status unsat\)"
        ]:
            l = re.sub(exp, '', l)

        if "(set-logic" in l:
            l = re.sub('\(set-logic.*?\)', '(set-logic QF_SLIA)', l)
        copy.write(l)

    copy.write("\n(check-sat)")
    f.close()
    copy.close()

    time = timer.Timer()
    try:
        out = subprocess.check_output([
            path, "--lang", "smtlib2.5", "--no-interactive",
            "--no-interactive-prompt", "--strings-exp", "--dump-models",
            "--tlimit-per",
            str(timeout) + "000", smtfile
        ],
                                      timeout=int(timeout)).decode().strip()
    except subprocess.TimeoutExpired:
        return utils.Result(None, timeout * 1000, True, 1)
    except subprocess.CalledProcessError as e:
        time.stop()

        if time.getTime() >= timeout:
            return utils.Result(None, time.getTime_ms(), True, 1)
        else:
            out = "Error in " + eq + ": " + str(e.output)
            return utils.Result(None, time.getTime_ms(), False, 1, out)

    time.stop()
    shutil.rmtree(tempd)
    if "unsat" in out:
        return utils.Result(False, time.getTime_ms(), False, 1, out)
    elif "sat" in out:
        return utils.Result(True, time.getTime_ms(), False, 1, out,
                            "\n".join(out.split("\n")[1:]))
    elif time.getTime() >= timeout:
        return utils.Result(None, time.getTime_ms(), True, 1)
    elif "unknown" in out:
        return utils.Result(None, time.getTime_ms(), False, 1, out)
    else:
        # must be an error
        return utils.Result(None, time.getTime_ms(), False, 1,
                            f"Error in {eq} # stdout: {out}")
예제 #16
0
def run (eq,timeout,ploc,wd):
    path = ploc.findProgram ("Z3")
    if not path:
        raise "Z3 Not in Path"

    tempd = tempfile.mkdtemp ()
    smtfile = os.path.join (tempd,"z3seq_out.smt")

    setLogicPresent = False
    #set logic present?
    with open(eq) as flc:
        for l in flc:
            if not l.startswith(";") and '(set-logic' in l:
                setLogicPresent = True

    # hack to insert (get-model), which is needed for cvc4 to output a model
    f=open(eq,"r")
    copy=open(smtfile,"w")
    firstLine = None

    for l in f:
        if not l.startswith(";") and firstLine == None:
            firstLine = True

        if firstLine:
            firstLine = False

        for exp in ["\(get-model\)","\(check-sat\)","\(exit\)","\(set-info :status sat\)","\(set-info :status unsat\)"]:
            l = re.sub(exp, '', l)

        copy.write(l)

    copy.write("\n(check-sat)")
    f.close()
    copy.close()

    time = timer.Timer ()
    try:
        out = subprocess.check_output ([path,"smt.string_solver=seq","dump_models=true",smtfile],timeout=timeout)
        out = out.decode().strip()
    
    except subprocess.TimeoutExpired:
        return utils.Result(None,timeout*1000,True,1)
    except subprocess.CalledProcessError as e:
        time.stop()
        out = "Error in " + eq + ": " + str(e.output)
        return utils.Result(None,time.getTime_ms(),False,1,out)
    time.stop()    
    


    if "NOT IMPLEMENTED YET!" in out and not time.getTime() >= timeout:
        out = "Error in " + eq + ": " + out    
    shutil.rmtree (tempd)
  
    
    if "unsat" in out:
        return utils.Result(False,time.getTime_ms (),False,1,out)
    elif "sat" in out:
        return utils.Result(True,time.getTime_ms (),False,1,out,"\n".join(out.split("\n")[1:]))
    elif time.getTime() >= timeout:
        return utils.Result(None,time.getTime_ms(),True,1)    
    
    return utils.Result(None,time.getTime_ms  (),False,1,out)
def validate(val_loader, models, logr, epoch, args, device):
    as_model, bprox, mdn, depthcomp = models
    as_model.eval()
    block_average_meter = utils.AverageMeter()
    average_meter = utils.AverageMeter()

    if args.ret_samples:
        samps_tot = 0

    for i, batch_data in enumerate(val_loader):
        batch_data = {
            key: val.to(device)
            for key, val in batch_data.items() if val is not None
        }

        # monocular depth estimation in range of (2,1000) as in paper
        # scale to be in depth range of the appropriate dataset
        # bilateral proxy was also trained to take images in range of (2,1000)
        # and output images in range of (2, 1000)
        if args.dataset == 'kitti':
            with torch.no_grad():
                mdi = torch.clamp(
                    F.interpolate(utils.DepthNorm(mdn(batch_data['rgb'])),
                                  scale_factor=2), 2, 1000) / 1000 * 85
        elif args.dataset == 'nyu_v2':
            with torch.no_grad():
                mdi = torch.clamp(utils.DepthNorm(mdn(batch_data['rgb'])), 10,
                                  1000) / 1000 * 10
        else:
            print("invalid dataset")
            exit()

        batch_data['mdi'] = mdi

        if args.ret_samples:
            pred_sparse_depth, vector_field, grid_to_sample, samps = as_model(
                batch_data)
            samps_tot += samps
        else:
            pred_sparse_depth, vector_field, grid_to_sample = as_model(
                batch_data)

        batch_data['d'] = pred_sparse_depth

        # monocular depth estimation in range of (2,1000) as in paper
        # scale to be in depth range of the appropriate dataset
        # bilateral proxy was also trained to take images in range of (2,1000)
        # and output images in range of (2, 1000)
        if args.dataset == 'kitti':
            inpainted = bprox(
                torch.cat([mdi / 85 * 1000, pred_sparse_depth / 85 * 1000],
                          dim=1)) / 1000 * 85
        elif args.dataset == 'nyu_v2':
            inpainted = bprox(
                torch.cat([mdi / 10 * 1000, pred_sparse_depth / 10 * 1000],
                          dim=1)) / 1000 * 10
        else:
            print("invalid dataset")
            exit()

        batch_data['bproxi'] = inpainted
        pred = depthcomp(batch_data)

        # pause at some iteration to save a full pipeline figure for a validation image
        # if i == 341:
        #     utils.save_methods_figure(batch_data, pred_sparse_depth, vector_field, pred)
        #     exit()

        # evaluate on the sparse ground truth datapoints
        result = utils.Result()
        if args.dataset == 'nyu_v2':
            result.evaluate(pred, batch_data['gt'])
        elif args.dataset == 'kitti':
            result.evaluate(pred, batch_data['gt_sparse'])
        block_average_meter.update(result)
        average_meter.update(result)
        if (i + 1) % 200 == 0:
            logr.print(i, epoch, args.lr, len(val_loader), block_average_meter,
                       average_meter)

        logr.conditional_save_img_comparison(i, batch_data, pred, epoch,
                                             args.dataset)

    avg = average_meter.average()
    is_best = logr.rank_save_best(avg, epoch)
    logr.summarize(avg, is_best)

    writer.add_scalar('data/val_loss_rmse', avg.rmse, epoch)

    if args.ret_samples:
        print("AVERAGE SAMPLES:")
        print(samps_tot / len(val_loader))

    return avg, is_best
def train_epoch(train_loader, models, optimizer, logr, epoch, args, device):
    as_model, bprox, mdn, depthcomp = models
    as_model.train()
    block_average_meter = utils.AverageMeter()
    average_meter = utils.AverageMeter()

    for i, batch_data in enumerate(train_loader):
        batch_data = {
            key: val.to(device)
            for key, val in batch_data.items() if val is not None
        }

        # monocular depth estimation in range of (2,1000) as in paper
        # scale to be in depth range of the appropriate dataset
        # bilateral proxy was also trained to take images in range of (2,1000)
        # and output images in range of (2, 1000)
        if args.dataset == 'kitti':
            mdi = torch.clamp(
                F.interpolate(utils.DepthNorm(mdn(batch_data['rgb'])),
                              scale_factor=2), 2, 1000) / 1000 * 85
        elif args.dataset == 'nyu_v2':
            with torch.no_grad():
                mdi = torch.clamp(utils.DepthNorm(mdn(batch_data['rgb'])), 10,
                                  1000) / 1000 * 10
        else:
            print("invalid dataset")
            exit()

        batch_data['mdi'] = mdi

        # predict sampling locations
        pred_sparse_depth, vector_field, grid_to_sample = as_model(batch_data)
        batch_data['d'] = pred_sparse_depth

        # monocular depth estimation in range of (2,1000) as in paper
        # scale to be in depth range of the appropriate dataset
        # bilateral proxy was also trained to take images in range of (2,1000)
        # and output images in range of (2, 1000)
        if args.dataset == 'kitti':
            inpainted = bprox(
                torch.cat([mdi / 85 * 1000, pred_sparse_depth / 85 * 1000],
                          dim=1)) / 1000 * 85
        elif args.dataset == 'nyu_v2':
            inpainted = bprox(
                torch.cat([mdi / 10 * 1000, pred_sparse_depth / 10 * 1000],
                          dim=1)) / 1000 * 10
        else:
            print("invalid dataset")
            exit()

        batch_data['bproxi'] = inpainted
        pred = depthcomp(batch_data)

        loss = utils.adaptive_loss(pred, vector_field, grid_to_sample,
                                   batch_data, args.dataset, args.grid_reg,
                                   args.image_reg)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        writer.add_scalar('data/train_loss_full',
                          loss.cpu().data.numpy(),
                          epoch * len(train_loader) + i)
        result = utils.Result()
        result.evaluate(pred, batch_data['gt'])
        block_average_meter.update(result)
        average_meter.update(result)
        if (i + 1) % 20 == 0:
            logr.print(i, epoch, args.lr, len(train_loader),
                       block_average_meter, average_meter)

        writer.add_scalar('data/train_loss_rmse', result.rmse,
                          epoch * len(train_loader) + i)
예제 #19
0
def validate(val_loader, models, logr, epoch):
    model, bprox, mdn = models
    model = model.eval()
    block_average_meter = utils.AverageMeter()
    average_meter = utils.AverageMeter()

    for i, batch_data in enumerate(val_loader):
        batch_data = {
            key: val.to(device)
            for key, val in batch_data.items() if val is not None
        }

        # monocular depth estimation in range of (2,1000) as in paper
        # scale to be in depth range of the appropriate dataset
        # bilateral proxy was also trained to take images in range of (2,1000)
        # and output images in range of (2, 1000)
        if args.dataset == 'kitti':
            mdi = torch.clamp(
                F.interpolate(utils.DepthNorm(mdn(batch_data['rgb'])),
                              scale_factor=2), 2, 1000) / 1000 * 85
            inpainted = bprox(
                torch.cat([mdi / 85 * 1000, batch_data['d'] / 85 * 1000],
                          dim=1)) / 1000 * 85
        elif args.dataset == 'nyu_v2':
            mdi = torch.clamp(utils.DepthNorm(mdn(batch_data['rgb'])), 10,
                              1000) / 1000 * 10
            inpainted = bprox(
                torch.cat([mdi / 10 * 1000, batch_data['d'] / 10 * 1000],
                          dim=1)) / 1000 * 10
        else:
            print("invalid dataset")
            exit()

        # add bilateral proxy and monocular depth estimate to the data to feed
        # to the inpainting model
        batch_data['bproxi'] = inpainted
        batch_data['mdi'] = mdi

        pred = model(batch_data)

        result = utils.Result()

        # uncomment line below if using the inpainted dataset and want to train with
        # the sparse KITTI data instead of inpainted kitti data
        # result.evaluate(pred, batch_data['gt_sparse'])
        result.evaluate(pred, batch_data['gt'])

        # result.evaluate(crop(pred, 228, 304), crop(batch_data['gt'], 228, 304))
        block_average_meter.update(result)
        average_meter.update(result)
        if (i + 1) % 20 == 0:
            logr.print(i, epoch, args.lr, len(val_loader), block_average_meter,
                       average_meter)

        logr.conditional_save_img_comparison(i, batch_data, pred, epoch,
                                             args.dataset)

    avg = average_meter.average()
    is_best = logr.rank_save_best(avg, epoch)
    logr.summarize(avg, is_best)

    return avg, is_best