示例#1
0
def main():
    with open(config.configFile(), "a+") as outFile:
        json.dump(vars(config), outFile)

    # set gpus
    if config.gpus != "":
        config.gpusNum = len(config.gpus.split(","))
        os.environ["CUDA_VISIBLE_DEVICES"] = config.gpus

    tf.logging.set_verbosity(tf.logging.ERROR)

    # process data
    print(bold("Preprocess data..."))
    start = time.time()
    preprocessor = Preprocesser()
    data, embeddings, answerDict = preprocessor.preprocessData()
    print("took {} seconds".format(
        bcolored("{:.2f}".format(time.time() - start), "blue")))

    # build model
    print(bold("Building model..."))
    start = time.time()
    model = MACnet(embeddings, answerDict, raw=config.raw_image)
    print("took {} seconds".format(
        bcolored("{:.2f}".format(time.time() - start), "blue")))

    # initializer
    init = tf.global_variables_initializer()

    # savers
    savers = setSavers(model)
    saver, emaSaver, resnet_saver = \
        savers["saver"], savers["emaSaver"], savers['resnet_saver']

    # sessionConfig
    sessionConfig = setSession()

    with tf.Session(config=sessionConfig) as sess:

        # ensure no more ops are added after model is built
        sess.graph.finalize()

        # restore / initialize weights, initialize epoch variable
        epoch = loadWeights(sess, saver, init, resnet_saver)

        if config.train:
            start0 = time.time()

            bestEpoch = epoch
            bestRes = None
            prevRes = None

            # epoch in [restored + 1, epochs]
            for epoch in range(config.restoreEpoch + 1, config.epochs + 1):
                print(bcolored("Training epoch {}...".format(epoch), "green"))
                start = time.time()

                # train
                # calle = lambda: model.runEpoch(), collectRuntimeStats, writer
                trainingData, alterData = chooseTrainingData(data)
                trainRes = runEpoch(sess,
                                    model,
                                    trainingData,
                                    train=True,
                                    epoch=epoch,
                                    saver=saver,
                                    alterData=alterData,
                                    raw=config.raw_image)

                # save weights
                saver.save(sess, config.weightsFile(epoch))
                if config.saveSubset:
                    subsetSaver.save(sess, config.subsetWeightsFile(epoch))

                # load EMA weights
                if config.useEMA:
                    print(bold("Restoring EMA weights"))
                    emaSaver.restore(sess, config.weightsFile(epoch))

                # evaluation
                evalRes = runEvaluation(sess,
                                        model,
                                        data["main"],
                                        epoch,
                                        raw=config.raw_image)
                extraEvalRes = runEvaluation(sess,
                                             model,
                                             data["extra"],
                                             epoch,
                                             evalTrain=not config.extraVal,
                                             raw=config.raw_image)

                # restore standard weights
                if config.useEMA:
                    print(bold("Restoring standard weights"))
                    saver.restore(sess, config.weightsFile(epoch))

                print("")

                epochTime = time.time() - start
                print("took {:.2f} seconds".format(epochTime))

                # print results
                printDatasetResults(trainRes, evalRes, extraEvalRes)

                # stores predictions and optionally attention maps
                if config.getPreds:
                    print(bcolored("Writing predictions...", "white"))
                    writePreds(preprocessor, evalRes, extraEvalRes)

                logRecord(epoch, epochTime, config.lr, trainRes, evalRes,
                          extraEvalRes)

                # update best result
                # compute curr and prior
                currRes = {
                    "train": trainRes,
                    "val": evalRes["val"],
                    'test': evalRes['test']
                }
                curr = {"res": currRes, "epoch": epoch}

                if bestRes is None or better(currRes, bestRes):
                    bestRes = currRes
                    bestEpoch = epoch

                prior = {
                    "best": {
                        "res": bestRes,
                        "epoch": bestEpoch
                    },
                    "prev": {
                        "res": prevRes,
                        "epoch": epoch - 1
                    }
                }

                # lr reducing
                if config.lrReduce:
                    if not improveEnough(curr, prior, config.lr):
                        config.lr *= config.lrDecayRate
                        print(
                            colored("Reducing LR to {}".format(config.lr),
                                    "red"))

                # early stopping
                if config.earlyStopping > 0:
                    if epoch - bestEpoch > config.earlyStopping:
                        break

                # update previous result
                prevRes = currRes

            # reduce epoch back to the last one we trained on
            epoch -= 1
            print("Training took {:.2f} seconds ({:} epochs)".format(
                time.time() - start0, epoch - config.restoreEpoch))

        if config.finalTest:
            print("Testing on epoch {}...".format(epoch))

            start = time.time()
            if epoch > 0:
                if config.useEMA:
                    emaSaver.restore(sess, config.weightsFile(epoch))
                else:
                    saver.restore(sess, config.weightsFile(epoch))

            evalRes = runEvaluation(sess,
                                    model,
                                    data["main"],
                                    epoch,
                                    evalTest=True,
                                    raw=config.raw_image)
            extraEvalRes = runEvaluation(sess,
                                         model,
                                         data["extra"],
                                         epoch,
                                         evalTrain=not config.extraVal,
                                         evalTest=True,
                                         raw=config.raw_image)

            print("took {:.2f} seconds".format(time.time() - start))
            printDatasetResults(None, evalRes, extraEvalRes)

            print("Writing predictions...")
            writePreds(preprocessor, evalRes, extraEvalRes)

        print(bcolored("Done!", "white"))
