Esempio n. 1
0
def prepare_data(args):
    import utils_v2 as utils  # v3 network is using v2 utils
    if args.slim == True:
        import clairvoyante_v3_slim as cv
    else:
        import clairvoyante_v3 as cv

    utils.SetupEnv()
    m = cv.Clairvoyante()
    m.init()

    m.restoreParameters(args.chkpnt_fn)

    if args.bin_fn != None:
        with open(args.bin_fn, "rb") as fh:
            total = pickle.load(fh)
            XArrayCompressed = pickle.load(fh)
            YArrayCompressed = pickle.load(fh)
            posArrayCompressed = pickle.load(fh)
    else:
        total, XArrayCompressed, YArrayCompressed, posArrayCompressed = \
        utils.GetTrainingArray(args.tensor_fn,
                               args.var_fn,
                               args.bed_fn)

    return m, utils, total, XArrayCompressed, YArrayCompressed, posArrayCompressed
Esempio n. 2
0
def Test22(args, m):
    logging.info("Loading the chr22 dataset ...")
    total, XArrayCompressed, YArrayCompressed, posArrayCompressed = \
    utils.GetTrainingArray("../training/tensor_can_chr22",
                           "../training/var_chr22",
                           "../training/bed")

    logging.info("Testing on the chr22 dataset ...")
    predictStart = time.time()
    predictBatchSize = param.predictBatchSize
    datasetPtr = 0
    XBatch, _, _ = utils.DecompressArray(XArrayCompressed, datasetPtr, predictBatchSize, total)
    bases = []; zs = []; ts = []; ls = []
    base, z, t, l = m.predict(XBatch)
    bases.append(base); zs.append(z); ts.append(t); ls.append(l)
    datasetPtr += predictBatchSize
    while datasetPtr < total:
        XBatch, _, endFlag = utils.DecompressArray(XArrayCompressed, datasetPtr, predictBatchSize, total)
        base, z, t, l = m.predict(XBatch)
        bases.append(base); zs.append(z); ts.append(t); ls.append(l)
        datasetPtr += predictBatchSize
        if endFlag != 0:
            break
    bases = np.concatenate(bases[:]); zs = np.concatenate(zs[:]); ts = np.concatenate(ts[:]); ls = np.concatenate(ls[:])
    print >> sys.stderr, "Prediciton time elapsed: %.2f s" % (time.time() - predictStart)

    # Evaluate the trained model
    YArray, _, _ = utils.DecompressArray(YArrayCompressed, 0, total, total)
    print >> sys.stderr, "Version 2 model, evaluation on base change:"
    allBaseCount = top1Count = top2Count = 0
    for predictV, annotateV in zip(bases, YArray[:,0:4]):
        allBaseCount += 1
        sortPredictV = predictV.argsort()[::-1]
        if np.argmax(annotateV) == sortPredictV[0]: top1Count += 1; top2Count += 1
        elif np.argmax(annotateV) == sortPredictV[1]: top2Count += 1
    print >> sys.stderr, "all/top1/top2/top1p/top2p: %d/%d/%d/%.2f/%.2f" %\
                (allBaseCount, top1Count, top2Count, float(top1Count)/allBaseCount*100, float(top2Count)/allBaseCount*100)
    print >> sys.stderr, "Version 2 model, evaluation on Zygosity:"
    ed = np.zeros( (2,2), dtype=np.int )
    for predictV, annotateV in zip(zs, YArray[:,4:6]):
        ed[np.argmax(annotateV)][np.argmax(predictV)] += 1
    for i in range(2):
        print >> sys.stderr, "\t".join([str(ed[i][j]) for j in range(2)])
    print >> sys.stderr, "Version 2 model, evaluation on variant type:"
    ed = np.zeros( (4,4), dtype=np.int )
    for predictV, annotateV in zip(ts, YArray[:,6:10]):
        ed[np.argmax(annotateV)][np.argmax(predictV)] += 1
    for i in range(4):
        print >> sys.stderr, "\t".join([str(ed[i][j]) for j in range(4)])
    print >> sys.stderr, "Version 2 model, evaluation on indel length:"
    ed = np.zeros( (6,6), dtype=np.int )
    for predictV, annotateV in zip(ls, YArray[:,10:16]):
        ed[np.argmax(annotateV)][np.argmax(predictV)] += 1
    for i in range(6):
        print >> sys.stderr, "\t".join([str(ed[i][j]) for j in range(6)])
