logging.info("Running evaluation!") numRounds = params['numRounds'] if 'ckpt_iterid' in params: iterId = params['ckpt_iterid'] + 1 else: iterId = -1 for split in splits: if split == 'train': splitName = 'full train - {}'.format(params['evalTitle']) if split == 'val': splitName = 'full Val - {}'.format(params['evalTitle']) if split == 'test': splitName = 'test - {}'.format(params['evalTitle']) logging.info("Using split %s" % split) dataset.split = split # if params['evalModeList'] == 'ABotRank': if 'ABotRank' in params['evalModeList']: #print("Performing ABotRank evaluation") logging.info( "Performing ABotRank evaluation on split {}".format(split)) if params['qaCategory'] and params['categoryMap']: logging.info( "Evaluating only on rounds in the category \"{}\"".format( params['qaCategory'])) rankMetrics_category = rankABot_category_specific( aBot, dataset, split, params['qaCategory'],
# parameters.extend(qBot.parameters()) # qBot = nn.DataParallel(qBot) # Loading AQM-Bot if params['trainMode'] in ['aqmbot-ind', 'aqmbot-dep']: aqmBot, loadedParams, optim_state = utils.loadModel(params, 'AQM-qbot') for key in loadedParams: params[key] = loadedParams[key] # Filtering parameters which require a gradient update parameters.extend(filter(lambda p: p.requires_grad, aqmBot.parameters())) # parameters.extend(qBot.parameters()) # aqmBot = nn.DataParallel(aqmBot) # Setup pytorch dataloader dataset.split = 'train' dataloader = DataLoader(dataset, batch_size=params['batchSize'], shuffle=False, num_workers=params['numWorkers'], drop_last=True, collate_fn=dataset.collate_fn, pin_memory=False) # Initializing visdom environment for plotting data viz = VisdomVisualize(enable=bool(params['enableVisdom']), env_name=params['visdomEnv'], server=params['visdomServer'], port=params['visdomServerPort']) pprint.pprint(params) viz.addText(pprint.pformat(params, indent=4))
def main(params): aqmSetting = None if ("AQMBotRank" in params["evalModeList"] or "AQMdialog" in params["evalModeList"] or "AQMdemo" in params["evalModeList"]): aqmSetting = getAQMSetting(params) # setup dataloader dlparams = params.copy() dlparams['useIm'] = True dlparams['useHistory'] = True dlparams['numRounds'] = 10 splits = ['val', 'test'] dataset = VisDialDataset(dlparams, splits) # Transferring dataset parameters transfer = ['vocabSize', 'numOptions', 'numRounds'] for key in transfer: if hasattr(dataset, key): params[key] = getattr(dataset, key) if 'numRounds' not in params: params['numRounds'] = 10 # Always load checkpoint parameters with continue flag params['continue'] = True excludeParams = ['batchSize', 'visdomEnv', 'startFrom', 'qstartFrom', 'trainMode', \ 'evalModeList', 'inputImg', 'inputQues', 'inputJson', 'evalTitle', 'beamSize', \ 'enableVisdom', 'visdomServer', 'visdomServerPort', 'randomCaption', 'zeroCaption', 'numImg', 'numQ', 'numA', 'alpha', 'qbeamSize', 'gamma', 'delta', 'lambda', 'onlyGuesser', 'randQ', 'gen1Q', 'gtQ', 'randA', 'noHistory', 'slGuesser', 'resampleEveryDialog'] aBot = None qBot = None aqmBot = None # load aBot print('load aBot') if params['startFrom']: aBot, loadedParams, _ = utils.loadModel(params, 'abot', overwrite=True) assert aBot.encoder.vocabSize == dataset.vocabSize, "Vocab size mismatch!" for key in loadedParams: params[key] = loadedParams[key] aBot.eval() # Retaining certain dataloder parameters for key in excludeParams: params[key] = dlparams[key] print('load qBot') # load qBot if params['qstartFrom'] and not params['aqmstartFrom']: qBot, loadedParams, _ = utils.loadModel(params, 'qbot', overwrite=True) assert qBot.encoder.vocabSize == params[ 'vocabSize'], "Vocab size mismatch!" for key in loadedParams: params[key] = loadedParams[key] qBot.eval() # Retaining certain dataloder parameters for key in excludeParams: params[key] = dlparams[key] print('load AQM-Bot') # load aqmBot if params['aqmstartFrom']: # abot of AQM assert params['qstartFrom'] # qbot of AQM aqmBot, loadedParams, _ = utils.loadModel(params, 'AQM-qbot', overwrite=True) assert aqmBot.questioner.encoder.vocabSize == params[ 'vocabSize'], "Vocab size mismatch!" for key in loadedParams: params[key] = loadedParams[key] aqmBot.eval() # load qBot for key in excludeParams: params[key] = dlparams[key] aqmQ, loadedParams, _ = utils.loadModel(params, 'qbot', overwrite=True) assert aqmQ.encoder.vocabSize == params[ 'vocabSize'], "Vocab size mismatch!" for key in loadedParams: params[key] = loadedParams[key] aqmQ.eval() for key in excludeParams: params[key] = dlparams[key] aqmBot.setQuestioner(aqmQ) elif params['aqmQStartFrom']: from visdial.models.aqm_questioner import AQMQuestioner aqmBot = AQMQuestioner() aqmBot.eval() params['qstartFrom'] = params['aqmQStartFrom'] aqmQ, loadedParams, _ = utils.loadModel(params, 'qbot', overwrite=True) assert aqmQ.encoder.vocabSize == params[ 'vocabSize'], "Vocab size mismatch!" for key in loadedParams: params[key] = loadedParams[key] aqmQ.eval() for key in excludeParams: params[key] = dlparams[key] aqmBot.setQuestioner(aqmQ) params['startFrom'] = params['aqmAStartFrom'] aqmA, loadedParams, _ = utils.loadModel(params, 'abot', overwrite=True) assert aqmA.encoder.vocabSize == dataset.vocabSize, "Vocab size mismatch!" for key in loadedParams: params[key] = loadedParams[key] aqmA.eval() aqmBot.setAppAnswerer(aqmA) for key in excludeParams: params[key] = dlparams[key] pprint.pprint(params) #viz.addText(pprint.pformat(params, indent=4)) print("Running evaluation!") numRounds = params['numRounds'] if 'ckpt_iterid' in params: iterId = params['ckpt_iterid'] + 1 else: iterId = -1 if 'test' in splits: split = 'test' splitName = 'test - {}'.format(params['evalTitle']) else: split = 'val' splitName = 'full Val - {}'.format(params['evalTitle']) print("Using split %s" % split) dataset.split = split if 'ABotRank' in params['evalModeList']: if params['aqmstartFrom']: aBot = aqmBot.appAnswerer print('evaluating appBot of AQM') print("Performing ABotRank evaluation") rankMetrics = rankABot(aBot, dataset, split, scoringFunction=utils.maskedNll, expLowerLimit=params['expLowerLimit'], expUpperLimit=params['expUpperLimit']) print(rankMetrics) for metric, value in rankMetrics.items(): plotName = splitName + ' - ABot Rank' #viz.linePlot(iterId, value, plotName, metric, xlabel='Iterations') if 'QBotRank' in params['evalModeList']: print("Performing QBotRank evaluation") rankMetrics, roundRanks = rankQBot( qBot, dataset, split, expLowerLimit=params['expLowerLimit'], expUpperLimit=params['expUpperLimit'], verbose=1) for metric, value in rankMetrics.items(): plotName = splitName + ' - QBot Rank' #viz.linePlot(iterId, value, plotName, metric, xlabel='Iterations') for r in range(numRounds + 1): for metric, value in roundRanks[r].items(): plotName = '[Iter %d] %s - QABots Rank Roundwise' % \ (iterId, splitName) #viz.linePlot(r, value, plotName, metric, xlabel='Round') if 'QABotsRank' in params['evalModeList']: print("Performing QABotsRank evaluation") outputPredFile = "data/visdial/visdial/output_predictions_rollout.h5" rankMetrics, roundRanks = rankQABots( qBot, aBot, dataset, split, beamSize=params['beamSize'], expLowerLimit=params['expLowerLimit'], expUpperLimit=params['expUpperLimit'], zeroCaption=params['zeroCaption'], randomCaption=params['randomCaption'], numRounds=params['runRounds']) for metric, value in rankMetrics.items(): plotName = splitName + ' - QABots Rank' #viz.linePlot(iterId, value, plotName, metric, xlabel='Iterations') for r in range(numRounds + 1): for metric, value in roundRanks[r].items(): plotName = '[Iter %d] %s - QBot All Metrics vs Round'%\ (iterId, splitName) #viz.linePlot(r, value, plotName, metric, xlabel='Round') if 'AQMBotRank' in params['evalModeList']: print("Performing AQMBotRank evaluation") outputPredFile = "data/visdial/visdial/output_predictions_rollout.h5" rankMetrics, roundRanks = AQMRunner( aqmBot, aBot, dataset, split, beamSize=params['beamSize'], realQA=params['aqmRealQA'], saveLogs=params['saveLogs'], showQA=params['showQA'], expLowerLimit=params['expLowerLimit'], expUpperLimit=params['expUpperLimit'], selectedBatchIdxs=params['selectedBatchIdxs'], numRounds=params['runRounds'], lda=params['lambda'], onlyGuesser=params['onlyGuesser'], numQ=params['numQ'], qbeamSize=params['qbeamSize'], numImg=params['numImg'], alpha=params['alpha'], numA=params['numA'], randQ=params['randQ'], randA=params['randA'], zeroCaption=params['zeroCaption'], randomCaption=params['randomCaption'], gamma=params['gamma'], delta=params['delta'], gen1Q=params['gen1Q'], gtQ=params['gtQ'], noHistory=params['noHistory'], slGuesser=params['slGuesser'], resampleEveryDialog=params['resampleEveryDialog'], aqmSetting=aqmSetting, ).rankQuestioner() for metric, value in rankMetrics.items(): plotName = splitName + ' - QABots Rank' #viz.linePlot(iterId, value, plotName, metric, xlabel='Iterations') for r in range(numRounds + 1): for metric, value in roundRanks[r].items(): plotName = '[Iter %d] %s - QBot All Metrics vs Round'%\ (iterId, splitName) #viz.linePlot(r, value, plotName, metric, xlabel='Round') if 'dialog' in params['evalModeList']: print("Performing dialog generation...") split = 'test' outputFolder = "dialog_output/results" os.makedirs(outputFolder, exist_ok=True) outputPath = os.path.join(outputFolder, "results.json") dialogDump(params, dataset, split, aBot=aBot, qBot=qBot, expLowerLimit=params['expLowerLimit'], expUpperLimit=params['expUpperLimit'], beamSize=params['beamSize'], savePath=outputPath) if 'AQMdialog' in params['evalModeList']: print("Performing AQM dialog generation...") split = 'test' AQMRunner( aqmBot, aBot, dataset, split, beamSize=params['beamSize'], realQA=params['aqmRealQA'], saveLogs=params['saveLogs'], showQA=params['showQA'], expLowerLimit=params['expLowerLimit'], expUpperLimit=params['expUpperLimit'], selectedBatchIdxs=params['selectedBatchIdxs'], numRounds=params['runRounds'], lda=params['lambda'], onlyGuesser=params['onlyGuesser'], numQ=params['numQ'], qbeamSize=params['qbeamSize'], numImg=params['numImg'], alpha=params['alpha'], numA=params['numA'], randQ=params['randQ'], randA=params['randA'], zeroCaption=params['zeroCaption'], randomCaption=params['randomCaption'], gamma=params['gamma'], delta=params['delta'], gen1Q=params['gen1Q'], gtQ=params['gtQ'], noHistory=params['noHistory'], slGuesser=params['slGuesser'], resampleEveryDialog=params['resampleEveryDialog'], aqmSetting=aqmSetting, ).dialogDump(params)