def prepare_data(args): import utils_v2 as utils # v3 network is using v2 utils if args.slim == True: import clairvoyante_v3_slim as cv else: import clairvoyante_v3 as cv utils.SetupEnv() m = cv.Clairvoyante() m.init() m.restoreParameters(args.chkpnt_fn) if args.bin_fn != None: with open(args.bin_fn, "rb") as fh: total = pickle.load(fh) XArrayCompressed = pickle.load(fh) YArrayCompressed = pickle.load(fh) posArrayCompressed = pickle.load(fh) else: total, XArrayCompressed, YArrayCompressed, posArrayCompressed = \ utils.GetTrainingArray(args.tensor_fn, args.var_fn, args.bed_fn) return m, utils, total, XArrayCompressed, YArrayCompressed, posArrayCompressed
def Test22(args, m): logging.info("Loading the chr22 dataset ...") total, XArrayCompressed, YArrayCompressed, posArrayCompressed = \ utils.GetTrainingArray("../training/tensor_can_chr22", "../training/var_chr22", "../training/bed") logging.info("Testing on the chr22 dataset ...") predictStart = time.time() predictBatchSize = param.predictBatchSize datasetPtr = 0 XBatch, _, _ = utils.DecompressArray(XArrayCompressed, datasetPtr, predictBatchSize, total) bases = []; zs = []; ts = []; ls = [] base, z, t, l = m.predict(XBatch) bases.append(base); zs.append(z); ts.append(t); ls.append(l) datasetPtr += predictBatchSize while datasetPtr < total: XBatch, _, endFlag = utils.DecompressArray(XArrayCompressed, datasetPtr, predictBatchSize, total) base, z, t, l = m.predict(XBatch) bases.append(base); zs.append(z); ts.append(t); ls.append(l) datasetPtr += predictBatchSize if endFlag != 0: break bases = np.concatenate(bases[:]); zs = np.concatenate(zs[:]); ts = np.concatenate(ts[:]); ls = np.concatenate(ls[:]) print >> sys.stderr, "Prediciton time elapsed: %.2f s" % (time.time() - predictStart) # Evaluate the trained model YArray, _, _ = utils.DecompressArray(YArrayCompressed, 0, total, total) print >> sys.stderr, "Version 2 model, evaluation on base change:" allBaseCount = top1Count = top2Count = 0 for predictV, annotateV in zip(bases, YArray[:,0:4]): allBaseCount += 1 sortPredictV = predictV.argsort()[::-1] if np.argmax(annotateV) == sortPredictV[0]: top1Count += 1; top2Count += 1 elif np.argmax(annotateV) == sortPredictV[1]: top2Count += 1 print >> sys.stderr, "all/top1/top2/top1p/top2p: %d/%d/%d/%.2f/%.2f" %\ (allBaseCount, top1Count, top2Count, float(top1Count)/allBaseCount*100, float(top2Count)/allBaseCount*100) print >> sys.stderr, "Version 2 model, evaluation on Zygosity:" ed = np.zeros( (2,2), dtype=np.int ) for predictV, annotateV in zip(zs, YArray[:,4:6]): ed[np.argmax(annotateV)][np.argmax(predictV)] += 1 for i in range(2): print >> sys.stderr, "\t".join([str(ed[i][j]) for j in range(2)]) print >> sys.stderr, "Version 2 model, evaluation on variant type:" ed = np.zeros( (4,4), dtype=np.int ) for predictV, annotateV in zip(ts, YArray[:,6:10]): ed[np.argmax(annotateV)][np.argmax(predictV)] += 1 for i in range(4): print >> sys.stderr, "\t".join([str(ed[i][j]) for j in range(4)]) print >> sys.stderr, "Version 2 model, evaluation on indel length:" ed = np.zeros( (6,6), dtype=np.int ) for predictV, annotateV in zip(ls, YArray[:,10:16]): ed[np.argmax(annotateV)][np.argmax(predictV)] += 1 for i in range(6): print >> sys.stderr, "\t".join([str(ed[i][j]) for j in range(6)])
def Convert(args, utils): logging.info("Loading the dataset ...") total, XArrayCompressed, YArrayCompressed, posArrayCompressed = \ utils.GetTrainingArray(args.tensor_fn, args.var_fn, args.bed_fn) logging.info("Writing to binary ...") fh = open(args.bin_fn, "wb") pickle.dump(total, fh) pickle.dump(XArrayCompressed, fh) pickle.dump(YArrayCompressed, fh) pickle.dump(posArrayCompressed, fh)
def Prepare(args): import utils_v2 as utils # v3 network is using v2 utils if args.slim == True: import clairvoyante_v3_slim as cv else: import clairvoyante_v3 as cv utils.SetupEnv() m = cv.Clairvoyante() m.init() m.restoreParameters(args.chkpnt_fn) total, XArrayCompressed, YArrayCompressed, posArrayCompressed = \ utils.GetTrainingArray(args.tensor_fn, args.var_fn, None) return m, utils, total, XArrayCompressed, YArrayCompressed, posArrayCompressed
def Run(args): # create a Clairvoyante if args.v2 == True: import utils_v2 as utils if args.slim == True: import clairvoyante_v2_slim as cv else: import clairvoyante_v2 as cv elif args.v3 == True: import utils_v2 as utils # v3 network is using v2 utils if args.slim == True: import clairvoyante_v3_slim as cv else: import clairvoyante_v3 as cv utils.SetupEnv() m = cv.Clairvoyante() m.init() if args.bin_fn != None: with open(args.bin_fn, "rb") as fh: total = pickle.load(fh) XArrayCompressed = pickle.load(fh) YArrayCompressed = pickle.load(fh) posArrayCompressed = pickle.load(fh) else: total, XArrayCompressed, YArrayCompressed, posArrayCompressed = \ utils.GetTrainingArray(args.tensor_fn, args.var_fn, args.bed_fn) with open(args.chkpnt_list) as fh: for row in fh: row = row.rstrip() logging.info("Working on model: %s" % (row)) m.restoreParameters(os.path.abspath(row)) Test(args, m, utils, total, XArrayCompressed, YArrayCompressed, posArrayCompressed)
def CalcAll(args, m, utils): if args.bin_fn != None: with open(args.bin_fn, "rb") as fh: total = pickle.load(fh) XArrayCompressed = pickle.load(fh) YArrayCompressed = pickle.load(fh) posArrayCompressed = pickle.load(fh) else: total, XArrayCompressed, YArrayCompressed, posArrayCompressed = \ utils.GetTrainingArray(args.tensor_fn, args.var_fn, args.bed_fn) predictBatchSize = param.predictBatchSize trainingTotal = int(total * param.trainingDatasetPercentage) validationStart = trainingTotal + 1 numValItems = total - validationStart for n in args.chkpnt_fn: m.restoreParameters(os.path.abspath(n)) datasetPtr = 0 trainingLost = 0 validationLost = 0 i = 1 XBatch, XNum, XEndFlag = utils.DecompressArray(XArrayCompressed, datasetPtr, predictBatchSize, total) YBatch, YNum, YEndFlag = utils.DecompressArray(YArrayCompressed, datasetPtr, predictBatchSize, total) datasetPtr += XNum while True: threadPool = [] threadPool.append( Thread(target=m.getLossNoRT, args=( XBatch, YBatch, ))) for t in threadPool: t.start() predictBatchSize = param.predictBatchSize if datasetPtr < validationStart and ( validationStart - datasetPtr) < predictBatchSize: predictBatchSize = validationStart - datasetPtr elif datasetPtr >= validationStart and (datasetPtr % predictBatchSize) != 0: predictBatchSize = predictBatchSize - (datasetPtr % predictBatchSize) #print >> sys.stderr, "%d\t%d\t%d\t%d" % (datasetPtr, predictBatchSize, validationStart, total) XBatch2, XNum2, XEndFlag2 = utils.DecompressArray( XArrayCompressed, datasetPtr, predictBatchSize, total) YBatch2, YNum2, YEndFlag2 = utils.DecompressArray( YArrayCompressed, datasetPtr, predictBatchSize, total) if XNum != YNum or XEndFlag != YEndFlag: sys.exit("Inconsistency between decompressed arrays: %d/%d" % (XNum, YNum)) for t in threadPool: t.join() XBatch = XBatch2 YBatch = YBatch2 if datasetPtr >= validationStart: validationLost += m.getLossLossRTVal else: trainingLost += m.getLossLossRTVal if XEndFlag2 != 0: m.getLossNoRT(XBatch, YBatch) if datasetPtr >= validationStart: validationLost += m.getLossLossRTVal else: trainingLost += m.getLossLossRTVal print >> sys.stderr, "%s\t%.10f\t%.10f" % ( n, trainingLost / trainingTotal, validationLost / numValItems) break i += 1 datasetPtr += XNum2
def Test(args, m, utils): logging.info("Loading the dataset ...") if args.bin_fn != None: with open(args.bin_fn, "rb") as fh: total = pickle.load(fh) XArrayCompressed = pickle.load(fh) YArrayCompressed = pickle.load(fh) posArrayCompressed = pickle.load(fh) else: total, XArrayCompressed, YArrayCompressed, posArrayCompressed = \ utils.GetTrainingArray(args.tensor_fn, args.var_fn, args.bed_fn) logging.info("Dataset size: %d" % total) logging.info("Testing on the dataset ...") predictBatchSize = param.predictBatchSize predictStart = time.time() if args.v2 == True or args.v3 == True: datasetPtr = 0 XBatch, _, _ = utils.DecompressArray(XArrayCompressed, datasetPtr, predictBatchSize, total) bases = [] zs = [] ts = [] ls = [] base, z, t, l = m.predict(XBatch) bases.append(base) zs.append(z) ts.append(t) ls.append(l) datasetPtr += predictBatchSize while datasetPtr < total: XBatch, _, endFlag = utils.DecompressArray(XArrayCompressed, datasetPtr, predictBatchSize, total) base, z, t, l = m.predict(XBatch) bases.append(base) zs.append(z) ts.append(t) ls.append(l) datasetPtr += predictBatchSize if endFlag != 0: break bases = np.concatenate(bases[:]) zs = np.concatenate(zs[:]) ts = np.concatenate(ts[:]) ls = np.concatenate(ls[:]) logging.info("Prediciton time elapsed: %.2f s" % (time.time() - predictStart)) YArray, _, _ = utils.DecompressArray(YArrayCompressed, 0, total, total) if args.v2 == True or args.v3 == True: logging.info("Version 2 model, evaluation on base change:") allBaseCount = top1Count = top2Count = 0 for predictV, annotateV in zip(bases, YArray[:, 0:4]): allBaseCount += 1 sortPredictV = predictV.argsort()[::-1] if np.argmax(annotateV) == sortPredictV[0]: top1Count += 1 top2Count += 1 elif np.argmax(annotateV) == sortPredictV[1]: top2Count += 1 logging.info( "all/top1/top2/top1p/top2p: %d/%d/%d/%.2f/%.2f" % (allBaseCount, top1Count, top2Count, float(top1Count) / allBaseCount * 100, float(top2Count) / allBaseCount * 100)) logging.info("Version 2 model, evaluation on Zygosity:") ed = np.zeros((2, 2), dtype=np.int) for predictV, annotateV in zip(zs, YArray[:, 4:6]): ed[np.argmax(annotateV)][np.argmax(predictV)] += 1 for i in range(2): logging.info("\t".join([str(ed[i][j]) for j in range(2)])) logging.info("Version 2 model, evaluation on variant type:") ed = np.zeros((4, 4), dtype=np.int) for predictV, annotateV in zip(ts, YArray[:, 6:10]): ed[np.argmax(annotateV)][np.argmax(predictV)] += 1 for i in range(4): logging.info("\t".join([str(ed[i][j]) for j in range(4)])) logging.info("Version 2 model, evaluation on indel length:") ed = np.zeros((6, 6), dtype=np.int) for predictV, annotateV in zip(ls, YArray[:, 10:16]): ed[np.argmax(annotateV)][np.argmax(predictV)] += 1 for i in range(6): logging.info("\t".join([str(ed[i][j]) for j in range(6)]))
def TrainAll(args, m): logging.info("Loading the training dataset ...") total, XArrayCompressed, YArrayCompressed, posArrayCompressed = \ utils.GetTrainingArray("../training/tensor_can_chr21", "../training/var_chr21", "../training/bed") logging.info("The size of training dataset: {}".format(total)) # op to write logs to Tensorboard if args.olog != None: summaryWriter = m.summaryFileWriter(args.olog) # training and save the parameters, we train on the first 90% variant sites and validate on the last 10% variant sites logging.info("Start training ...") trainingStart = time.time() trainBatchSize = param.trainBatchSize validationLosts = [] logging.info("Start at learning rate: %.2e" % m.setLearningRate(args.learning_rate)) c = 0 maxLearningRateSwitch = param.maxLearningRateSwitch epochStart = time.time() datasetPtr = 0 trainingTotal = int(total * param.trainingDatasetPercentage) validationStart = trainingTotal + 1 numValItems = total - validationStart valXArray, _, _ = utils.DecompressArray(XArrayCompressed, validationStart, numValItems, total) valYArray, _, _ = utils.DecompressArray(YArrayCompressed, validationStart, numValItems, total) logging.info("Number of variants for validation: %d" % len(valXArray)) i = 1 while i < (1 + int(param.maxEpoch * trainingTotal / trainBatchSize + 0.499)): XBatch, num, endFlag = utils.DecompressArray(XArrayCompressed, datasetPtr, trainBatchSize, trainingTotal) YBatch, num2, endFlag2 = utils.DecompressArray(YArrayCompressed, datasetPtr, trainBatchSize, trainingTotal) if num != num2 or endFlag != endFlag2: sys.exit("Inconsistency between decompressed arrays: %d/%d" % (num, num2)) loss, summary = m.train(XBatch, YBatch) if args.olog != None: summaryWriter.add_summary(summary, i) if endFlag != 0: validationLost = m.getLoss(valXArray, valYArray) logging.info(" ".join([ str(i), "Training loss:", str(loss / trainBatchSize), "Validation loss: ", str(validationLost / numValItems) ])) logging.info("Epoch time elapsed: %.2f s" % (time.time() - epochStart)) validationLosts.append((validationLost, i)) c += 1 flag = 0 flipFlop = 0 if c >= 6: if validationLosts[-6][0] - validationLosts[-5][0] <= 0: flipFlop += 1 if validationLosts[-5][0] - validationLosts[-4][0] <= 0: flipFlop += 1 if validationLosts[-4][0] - validationLosts[-3][0] <= 0: flipFlop += 1 if validationLosts[-3][0] - validationLosts[-2][0] <= 0: flipFlop += 1 if validationLosts[-2][0] - validationLosts[-1][0] <= 0: flipFlop += 1 if flipFlop >= 3: maxLearningRateSwitch -= 1 if maxLearningRateSwitch == 0: break logging.info("New learning rate: %.2e" % m.setLearningRate()) c = 0 epochStart = time.time() datasetPtr = 0 i += 1 datasetPtr += trainBatchSize logging.info("Training time elapsed: %.2f s" % (time.time() - trainingStart)) # show the parameter set with the smallest validation loss validationLosts.sort() i = validationLosts[0][1] logging.info("Best validation loss at batch: %d" % i) logging.info("Testing on the training dataset ...") predictStart = time.time() predictBatchSize = param.predictBatchSize datasetPtr = 0 XBatch, _, _ = utils.DecompressArray(XArrayCompressed, datasetPtr, predictBatchSize, total) bases = [] zs = [] ts = [] ls = [] base, z, t, l = m.predict(XBatch) bases.append(base) zs.append(z) ts.append(t) ls.append(l) datasetPtr += predictBatchSize while datasetPtr < total: XBatch, _, endFlag = utils.DecompressArray(XArrayCompressed, datasetPtr, predictBatchSize, total) base, z, t, l = m.predict(XBatch) bases.append(base) zs.append(z) ts.append(t) ls.append(l) datasetPtr += predictBatchSize if endFlag != 0: break bases = np.concatenate(bases[:]) zs = np.concatenate(zs[:]) ts = np.concatenate(ts[:]) ls = np.concatenate(ls[:]) print >> sys.stderr, "Prediciton time elapsed: %.2f s" % (time.time() - predictStart) # Evaluate the trained model YArray, _, _ = utils.DecompressArray(YArrayCompressed, 0, total, total) print >> sys.stderr, "Version 2 model, evaluation on base change:" allBaseCount = top1Count = top2Count = 0 for predictV, annotateV in zip(bases, YArray[:, 0:4]): allBaseCount += 1 sortPredictV = predictV.argsort()[::-1] if np.argmax(annotateV) == sortPredictV[0]: top1Count += 1 top2Count += 1 elif np.argmax(annotateV) == sortPredictV[1]: top2Count += 1 print >> sys.stderr, "all/top1/top2/top1p/top2p: %d/%d/%d/%.2f/%.2f" %\ (allBaseCount, top1Count, top2Count, float(top1Count)/allBaseCount*100, float(top2Count)/allBaseCount*100) print >> sys.stderr, "Version 2 model, evaluation on Zygosity:" ed = np.zeros((2, 2), dtype=np.int) for predictV, annotateV in zip(zs, YArray[:, 4:6]): ed[np.argmax(annotateV)][np.argmax(predictV)] += 1 for i in range(2): print >> sys.stderr, "\t".join([str(ed[i][j]) for j in range(2)]) print >> sys.stderr, "Version 2 model, evaluation on variant type:" ed = np.zeros((4, 4), dtype=np.int) for predictV, annotateV in zip(ts, YArray[:, 6:10]): ed[np.argmax(annotateV)][np.argmax(predictV)] += 1 for i in range(4): print >> sys.stderr, "\t".join([str(ed[i][j]) for j in range(4)]) print >> sys.stderr, "Version 2 model, evaluation on indel length:" ed = np.zeros((6, 6), dtype=np.int) for predictV, annotateV in zip(ls, YArray[:, 10:16]): ed[np.argmax(annotateV)][np.argmax(predictV)] += 1 for i in range(6): print >> sys.stderr, "\t".join([str(ed[i][j]) for j in range(6)])
def TrainAll(args, m, utils): logging.info("Loading the training dataset ...") if args.bin_fn != None: with open(args.bin_fn, "rb") as fh: total = pickle.load(fh) XArrayCompressed = pickle.load(fh) YArrayCompressed = pickle.load(fh) posArrayCompressed = pickle.load(fh) else: total, XArrayCompressed, YArrayCompressed, posArrayCompressed = \ utils.GetTrainingArray(args.tensor_fn, args.var_fn, args.bed_fn) logging.info("The size of training dataset: {}".format(total)) # Op to write logs to Tensorboard if args.olog_dir != None: summaryWriter = m.summaryFileWriter(args.olog_dir) # Train and save the parameters, we train on the first 90% variant sites and validate on the last 10% variant sites logging.info("Start training ...") logging.info("Learning rate: %.2e" % m.setLearningRate(args.learning_rate)) logging.info("L2 regularization lambda: %.2e" % m.setL2RegularizationLambda(args.lambd)) validationLosses = [] # Model Constants trainingStart = time.time() trainingTotal = int(total * param.trainingDatasetPercentage) validationStart = trainingTotal + 1 numValItems = total - validationStart maxLearningRateSwitch = param.maxLearningRateSwitch # Variables reset per epoch batchSize = param.trainBatchSize epochStart = time.time() trainLossSum = 0 validationLossSum = 0 datasetPtr = 0 # Variables reset per learning rate decay c = 0 i = 1 if args.chkpnt_fn == None else int( args.chkpnt_fn[-param.parameterOutputPlaceHolder:]) + 1 XBatch, XNum, XEndFlag = utils.DecompressArray(XArrayCompressed, datasetPtr, batchSize, total) YBatch, YNum, YEndFlag = utils.DecompressArray(YArrayCompressed, datasetPtr, batchSize, total) datasetPtr += XNum while i < param.maxEpoch: threadPool = [] if datasetPtr < validationStart: threadPool.append( Thread(target=m.trainNoRT, args=( XBatch, YBatch, ))) elif datasetPtr >= validationStart: threadPool.append( Thread(target=m.getLossNoRT, args=( XBatch, YBatch, ))) for t in threadPool: t.start() if datasetPtr < validationStart and ( validationStart - datasetPtr) < param.trainBatchSize: batchSize = validationStart - datasetPtr elif datasetPtr < validationStart: batchSize = param.trainBatchSize elif datasetPtr >= validationStart and (datasetPtr % param.predictBatchSize) != 0: batchSize = param.predictBatchSize - (datasetPtr % param.predictBatchSize) elif datasetPtr >= validationStart: batchSize = param.predictBatchSize XBatch2, XNum2, XEndFlag2 = utils.DecompressArray( XArrayCompressed, datasetPtr, batchSize, total) YBatch2, YNum2, YEndFlag2 = utils.DecompressArray( YArrayCompressed, datasetPtr, batchSize, total) if XNum2 != YNum2 or XEndFlag2 != YEndFlag2: sys.exit("Inconsistency between decompressed arrays: %d/%d" % (XNum, YNum)) for t in threadPool: t.join() XBatch = XBatch2 YBatch = YBatch2 if datasetPtr < validationStart: trainLossSum += m.trainLossRTVal summary = m.trainSummaryRTVal if args.olog_dir != None: summaryWriter.add_summary(summary, i) elif datasetPtr >= validationStart: validationLossSum += m.getLossLossRTVal datasetPtr += XNum2 if XEndFlag2 != 0: validationLossSum += m.getLoss(XBatch, YBatch) logging.info(" ".join([ str(i), "Training loss:", str(trainLossSum / trainingTotal), "Validation loss: ", str(validationLossSum / numValItems) ])) logging.info("Epoch time elapsed: %.2f s" % (time.time() - epochStart)) validationLosses.append((validationLossSum, i)) # Output the model if args.ochk_prefix != None: parameterOutputPath = "%s-%%0%dd" % ( args.ochk_prefix, param.parameterOutputPlaceHolder) m.saveParameters(os.path.abspath(parameterOutputPath % i)) # Adaptive learning rate decay c += 1 flag = 0 if c >= 6: if validationLosses[-6][0] - validationLosses[-5][0] > 0: if validationLosses[-5][0] - validationLosses[-4][0] < 0: if validationLosses[-4][0] - validationLosses[-3][ 0] > 0: if validationLosses[-3][0] - validationLosses[-2][ 0] < 0: if validationLosses[-2][0] - validationLosses[ -1][0] > 0: flag = 1 elif validationLosses[-6][0] - validationLosses[-5][0] < 0: if validationLosses[-5][0] - validationLosses[-4][0] > 0: if validationLosses[-4][0] - validationLosses[-3][ 0] < 0: if validationLosses[-3][0] - validationLosses[-2][ 0] > 0: if validationLosses[-2][0] - validationLosses[ -1][0] < 0: flag = 1 else: flag = 1 if flag == 1: maxLearningRateSwitch -= 1 if maxLearningRateSwitch == 0: break logging.info("New learning rate: %.2e" % m.setLearningRate()) logging.info("New L2 regularization lambda: %.2e" % m.setL2RegularizationLambda()) c = 0 # Reset per epoch variables i += 1 trainLossSum = 0 validationLossSum = 0 datasetPtr = 0 epochStart = time.time() batchSize = param.trainBatchSize XBatch, XNum, XEndFlag = utils.DecompressArray( XArrayCompressed, datasetPtr, batchSize, total) YBatch, YNum, YEndFlag = utils.DecompressArray( YArrayCompressed, datasetPtr, batchSize, total) datasetPtr += XNum logging.info("Training time elapsed: %.2f s" % (time.time() - trainingStart)) # show the parameter set with the smallest validation loss validationLosses.sort() i = validationLosses[0][1] logging.info("Best validation loss at batch: %d" % i) logging.info("Testing on the training and validation dataset ...") predictStart = time.time() predictBatchSize = param.predictBatchSize if args.v2 == True or args.v3 == True: datasetPtr = 0 XBatch, _, _ = utils.DecompressArray(XArrayCompressed, datasetPtr, predictBatchSize, total) bases = [] zs = [] ts = [] ls = [] base, z, t, l = m.predict(XBatch) bases.append(base) zs.append(z) ts.append(t) ls.append(l) datasetPtr += predictBatchSize while datasetPtr < total: XBatch, _, endFlag = utils.DecompressArray(XArrayCompressed, datasetPtr, predictBatchSize, total) base, z, t, l = m.predict(XBatch) bases.append(base) zs.append(z) ts.append(t) ls.append(l) datasetPtr += predictBatchSize if endFlag != 0: break bases = np.concatenate(bases[:]) zs = np.concatenate(zs[:]) ts = np.concatenate(ts[:]) ls = np.concatenate(ls[:]) logging.info("Prediciton time elapsed: %.2f s" % (time.time() - predictStart)) # Evaluate the trained model YArray, _, _ = utils.DecompressArray(YArrayCompressed, 0, total, total) if args.v2 == True or args.v3 == True: logging.info("Version 2 model, evaluation on base change:") allBaseCount = top1Count = top2Count = 0 for predictV, annotateV in zip(bases, YArray[:, 0:4]): allBaseCount += 1 sortPredictV = predictV.argsort()[::-1] if np.argmax(annotateV) == sortPredictV[0]: top1Count += 1 top2Count += 1 elif np.argmax(annotateV) == sortPredictV[1]: top2Count += 1 logging.info("all/top1/top2/top1p/top2p: %d/%d/%d/%.2f/%.2f" %\ (allBaseCount, top1Count, top2Count, float(top1Count)/allBaseCount*100, float(top2Count)/allBaseCount*100)) logging.info("Version 2 model, evaluation on Zygosity:") ed = np.zeros((2, 2), dtype=np.int) for predictV, annotateV in zip(zs, YArray[:, 4:6]): ed[np.argmax(annotateV)][np.argmax(predictV)] += 1 for i in range(2): logging.info("\t".join([str(ed[i][j]) for j in range(2)])) logging.info("Version 2 model, evaluation on variant type:") ed = np.zeros((4, 4), dtype=np.int) for predictV, annotateV in zip(ts, YArray[:, 6:10]): ed[np.argmax(annotateV)][np.argmax(predictV)] += 1 for i in range(4): logging.info("\t".join([str(ed[i][j]) for j in range(4)])) logging.info("Version 2 model, evaluation on indel length:") ed = np.zeros((6, 6), dtype=np.int) for predictV, annotateV in zip(ls, YArray[:, 10:16]): ed[np.argmax(annotateV)][np.argmax(predictV)] += 1 for i in range(6): logging.info("\t".join([str(ed[i][j]) for j in range(6)]))
def TrainAll(args, m, utils): logging.info("Loading the training dataset ...") if args.bin_fn != None: with open(args.bin_fn, "rb") as fh: total = pickle.load(fh) XArrayCompressed = pickle.load(fh) YArrayCompressed = pickle.load(fh) posArrayCompressed = pickle.load(fh) else: total, XArrayCompressed, YArrayCompressed, posArrayCompressed = \ utils.GetTrainingArray(args.tensor_fn, args.var_fn, args.bed_fn) logging.info("The size of training dataset: {}".format(total)) # Op to write logs to Tensorboard if args.olog_dir != None: summaryWriter = m.summaryFileWriter(args.olog_dir) # Train and save the parameters, we train on the first 90% variant sites and validate on the last 10% variant sites logging.info("Start training ...") logging.info("Learning rate: %.2e" % m.setLearningRate(args.learning_rate)) logging.info("L2 regularization lambda: %.2e" % m.setL2RegularizationLambda(args.lambd)) # Model Constants trainingStart = time.time() trainingTotal = total maxLearningRateSwitch = param.maxLearningRateSwitch batchSize = param.trainBatchSize # Variables reset per epoch batchSize = param.trainBatchSize epochStart = time.time() trainLossSum = 0 datasetPtr = 0 i = 1 if args.chkpnt_fn == None else int( args.chkpnt_fn[-param.parameterOutputPlaceHolder:]) + 1 XBatch, XNum, XEndFlag = utils.DecompressArray(XArrayCompressed, datasetPtr, batchSize, total) YBatch, YNum, YEndFlag = utils.DecompressArray(YArrayCompressed, datasetPtr, batchSize, total) datasetPtr += XNum while i < param.maxEpoch: threadPool = [] threadPool.append(Thread(target=m.trainNoRT, args=( XBatch, YBatch, ))) for t in threadPool: t.start() XBatch2, XNum2, XEndFlag2 = utils.DecompressArray( XArrayCompressed, datasetPtr, batchSize, total) YBatch2, YNum2, YEndFlag2 = utils.DecompressArray( YArrayCompressed, datasetPtr, batchSize, total) if XNum2 != YNum2 or XEndFlag2 != YEndFlag2: sys.exit("Inconsistency between decompressed arrays: %d/%d" % (XNum, YNum)) for t in threadPool: t.join() XBatch = XBatch2 YBatch = YBatch2 trainLossSum += m.trainLossRTVal summary = m.trainSummaryRTVal if args.olog_dir != None: summaryWriter.add_summary(summary, i) datasetPtr += XNum2 if XEndFlag2 != 0: logging.info(" ".join( [str(i), "Training loss:", str(trainLossSum / trainingTotal)])) logging.info("Epoch time elapsed: %.2f s" % (time.time() - epochStart)) parameterOutputPath = "%s-%%0%dd" % ( args.ochk_prefix, param.parameterOutputPlaceHolder) m.saveParameters(parameterOutputPath % i) # Reset per epoch variables i += 1 trainLossSum = 0 datasetPtr = 0 epochStart = time.time() batchSize = param.trainBatchSize XBatch, XNum, XEndFlag = utils.DecompressArray( XArrayCompressed, datasetPtr, batchSize, total) YBatch, YNum, YEndFlag = utils.DecompressArray( YArrayCompressed, datasetPtr, batchSize, total) datasetPtr += XNum logging.info("Training time elapsed: %.2f s" % (time.time() - trainingStart))