Esempio n. 3
0
def Convert(args, utils):
    logging.info("Loading the dataset ...")
    total, XArrayCompressed, YArrayCompressed, posArrayCompressed = \
    utils.GetTrainingArray(args.tensor_fn,
                           args.var_fn,
                           args.bed_fn)

    logging.info("Writing to binary ...")
    fh = open(args.bin_fn, "wb")
    pickle.dump(total, fh)
    pickle.dump(XArrayCompressed, fh)
    pickle.dump(YArrayCompressed, fh)
    pickle.dump(posArrayCompressed, fh)
Esempio n. 4
0
def Prepare(args):
    import utils_v2 as utils  # v3 network is using v2 utils
    if args.slim == True:
        import clairvoyante_v3_slim as cv
    else:
        import clairvoyante_v3 as cv

    utils.SetupEnv()
    m = cv.Clairvoyante()
    m.init()

    m.restoreParameters(args.chkpnt_fn)

    total, XArrayCompressed, YArrayCompressed, posArrayCompressed = \
    utils.GetTrainingArray(args.tensor_fn, args.var_fn, None)

    return m, utils, total, XArrayCompressed, YArrayCompressed, posArrayCompressed
Esempio n. 5
0
def Run(args):
    # create a Clairvoyante
    if args.v2 == True:
        import utils_v2 as utils
        if args.slim == True:
            import clairvoyante_v2_slim as cv
        else:
            import clairvoyante_v2 as cv
    elif args.v3 == True:
        import utils_v2 as utils # v3 network is using v2 utils
        if args.slim == True:
            import clairvoyante_v3_slim as cv
        else:
            import clairvoyante_v3 as cv
    utils.SetupEnv()
    m = cv.Clairvoyante()
    m.init()

    if args.bin_fn != None:
        with open(args.bin_fn, "rb") as fh:
            total = pickle.load(fh)
            XArrayCompressed = pickle.load(fh)
            YArrayCompressed = pickle.load(fh)
            posArrayCompressed = pickle.load(fh)
    else:
        total, XArrayCompressed, YArrayCompressed, posArrayCompressed = \
        utils.GetTrainingArray(args.tensor_fn,
                               args.var_fn,
                               args.bed_fn)

    with open(args.chkpnt_list) as fh:
        for row in fh:
            row = row.rstrip()
            logging.info("Working on model: %s" % (row))
            m.restoreParameters(os.path.abspath(row))
            Test(args, m, utils, total, XArrayCompressed, YArrayCompressed, posArrayCompressed)
Esempio n. 6
0
def CalcAll(args, m, utils):
    if args.bin_fn != None:
        with open(args.bin_fn, "rb") as fh:
            total = pickle.load(fh)
            XArrayCompressed = pickle.load(fh)
            YArrayCompressed = pickle.load(fh)
            posArrayCompressed = pickle.load(fh)
    else:
        total, XArrayCompressed, YArrayCompressed, posArrayCompressed = \
        utils.GetTrainingArray(args.tensor_fn,
                               args.var_fn,
                               args.bed_fn)

    predictBatchSize = param.predictBatchSize
    trainingTotal = int(total * param.trainingDatasetPercentage)
    validationStart = trainingTotal + 1
    numValItems = total - validationStart

    for n in args.chkpnt_fn:
        m.restoreParameters(os.path.abspath(n))
        datasetPtr = 0
        trainingLost = 0
        validationLost = 0
        i = 1
        XBatch, XNum, XEndFlag = utils.DecompressArray(XArrayCompressed,
                                                       datasetPtr,
                                                       predictBatchSize, total)
        YBatch, YNum, YEndFlag = utils.DecompressArray(YArrayCompressed,
                                                       datasetPtr,
                                                       predictBatchSize, total)
        datasetPtr += XNum
        while True:
            threadPool = []
            threadPool.append(
                Thread(target=m.getLossNoRT, args=(
                    XBatch,
                    YBatch,
                )))
            for t in threadPool:
                t.start()
            predictBatchSize = param.predictBatchSize
            if datasetPtr < validationStart and (
                    validationStart - datasetPtr) < predictBatchSize:
                predictBatchSize = validationStart - datasetPtr
            elif datasetPtr >= validationStart and (datasetPtr %
                                                    predictBatchSize) != 0:
                predictBatchSize = predictBatchSize - (datasetPtr %
                                                       predictBatchSize)
            #print >> sys.stderr, "%d\t%d\t%d\t%d" % (datasetPtr, predictBatchSize, validationStart, total)
            XBatch2, XNum2, XEndFlag2 = utils.DecompressArray(
                XArrayCompressed, datasetPtr, predictBatchSize, total)
            YBatch2, YNum2, YEndFlag2 = utils.DecompressArray(
                YArrayCompressed, datasetPtr, predictBatchSize, total)
            if XNum != YNum or XEndFlag != YEndFlag:
                sys.exit("Inconsistency between decompressed arrays: %d/%d" %
                         (XNum, YNum))
            for t in threadPool:
                t.join()
            XBatch = XBatch2
            YBatch = YBatch2
            if datasetPtr >= validationStart:
                validationLost += m.getLossLossRTVal
            else:
                trainingLost += m.getLossLossRTVal
            if XEndFlag2 != 0:
                m.getLossNoRT(XBatch, YBatch)
                if datasetPtr >= validationStart:
                    validationLost += m.getLossLossRTVal
                else:
                    trainingLost += m.getLossLossRTVal
                print >> sys.stderr, "%s\t%.10f\t%.10f" % (
                    n, trainingLost / trainingTotal,
                    validationLost / numValItems)
                break
            i += 1
            datasetPtr += XNum2