示例#2
0
def main():
    with open(config.configFile(), "a+") as outFile:
        json.dump(vars(config), outFile)

    tf.set_random_seed(config.tfseed)

    # set gpus
    if config.gpus != "":
        config.gpusNum = len(config.gpus.split(","))
        os.environ["CUDA_VISIBLE_DEVICES"] = config.gpus

    tf.logging.set_verbosity(tf.logging.ERROR)

    # process data
    print(bold("Preprocess data..."))
    start = time.time()
    preprocessor = Preprocesser()
    data, embeddings, answerDict, questionDict = preprocessor.preprocessData()
    print("took {} seconds".format(
        bcolored("{:.2f}".format(time.time() - start), "blue")))

    nextElement = None
    dataOps = None

    # build model
    print(bold("Building model..."))
    start = time.time()
    model = MACnet(embeddings, answerDict, questionDict, nextElement)
    print("took {} seconds".format(
        bcolored("{:.2f}".format(time.time() - start), "blue")))

    # initializer
    init = tf.global_variables_initializer()

    # savers
    savers = setSavers(model)
    saver, emaSaver = savers["saver"], savers["emaSaver"]

    # sessionConfig
    sessionConfig = setSession()

    with tf.Session(config=sessionConfig) as sess:

        # ensure no more ops are added after model is built
        sess.graph.finalize()

        # restore / initialize weights, initialize epoch variable
        epoch = loadWeights(sess, saver, init)

        trainRes, evalRes = None, None

        if config.train:
            start0 = time.time()

            bestEpoch = epoch
            bestRes = None
            prevRes = None

            # epoch in [restored + 1, epochs]
            for epoch in range(config.restoreEpoch + 1, config.epochs + 1):
                print(bcolored("Training epoch {}...".format(epoch), "green"))
                start = time.time()

                # train
                # calle = lambda: model.runEpoch(), collectRuntimeStats, writer
                trainingData, alterData = chooseTrainingData(data)
                trainRes = runEpoch(
                    sess,
                    model,
                    trainingData,
                    dataOps,
                    train=True,
                    epoch=epoch,
                    saver=saver,
                    alterData=alterData,
                    maxAcc=trainRes["maxAcc"] if trainRes else 0.0,
                    minLoss=trainRes["minLoss"] if trainRes else float("inf"),
                )

                # save weights
                saver.save(sess, config.weightsFile(epoch))
                if config.saveSubset:
                    subsetSaver.save(sess, config.subsetWeightsFile(epoch))

                # load EMA weights
                if config.useEMA:
                    print(bold("Restoring EMA weights"))
                    emaSaver.restore(sess, config.weightsFile(epoch))

                # evaluation
                getPreds = config.getPreds or (config.analysisType != "")

                evalRes = runEvaluation(sess,
                                        model,
                                        data["main"],
                                        dataOps,
                                        epoch,
                                        getPreds=getPreds,
                                        prevRes=evalRes)
                extraEvalRes = runEvaluation(sess,
                                             model,
                                             data["extra"],
                                             dataOps,
                                             epoch,
                                             evalTrain=not config.extraVal,
                                             getPreds=getPreds)

                # restore standard weights
                if config.useEMA:
                    print(bold("Restoring standard weights"))
                    saver.restore(sess, config.weightsFile(epoch))

                print("")

                epochTime = time.time() - start
                print("took {:.2f} seconds".format(epochTime))

                # print results
                printDatasetResults(trainRes, evalRes, extraEvalRes)

                # stores predictions and optionally attention maps
                if config.getPreds:
                    print(bcolored("Writing predictions...", "white"))
                    writePreds(preprocessor, evalRes, extraEvalRes)

                logRecord(epoch, epochTime, config.lr, trainRes, evalRes,
                          extraEvalRes)

                # update best result
                # compute curr and prior
                currRes = {"train": trainRes, "val": evalRes["val"]}
                curr = {"res": currRes, "epoch": epoch}

                if bestRes is None or better(currRes, bestRes):
                    bestRes = currRes
                    bestEpoch = epoch

                prior = {
                    "best": {
                        "res": bestRes,
                        "epoch": bestEpoch
                    },
                    "prev": {
                        "res": prevRes,
                        "epoch": epoch - 1
                    }
                }

                # lr reducing
                if config.lrReduce:
                    if not improveEnough(curr, prior, config.lr):
                        config.lr *= config.lrDecayRate
                        print(
                            colored("Reducing LR to {}".format(config.lr),
                                    "red"))

                # early stopping
                if config.earlyStopping > 0:
                    if epoch - bestEpoch > config.earlyStopping:
                        break

                # update previous result
                prevRes = currRes

            # reduce epoch back to the last one we trained on
            epoch -= 1
            print("Training took {:.2f} seconds ({:} epochs)".format(
                time.time() - start0, epoch - config.restoreEpoch))

        if config.finalTest:
            print("Testing on epoch {}...".format(epoch))

            start = time.time()
            if epoch > 0:
                if config.useEMA:
                    emaSaver.restore(sess, config.weightsFile(epoch))
                else:
                    saver.restore(sess, config.weightsFile(epoch))

            evalRes = runEvaluation(sess,
                                    model,
                                    data["main"],
                                    dataOps,
                                    epoch,
                                    evalTest=False,
                                    getPreds=True)
            extraEvalRes = runEvaluation(sess,
                                         model,
                                         data["extra"],
                                         dataOps,
                                         epoch,
                                         evalTrain=not config.extraVal,
                                         evalTest=False,
                                         getPreds=True)

            print("took {:.2f} seconds".format(time.time() - start))
            printDatasetResults(trainRes, evalRes, extraEvalRes)

            print("Writing predictions...")
            writePreds(preprocessor, evalRes, extraEvalRes)

        if config.interactive:
            if epoch > 0:
                if config.useEMA:
                    emaSaver.restore(sess, config.weightsFile(epoch))
                else:
                    saver.restore(sess, config.weightsFile(epoch))

            tier = config.interactiveTier
            images = data["main"][tier]["images"]

            imgsInfoFilename = config.imgsInfoFile(tier)
            with open(imgsInfoFilename, "r") as file:
                imageIndex = json.load(file)

            openImageFiles(images)

            resInter = {"preds": []}

            while True:

                text = inp("Enter <imageId>_<question>\n")
                if len(text) == 0:
                    break

                imageId, questionStr = text.split("_")

                imageInfo = imageIndex[imageId]

                imageId = {
                    "group": tier,
                    "id": imageId,
                    "idx": imageInfo["idx"]
                }  # int(imageId)
                question = preprocessor.encodeQuestionStr(questionStr)
                instance = {
                    "questionStr": questionStr,
                    "question": question,
                    "answer": "yes",  # Dummy answer
                    "answerFreq": ["yes"],  # Dummy answer
                    "imageId": imageId,
                    "tier": tier,
                    "index": 0
                }

                if config.imageObjects:
                    instance["objectsNum"] = imageInfo["objectsNum"]

                print(instance)

                datum = preprocessor.vectorizeData([instance])
                image = loadImageBatch(images, {"imageIds": [imageId]})
                res = model.runBatch(sess,
                                     datum,
                                     image,
                                     train=False,
                                     getPreds=True,
                                     getAtt=True)
                resInter["preds"].append(instance)

                print(instance["prediction"])

            if config.getPreds:
                print(bcolored("Writing predictions...", "white"))
                preprocessor.writePreds(resInter, "interactive".format())

            closeImageFiles(images)

        print(bcolored("Done!", "white"))