示例#1
0
def predAnalysis(args):
    print('Running Val Test')
    predFile = 'Pred_QnAtt47.9.csv'
    config = Attention_GPUConfig(load=True, args=args)
    valTestReader = AttModelInputProcessor(config.testAnnotFile, 
                                 config.rawQnValTestFile, 
                                 config.valImgFile, 
                                 config,
                                 is_training=False)
    
    model = QnAttentionModel(config)
    model.loadTrainedModel(config.restoreModel, config.restoreModelPath)
    model.runPredict(valTestReader, predFile)
    model.destruct()
    valTestReader.destruct()
示例#2
0
def visQnImgAtt():
    print('Running qn Visuals')
    config = Attention_LapConfig(load=True, args=args)
    reader = AttModelInputProcessor(config.testAnnotFile,
                                 config.rawQnValTestFile,
                                 config.valImgFile, 
                                 config,
                                 is_training=False)
    
    #reader = AttModelInputProcessor(config.trainAnnotFile, 
    #                             config.rawQnTrain, 
    #                             config.trainImgFile, 
    #                             config,
    #                             is_training=False)
    
    model = QnAttentionModel(config)
    
    saveData = True
    
    model.loadTrainedModel(config.restoreQuAttSigmoidModel, 
                           config.restoreQuAttSigmoidModelPath)
    
    #model.loadTrainedModel(config.restoreQnImAttModel, 
    #                       config.restoreQnImAttModelPath)
    qnAlphas, alphas, img_ids, qns, preds, topk, labs = model.runPredict(
        reader, config.csvResults, 200, mini=True, chooseBatch=30)
    model.destruct()
    
    out = OutputGenerator(config.valImgPaths)
    #out = OutputGenerator(config.trainImgPaths)
    #out.displayQnImgAttention(qnAlphas, alphas, img_ids, qns, preds, topk, labs,saveData)
    out.displayQnImgAttSaveSplit(qnAlphas, alphas, img_ids, qns, preds, topk, labs,saveData)
def visQnImgAtt():
    print('Running qn Visuals')
    config = Attention_LapConfig(load=True, args=args)
    reader = AttModelInputProcessor(config.testAnnotFile,
                                 config.rawQnValTestFile,
                                 config.valImgFile, 
                                 config,
                                 is_training=False)
    
    model = QnAttentionModel(config)
    model.loadTrainedModel(config.restoreQnImAttModel, 
                           config.restoreQnImAttModelPath)
    qnAlphas, alphas, img_ids, qns, preds, topk = model.runPredict(
        reader, config.csvResults, 5, mini=True)
    model.destruct()
    
    out = OutputGenerator(config.valImgPaths)
    out.displayQnImgAttention(qnAlphas, alphas, img_ids, qns, preds, topk)
示例#4
0
def validateInternalTestSet(args, model=None, restoreModelPath=None):
    from vqaTools.vqaInternal import VQA
    from vqaTools.vqaEval import VQAEval
    
    #config = Attention_LapConfig(load=True, args)
    config = Attention_GPUConfig(load=True, args=args)
    
    print('Running Validation Test on Model')
    valTestReader = AttModelInputProcessor(config.testAnnotFile, 
                                 config.rawQnValTestFile, 
                                 config.valImgFile, 
                                 config,
                                 is_training=False)
    if restoreModelPath is None:
        restoreModel = config.restoreModel
        restoreModelPath = config.restoreModelPath
    
    if model is None:
        if args.att == 'qn':
            print('Attention over question and image model')
            model = QnAttentionModel(config)
        elif args.att == 'im':
            print('Attention over image model')
            model = ImageAttentionModel(config)
        model.loadTrainedModel(restoreModel, restoreModelPath)
        
    predFile = '{}PredsAtt{}.csv'.format(restoreModelPath, args.att)
    results, strictAcc = model.runPredict(valTestReader, predFile)
    model.destruct()
    valTestReader.destruct()
    print('predictions made')
    
    vqa = VQA(config.testAnnotFileUnresolved, config.originalValQns)
    vqaRes = vqa.loadRes(results, config.originalValQns)
    vqaEval = VQAEval(vqa, vqaRes, n=2)
    vqaEval.evaluate() 
    
    print('Writing to file..')
    writeToFile(vqaEval, restoreModelPath, vqa, vqaRes, args, strictAcc)
    print('Internal test complete')