Esempio n. 7
0
def Test(args, m, utils):
    logging.info("Loading the dataset ...")

    if args.bin_fn != None:
        with open(args.bin_fn, "rb") as fh:
            total = pickle.load(fh)
            XArrayCompressed = pickle.load(fh)
            YArrayCompressed = pickle.load(fh)
            posArrayCompressed = pickle.load(fh)
    else:
        total, XArrayCompressed, YArrayCompressed, posArrayCompressed = \
        utils.GetTrainingArray(args.tensor_fn,
                               args.var_fn,
                               args.bed_fn)

    logging.info("Dataset size: %d" % total)
    logging.info("Testing on the dataset ...")
    predictBatchSize = param.predictBatchSize
    predictStart = time.time()
    if args.v2 == True or args.v3 == True:
        datasetPtr = 0
        XBatch, _, _ = utils.DecompressArray(XArrayCompressed, datasetPtr,
                                             predictBatchSize, total)
        bases = []
        zs = []
        ts = []
        ls = []
        base, z, t, l = m.predict(XBatch)
        bases.append(base)
        zs.append(z)
        ts.append(t)
        ls.append(l)
        datasetPtr += predictBatchSize
        while datasetPtr < total:
            XBatch, _, endFlag = utils.DecompressArray(XArrayCompressed,
                                                       datasetPtr,
                                                       predictBatchSize, total)
            base, z, t, l = m.predict(XBatch)
            bases.append(base)
            zs.append(z)
            ts.append(t)
            ls.append(l)
            datasetPtr += predictBatchSize
            if endFlag != 0:
                break
        bases = np.concatenate(bases[:])
        zs = np.concatenate(zs[:])
        ts = np.concatenate(ts[:])
        ls = np.concatenate(ls[:])
    logging.info("Prediciton time elapsed: %.2f s" %
                 (time.time() - predictStart))

    YArray, _, _ = utils.DecompressArray(YArrayCompressed, 0, total, total)
    if args.v2 == True or args.v3 == True:
        logging.info("Version 2 model, evaluation on base change:")
        allBaseCount = top1Count = top2Count = 0
        for predictV, annotateV in zip(bases, YArray[:, 0:4]):
            allBaseCount += 1
            sortPredictV = predictV.argsort()[::-1]
            if np.argmax(annotateV) == sortPredictV[0]:
                top1Count += 1
                top2Count += 1
            elif np.argmax(annotateV) == sortPredictV[1]:
                top2Count += 1
        logging.info(
            "all/top1/top2/top1p/top2p: %d/%d/%d/%.2f/%.2f" %
            (allBaseCount, top1Count, top2Count, float(top1Count) /
             allBaseCount * 100, float(top2Count) / allBaseCount * 100))
        logging.info("Version 2 model, evaluation on Zygosity:")
        ed = np.zeros((2, 2), dtype=np.int)
        for predictV, annotateV in zip(zs, YArray[:, 4:6]):
            ed[np.argmax(annotateV)][np.argmax(predictV)] += 1
        for i in range(2):
            logging.info("\t".join([str(ed[i][j]) for j in range(2)]))
        logging.info("Version 2 model, evaluation on variant type:")
        ed = np.zeros((4, 4), dtype=np.int)
        for predictV, annotateV in zip(ts, YArray[:, 6:10]):
            ed[np.argmax(annotateV)][np.argmax(predictV)] += 1
        for i in range(4):
            logging.info("\t".join([str(ed[i][j]) for j in range(4)]))
        logging.info("Version 2 model, evaluation on indel length:")
        ed = np.zeros((6, 6), dtype=np.int)
        for predictV, annotateV in zip(ls, YArray[:, 10:16]):
            ed[np.argmax(annotateV)][np.argmax(predictV)] += 1
        for i in range(6):
            logging.info("\t".join([str(ed[i][j]) for j in range(6)]))
Esempio n. 8
0
def TrainAll(args, m):
    logging.info("Loading the training dataset ...")
    total, XArrayCompressed, YArrayCompressed, posArrayCompressed = \
    utils.GetTrainingArray("../training/tensor_can_chr21",
                           "../training/var_chr21",
                           "../training/bed")

    logging.info("The size of training dataset: {}".format(total))

    # op to write logs to Tensorboard
    if args.olog != None:
        summaryWriter = m.summaryFileWriter(args.olog)

    # training and save the parameters, we train on the first 90% variant sites and validate on the last 10% variant sites
    logging.info("Start training ...")
    trainingStart = time.time()
    trainBatchSize = param.trainBatchSize
    validationLosts = []
    logging.info("Start at learning rate: %.2e" %
                 m.setLearningRate(args.learning_rate))
    c = 0
    maxLearningRateSwitch = param.maxLearningRateSwitch
    epochStart = time.time()
    datasetPtr = 0
    trainingTotal = int(total * param.trainingDatasetPercentage)
    validationStart = trainingTotal + 1
    numValItems = total - validationStart
    valXArray, _, _ = utils.DecompressArray(XArrayCompressed, validationStart,
                                            numValItems, total)
    valYArray, _, _ = utils.DecompressArray(YArrayCompressed, validationStart,
                                            numValItems, total)
    logging.info("Number of variants for validation: %d" % len(valXArray))
    i = 1
    while i < (1 +
               int(param.maxEpoch * trainingTotal / trainBatchSize + 0.499)):
        XBatch, num, endFlag = utils.DecompressArray(XArrayCompressed,
                                                     datasetPtr,
                                                     trainBatchSize,
                                                     trainingTotal)
        YBatch, num2, endFlag2 = utils.DecompressArray(YArrayCompressed,
                                                       datasetPtr,
                                                       trainBatchSize,
                                                       trainingTotal)
        if num != num2 or endFlag != endFlag2:
            sys.exit("Inconsistency between decompressed arrays: %d/%d" %
                     (num, num2))
        loss, summary = m.train(XBatch, YBatch)
        if args.olog != None:
            summaryWriter.add_summary(summary, i)
        if endFlag != 0:
            validationLost = m.getLoss(valXArray, valYArray)
            logging.info(" ".join([
                str(i), "Training loss:",
                str(loss / trainBatchSize), "Validation loss: ",
                str(validationLost / numValItems)
            ]))
            logging.info("Epoch time elapsed: %.2f s" %
                         (time.time() - epochStart))
            validationLosts.append((validationLost, i))
            c += 1
            flag = 0
            flipFlop = 0
            if c >= 6:
                if validationLosts[-6][0] - validationLosts[-5][0] <= 0:
                    flipFlop += 1
                if validationLosts[-5][0] - validationLosts[-4][0] <= 0:
                    flipFlop += 1
                if validationLosts[-4][0] - validationLosts[-3][0] <= 0:
                    flipFlop += 1
                if validationLosts[-3][0] - validationLosts[-2][0] <= 0:
                    flipFlop += 1
                if validationLosts[-2][0] - validationLosts[-1][0] <= 0:
                    flipFlop += 1
            if flipFlop >= 3:
                maxLearningRateSwitch -= 1
                if maxLearningRateSwitch == 0:
                    break
                logging.info("New learning rate: %.2e" % m.setLearningRate())
                c = 0
            epochStart = time.time()
            datasetPtr = 0
        i += 1
        datasetPtr += trainBatchSize

    logging.info("Training time elapsed: %.2f s" %
                 (time.time() - trainingStart))

    # show the parameter set with the smallest validation loss
    validationLosts.sort()
    i = validationLosts[0][1]
    logging.info("Best validation loss at batch: %d" % i)

    logging.info("Testing on the training dataset ...")
    predictStart = time.time()
    predictBatchSize = param.predictBatchSize
    datasetPtr = 0
    XBatch, _, _ = utils.DecompressArray(XArrayCompressed, datasetPtr,
                                         predictBatchSize, total)
    bases = []
    zs = []
    ts = []
    ls = []
    base, z, t, l = m.predict(XBatch)
    bases.append(base)
    zs.append(z)
    ts.append(t)
    ls.append(l)
    datasetPtr += predictBatchSize
    while datasetPtr < total:
        XBatch, _, endFlag = utils.DecompressArray(XArrayCompressed,
                                                   datasetPtr,
                                                   predictBatchSize, total)
        base, z, t, l = m.predict(XBatch)
        bases.append(base)
        zs.append(z)
        ts.append(t)
        ls.append(l)
        datasetPtr += predictBatchSize
        if endFlag != 0:
            break
    bases = np.concatenate(bases[:])
    zs = np.concatenate(zs[:])
    ts = np.concatenate(ts[:])
    ls = np.concatenate(ls[:])
    print >> sys.stderr, "Prediciton time elapsed: %.2f s" % (time.time() -
                                                              predictStart)

    # Evaluate the trained model
    YArray, _, _ = utils.DecompressArray(YArrayCompressed, 0, total, total)
    print >> sys.stderr, "Version 2 model, evaluation on base change:"
    allBaseCount = top1Count = top2Count = 0
    for predictV, annotateV in zip(bases, YArray[:, 0:4]):
        allBaseCount += 1
        sortPredictV = predictV.argsort()[::-1]
        if np.argmax(annotateV) == sortPredictV[0]:
            top1Count += 1
            top2Count += 1
        elif np.argmax(annotateV) == sortPredictV[1]:
            top2Count += 1
    print >> sys.stderr, "all/top1/top2/top1p/top2p: %d/%d/%d/%.2f/%.2f" %\
                (allBaseCount, top1Count, top2Count, float(top1Count)/allBaseCount*100, float(top2Count)/allBaseCount*100)
    print >> sys.stderr, "Version 2 model, evaluation on Zygosity:"
    ed = np.zeros((2, 2), dtype=np.int)
    for predictV, annotateV in zip(zs, YArray[:, 4:6]):
        ed[np.argmax(annotateV)][np.argmax(predictV)] += 1
    for i in range(2):
        print >> sys.stderr, "\t".join([str(ed[i][j]) for j in range(2)])
    print >> sys.stderr, "Version 2 model, evaluation on variant type:"
    ed = np.zeros((4, 4), dtype=np.int)
    for predictV, annotateV in zip(ts, YArray[:, 6:10]):
        ed[np.argmax(annotateV)][np.argmax(predictV)] += 1
    for i in range(4):
        print >> sys.stderr, "\t".join([str(ed[i][j]) for j in range(4)])
    print >> sys.stderr, "Version 2 model, evaluation on indel length:"
    ed = np.zeros((6, 6), dtype=np.int)
    for predictV, annotateV in zip(ls, YArray[:, 10:16]):
        ed[np.argmax(annotateV)][np.argmax(predictV)] += 1
    for i in range(6):
        print >> sys.stderr, "\t".join([str(ed[i][j]) for j in range(6)])
Esempio n. 9
0
def TrainAll(args, m, utils):
    logging.info("Loading the training dataset ...")
    if args.bin_fn != None:
        with open(args.bin_fn, "rb") as fh:
            total = pickle.load(fh)
            XArrayCompressed = pickle.load(fh)
            YArrayCompressed = pickle.load(fh)
            posArrayCompressed = pickle.load(fh)
    else:
        total, XArrayCompressed, YArrayCompressed, posArrayCompressed = \
        utils.GetTrainingArray(args.tensor_fn,
                               args.var_fn,
                               args.bed_fn)

    logging.info("The size of training dataset: {}".format(total))

    # Op to write logs to Tensorboard
    if args.olog_dir != None:
        summaryWriter = m.summaryFileWriter(args.olog_dir)

    # Train and save the parameters, we train on the first 90% variant sites and validate on the last 10% variant sites
    logging.info("Start training ...")
    logging.info("Learning rate: %.2e" % m.setLearningRate(args.learning_rate))
    logging.info("L2 regularization lambda: %.2e" %
                 m.setL2RegularizationLambda(args.lambd))

    validationLosses = []

    # Model Constants
    trainingStart = time.time()
    trainingTotal = int(total * param.trainingDatasetPercentage)
    validationStart = trainingTotal + 1
    numValItems = total - validationStart
    maxLearningRateSwitch = param.maxLearningRateSwitch

    # Variables reset per epoch
    batchSize = param.trainBatchSize
    epochStart = time.time()
    trainLossSum = 0
    validationLossSum = 0
    datasetPtr = 0

    # Variables reset per learning rate decay
    c = 0

    i = 1 if args.chkpnt_fn == None else int(
        args.chkpnt_fn[-param.parameterOutputPlaceHolder:]) + 1
    XBatch, XNum, XEndFlag = utils.DecompressArray(XArrayCompressed,
                                                   datasetPtr, batchSize,
                                                   total)
    YBatch, YNum, YEndFlag = utils.DecompressArray(YArrayCompressed,
                                                   datasetPtr, batchSize,
                                                   total)
    datasetPtr += XNum
    while i < param.maxEpoch:
        threadPool = []
        if datasetPtr < validationStart:
            threadPool.append(
                Thread(target=m.trainNoRT, args=(
                    XBatch,
                    YBatch,
                )))
        elif datasetPtr >= validationStart:
            threadPool.append(
                Thread(target=m.getLossNoRT, args=(
                    XBatch,
                    YBatch,
                )))

        for t in threadPool:
            t.start()

        if datasetPtr < validationStart and (
                validationStart - datasetPtr) < param.trainBatchSize:
            batchSize = validationStart - datasetPtr
        elif datasetPtr < validationStart:
            batchSize = param.trainBatchSize
        elif datasetPtr >= validationStart and (datasetPtr %
                                                param.predictBatchSize) != 0:
            batchSize = param.predictBatchSize - (datasetPtr %
                                                  param.predictBatchSize)
        elif datasetPtr >= validationStart:
            batchSize = param.predictBatchSize

        XBatch2, XNum2, XEndFlag2 = utils.DecompressArray(
            XArrayCompressed, datasetPtr, batchSize, total)
        YBatch2, YNum2, YEndFlag2 = utils.DecompressArray(
            YArrayCompressed, datasetPtr, batchSize, total)
        if XNum2 != YNum2 or XEndFlag2 != YEndFlag2:
            sys.exit("Inconsistency between decompressed arrays: %d/%d" %
                     (XNum, YNum))

        for t in threadPool:
            t.join()

        XBatch = XBatch2
        YBatch = YBatch2
        if datasetPtr < validationStart:
            trainLossSum += m.trainLossRTVal
            summary = m.trainSummaryRTVal
            if args.olog_dir != None:
                summaryWriter.add_summary(summary, i)
        elif datasetPtr >= validationStart:
            validationLossSum += m.getLossLossRTVal
        datasetPtr += XNum2

        if XEndFlag2 != 0:
            validationLossSum += m.getLoss(XBatch, YBatch)
            logging.info(" ".join([
                str(i), "Training loss:",
                str(trainLossSum / trainingTotal), "Validation loss: ",
                str(validationLossSum / numValItems)
            ]))
            logging.info("Epoch time elapsed: %.2f s" %
                         (time.time() - epochStart))
            validationLosses.append((validationLossSum, i))
            # Output the model
            if args.ochk_prefix != None:
                parameterOutputPath = "%s-%%0%dd" % (
                    args.ochk_prefix, param.parameterOutputPlaceHolder)
                m.saveParameters(os.path.abspath(parameterOutputPath % i))
            # Adaptive learning rate decay
            c += 1
            flag = 0
            if c >= 6:
                if validationLosses[-6][0] - validationLosses[-5][0] > 0:
                    if validationLosses[-5][0] - validationLosses[-4][0] < 0:
                        if validationLosses[-4][0] - validationLosses[-3][
                                0] > 0:
                            if validationLosses[-3][0] - validationLosses[-2][
                                    0] < 0:
                                if validationLosses[-2][0] - validationLosses[
                                        -1][0] > 0:
                                    flag = 1
                elif validationLosses[-6][0] - validationLosses[-5][0] < 0:
                    if validationLosses[-5][0] - validationLosses[-4][0] > 0:
                        if validationLosses[-4][0] - validationLosses[-3][
                                0] < 0:
                            if validationLosses[-3][0] - validationLosses[-2][
                                    0] > 0:
                                if validationLosses[-2][0] - validationLosses[
                                        -1][0] < 0:
                                    flag = 1
                else:
                    flag = 1
            if flag == 1:
                maxLearningRateSwitch -= 1
                if maxLearningRateSwitch == 0:
                    break
                logging.info("New learning rate: %.2e" % m.setLearningRate())
                logging.info("New L2 regularization lambda: %.2e" %
                             m.setL2RegularizationLambda())
                c = 0
            # Reset per epoch variables
            i += 1
            trainLossSum = 0
            validationLossSum = 0
            datasetPtr = 0
            epochStart = time.time()
            batchSize = param.trainBatchSize
            XBatch, XNum, XEndFlag = utils.DecompressArray(
                XArrayCompressed, datasetPtr, batchSize, total)
            YBatch, YNum, YEndFlag = utils.DecompressArray(
                YArrayCompressed, datasetPtr, batchSize, total)
            datasetPtr += XNum

    logging.info("Training time elapsed: %.2f s" %
                 (time.time() - trainingStart))

    # show the parameter set with the smallest validation loss
    validationLosses.sort()
    i = validationLosses[0][1]
    logging.info("Best validation loss at batch: %d" % i)

    logging.info("Testing on the training and validation dataset ...")
    predictStart = time.time()
    predictBatchSize = param.predictBatchSize
    if args.v2 == True or args.v3 == True:
        datasetPtr = 0
        XBatch, _, _ = utils.DecompressArray(XArrayCompressed, datasetPtr,
                                             predictBatchSize, total)
        bases = []
        zs = []
        ts = []
        ls = []
        base, z, t, l = m.predict(XBatch)
        bases.append(base)
        zs.append(z)
        ts.append(t)
        ls.append(l)
        datasetPtr += predictBatchSize
        while datasetPtr < total:
            XBatch, _, endFlag = utils.DecompressArray(XArrayCompressed,
                                                       datasetPtr,
                                                       predictBatchSize, total)
            base, z, t, l = m.predict(XBatch)
            bases.append(base)
            zs.append(z)
            ts.append(t)
            ls.append(l)
            datasetPtr += predictBatchSize
            if endFlag != 0:
                break
        bases = np.concatenate(bases[:])
        zs = np.concatenate(zs[:])
        ts = np.concatenate(ts[:])
        ls = np.concatenate(ls[:])
    logging.info("Prediciton time elapsed: %.2f s" %
                 (time.time() - predictStart))

    # Evaluate the trained model
    YArray, _, _ = utils.DecompressArray(YArrayCompressed, 0, total, total)
    if args.v2 == True or args.v3 == True:
        logging.info("Version 2 model, evaluation on base change:")
        allBaseCount = top1Count = top2Count = 0
        for predictV, annotateV in zip(bases, YArray[:, 0:4]):
            allBaseCount += 1
            sortPredictV = predictV.argsort()[::-1]
            if np.argmax(annotateV) == sortPredictV[0]:
                top1Count += 1
                top2Count += 1
            elif np.argmax(annotateV) == sortPredictV[1]:
                top2Count += 1
        logging.info("all/top1/top2/top1p/top2p: %d/%d/%d/%.2f/%.2f" %\
                    (allBaseCount, top1Count, top2Count, float(top1Count)/allBaseCount*100, float(top2Count)/allBaseCount*100))
        logging.info("Version 2 model, evaluation on Zygosity:")
        ed = np.zeros((2, 2), dtype=np.int)
        for predictV, annotateV in zip(zs, YArray[:, 4:6]):
            ed[np.argmax(annotateV)][np.argmax(predictV)] += 1
        for i in range(2):
            logging.info("\t".join([str(ed[i][j]) for j in range(2)]))
        logging.info("Version 2 model, evaluation on variant type:")
        ed = np.zeros((4, 4), dtype=np.int)
        for predictV, annotateV in zip(ts, YArray[:, 6:10]):
            ed[np.argmax(annotateV)][np.argmax(predictV)] += 1
        for i in range(4):
            logging.info("\t".join([str(ed[i][j]) for j in range(4)]))
        logging.info("Version 2 model, evaluation on indel length:")
        ed = np.zeros((6, 6), dtype=np.int)
        for predictV, annotateV in zip(ls, YArray[:, 10:16]):
            ed[np.argmax(annotateV)][np.argmax(predictV)] += 1
        for i in range(6):
            logging.info("\t".join([str(ed[i][j]) for j in range(6)]))
def TrainAll(args, m, utils):
    logging.info("Loading the training dataset ...")
    if args.bin_fn != None:
        with open(args.bin_fn, "rb") as fh:
            total = pickle.load(fh)
            XArrayCompressed = pickle.load(fh)
            YArrayCompressed = pickle.load(fh)
            posArrayCompressed = pickle.load(fh)
    else:
        total, XArrayCompressed, YArrayCompressed, posArrayCompressed = \
        utils.GetTrainingArray(args.tensor_fn,
                               args.var_fn,
                               args.bed_fn)

    logging.info("The size of training dataset: {}".format(total))

    # Op to write logs to Tensorboard
    if args.olog_dir != None:
        summaryWriter = m.summaryFileWriter(args.olog_dir)

    # Train and save the parameters, we train on the first 90% variant sites and validate on the last 10% variant sites
    logging.info("Start training ...")
    logging.info("Learning rate: %.2e" % m.setLearningRate(args.learning_rate))
    logging.info("L2 regularization lambda: %.2e" %
                 m.setL2RegularizationLambda(args.lambd))

    # Model Constants
    trainingStart = time.time()
    trainingTotal = total
    maxLearningRateSwitch = param.maxLearningRateSwitch
    batchSize = param.trainBatchSize

    # Variables reset per epoch
    batchSize = param.trainBatchSize
    epochStart = time.time()
    trainLossSum = 0
    datasetPtr = 0

    i = 1 if args.chkpnt_fn == None else int(
        args.chkpnt_fn[-param.parameterOutputPlaceHolder:]) + 1
    XBatch, XNum, XEndFlag = utils.DecompressArray(XArrayCompressed,
                                                   datasetPtr, batchSize,
                                                   total)
    YBatch, YNum, YEndFlag = utils.DecompressArray(YArrayCompressed,
                                                   datasetPtr, batchSize,
                                                   total)
    datasetPtr += XNum
    while i < param.maxEpoch:
        threadPool = []
        threadPool.append(Thread(target=m.trainNoRT, args=(
            XBatch,
            YBatch,
        )))

        for t in threadPool:
            t.start()

        XBatch2, XNum2, XEndFlag2 = utils.DecompressArray(
            XArrayCompressed, datasetPtr, batchSize, total)
        YBatch2, YNum2, YEndFlag2 = utils.DecompressArray(
            YArrayCompressed, datasetPtr, batchSize, total)
        if XNum2 != YNum2 or XEndFlag2 != YEndFlag2:
            sys.exit("Inconsistency between decompressed arrays: %d/%d" %
                     (XNum, YNum))

        for t in threadPool:
            t.join()

        XBatch = XBatch2
        YBatch = YBatch2
        trainLossSum += m.trainLossRTVal
        summary = m.trainSummaryRTVal
        if args.olog_dir != None:
            summaryWriter.add_summary(summary, i)
        datasetPtr += XNum2

        if XEndFlag2 != 0:
            logging.info(" ".join(
                [str(i), "Training loss:",
                 str(trainLossSum / trainingTotal)]))
            logging.info("Epoch time elapsed: %.2f s" %
                         (time.time() - epochStart))
            parameterOutputPath = "%s-%%0%dd" % (
                args.ochk_prefix, param.parameterOutputPlaceHolder)
            m.saveParameters(parameterOutputPath % i)

            # Reset per epoch variables
            i += 1
            trainLossSum = 0
            datasetPtr = 0
            epochStart = time.time()
            batchSize = param.trainBatchSize
            XBatch, XNum, XEndFlag = utils.DecompressArray(
                XArrayCompressed, datasetPtr, batchSize, total)
            YBatch, YNum, YEndFlag = utils.DecompressArray(
                YArrayCompressed, datasetPtr, batchSize, total)
            datasetPtr += XNum

    logging.info("Training time elapsed: %.2f s" %
                 (time.time() - trainingStart))