def visualize_embedding(args, m, utils, total, XArrayCompressed, YArrayCompressed, olog_dir, embed_count): XBatch, XNum, XEndFlag = utils.DecompressArray(XArrayCompressed, 0, embed_count, total) YBatch, YNum, YEndFlag = utils.DecompressArray(YArrayCompressed, 0, embed_count, total) embeddings1, embeddings2, embeddings3, embeddings4 = get_embeddings( m, XBatch) labels1, labels2, labels3, labels4 = get_labels(YBatch) embedding1_values = embeddings1 embedding1_labels = labels1 embedding1_values = np.asarray(embedding1_values) embedding1_var = tf.Variable(embedding1_values, name="BaseChange") embedding2_values = embeddings2 embedding2_labels = labels2 embedding2_values = np.asarray(embedding2_values) embedding2_var = tf.Variable(embedding2_values, name="Zygosity") embedding3_values = embeddings3 embedding3_labels = labels3 embedding3_values = np.asarray(embedding3_values) embedding3_var = tf.Variable(embedding3_values, name="VarType") embedding4_values = embeddings4 embedding4_labels = labels4 embedding4_values = np.asarray(embedding4_values) embedding4_var = tf.Variable(embedding4_values, name="IndelLength") metadata1_path = os.path.join(olog_dir, 'BaseChange.tsv') write_metadata(args, metadata1_path, embedding1_labels) metadata2_path = os.path.join(olog_dir, 'Zygosity.tsv') write_metadata(args, metadata2_path, embedding2_labels) metadata3_path = os.path.join(olog_dir, 'VarType.tsv') write_metadata(args, metadata3_path, embedding3_labels) metadata4_path = os.path.join(olog_dir, 'IndelLength.tsv') write_metadata(args, metadata4_path, embedding4_labels) checkpoint_path = os.path.join(olog_dir, 'model.ckpt') sess = tf.Session() sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() saver.save(sess, checkpoint_path, 1) config = projector.ProjectorConfig() embedding1 = config.embeddings.add() embedding1.tensor_name = embedding1_var.name embedding1.metadata_path = metadata1_path embedding2 = config.embeddings.add() embedding2.tensor_name = embedding2_var.name embedding2.metadata_path = metadata2_path embedding3 = config.embeddings.add() embedding3.tensor_name = embedding3_var.name embedding3.metadata_path = metadata3_path embedding4 = config.embeddings.add() embedding4.tensor_name = embedding4_var.name embedding4.metadata_path = metadata4_path summary_writer = tf.summary.FileWriter(olog_dir, sess.graph) projector.visualize_embeddings(summary_writer, config)
def Test(args, m, utils, total, XArrayCompressed, YArrayCompressed, posArrayCompressed): predictBatchSize = param.predictBatchSize predictStart = time.time() if args.v2 == True or args.v3 == True: datasetPtr = 0 XBatch, _, _ = utils.DecompressArray(XArrayCompressed, datasetPtr, predictBatchSize, total) bases = []; zs = []; ts = []; ls = [] base, z, t, l = m.predict(XBatch) bases.append(base); zs.append(z); ts.append(t); ls.append(l) datasetPtr += predictBatchSize while datasetPtr < total: XBatch, _, endFlag = utils.DecompressArray(XArrayCompressed, datasetPtr, predictBatchSize, total) base, z, t, l = m.predict(XBatch) bases.append(base); zs.append(z); ts.append(t); ls.append(l) datasetPtr += predictBatchSize if endFlag != 0: break bases = np.concatenate(bases[:]) zs = np.concatenate(zs[:]) ts = np.concatenate(ts[:]) ls = np.concatenate(ls[:]) logging.info("Prediciton time elapsed: %.2f s" % (time.time() - predictStart)) YArray, _, _ = utils.DecompressArray(YArrayCompressed, 0, total, total) if args.v2 == True or args.v3 == True: logging.info("Version 2 model, evaluation on base change:") allBaseCount = top1Count = top2Count = 0 for predictV, annotateV in zip(bases, YArray[:,0:4]): allBaseCount += 1 sortPredictV = predictV.argsort()[::-1] if np.argmax(annotateV) == sortPredictV[0]: top1Count += 1 top2Count += 1 elif np.argmax(annotateV) == sortPredictV[1]: top2Count += 1 logging.info("all/top1/top2/top1p/top2p: %d/%d/%d/%.2f/%.2f" % (allBaseCount, top1Count, top2Count, float(top1Count)/allBaseCount*100, float(top2Count)/allBaseCount*100)) logging.info("Version 2 model, evaluation on Zygosity:") ed = np.zeros( (2,2), dtype=np.int ) for predictV, annotateV in zip(zs, YArray[:,4:6]): ed[np.argmax(annotateV)][np.argmax(predictV)] += 1 for i in range(2): logging.info("\t".join([str(ed[i][j]) for j in range(2)])) logging.info("Version 2 model, evaluation on variant type:") ed = np.zeros( (4,4), dtype=np.int ) for predictV, annotateV in zip(ts, YArray[:,6:10]): ed[np.argmax(annotateV)][np.argmax(predictV)] += 1 for i in range(4): logging.info("\t".join([str(ed[i][j]) for j in range(4)])) logging.info("Version 2 model, evaluation on indel length:") ed = np.zeros( (6,6), dtype=np.int ) for predictV, annotateV in zip(ls, YArray[:,10:16]): ed[np.argmax(annotateV)][np.argmax(predictV)] += 1 for i in range(6): logging.info("\t".join([str(ed[i][j]) for j in range(6)]))
def Test22(args, m): logging.info("Loading the chr22 dataset ...") total, XArrayCompressed, YArrayCompressed, posArrayCompressed = \ utils.GetTrainingArray("../training/tensor_can_chr22", "../training/var_chr22", "../training/bed") logging.info("Testing on the chr22 dataset ...") predictStart = time.time() predictBatchSize = param.predictBatchSize datasetPtr = 0 XBatch, _, _ = utils.DecompressArray(XArrayCompressed, datasetPtr, predictBatchSize, total) bases = []; zs = []; ts = []; ls = [] base, z, t, l = m.predict(XBatch) bases.append(base); zs.append(z); ts.append(t); ls.append(l) datasetPtr += predictBatchSize while datasetPtr < total: XBatch, _, endFlag = utils.DecompressArray(XArrayCompressed, datasetPtr, predictBatchSize, total) base, z, t, l = m.predict(XBatch) bases.append(base); zs.append(z); ts.append(t); ls.append(l) datasetPtr += predictBatchSize if endFlag != 0: break bases = np.concatenate(bases[:]); zs = np.concatenate(zs[:]); ts = np.concatenate(ts[:]); ls = np.concatenate(ls[:]) print >> sys.stderr, "Prediciton time elapsed: %.2f s" % (time.time() - predictStart) # Evaluate the trained model YArray, _, _ = utils.DecompressArray(YArrayCompressed, 0, total, total) print >> sys.stderr, "Version 2 model, evaluation on base change:" allBaseCount = top1Count = top2Count = 0 for predictV, annotateV in zip(bases, YArray[:,0:4]): allBaseCount += 1 sortPredictV = predictV.argsort()[::-1] if np.argmax(annotateV) == sortPredictV[0]: top1Count += 1; top2Count += 1 elif np.argmax(annotateV) == sortPredictV[1]: top2Count += 1 print >> sys.stderr, "all/top1/top2/top1p/top2p: %d/%d/%d/%.2f/%.2f" %\ (allBaseCount, top1Count, top2Count, float(top1Count)/allBaseCount*100, float(top2Count)/allBaseCount*100) print >> sys.stderr, "Version 2 model, evaluation on Zygosity:" ed = np.zeros( (2,2), dtype=np.int ) for predictV, annotateV in zip(zs, YArray[:,4:6]): ed[np.argmax(annotateV)][np.argmax(predictV)] += 1 for i in range(2): print >> sys.stderr, "\t".join([str(ed[i][j]) for j in range(2)]) print >> sys.stderr, "Version 2 model, evaluation on variant type:" ed = np.zeros( (4,4), dtype=np.int ) for predictV, annotateV in zip(ts, YArray[:,6:10]): ed[np.argmax(annotateV)][np.argmax(predictV)] += 1 for i in range(4): print >> sys.stderr, "\t".join([str(ed[i][j]) for j in range(4)]) print >> sys.stderr, "Version 2 model, evaluation on indel length:" ed = np.zeros( (6,6), dtype=np.int ) for predictV, annotateV in zip(ls, YArray[:,10:16]): ed[np.argmax(annotateV)][np.argmax(predictV)] += 1 for i in range(6): print >> sys.stderr, "\t".join([str(ed[i][j]) for j in range(6)])
def CreatePNGs(args, m, utils, total, XArrayCompressed, YArrayCompressed, posArrayCompressed): for i in range(total): XArray, _, _ = utils.DecompressArray(XArrayCompressed, i, 1, total) YArray, _, _ = utils.DecompressArray(YArrayCompressed, i, 1, total) posArray, _, _ = utils.DecompressArray(posArrayCompressed, i, 1, total) varName = posArray[0] varName = "-".join(varName.split(":")) print >> sys.stderr, "Plotting %s..." % (varName) # Create folder if not os.path.exists(varName): os.makedirs(varName) # Plot tensors PlotTensor(varName + "/tensor.png", XArray) # Plot conv1 units = GetActivations(m.conv1, XArray, m) PlotFiltersConv(varName + "/conv1.png", units, 1, 8, 9, 5, 5, 8) # Plot conv2 units = GetActivations(m.conv2, XArray, m) PlotFiltersConv(varName + "/conv2.png", units, 1, 8, 18, 6, 6, 9) # Plot conv3 units = GetActivations(m.conv3, XArray, m) PlotFiltersConv(varName + "/conv3.png", units, 1, 8, 24, 7, 7, 10) # Plot fc4 units = GetActivations(m.fc4, XArray, m) PlotFiltersFC(varName + "/fc4.png", units, 10, 16, 1, 9, 9, 10) # Plot fc5 units = GetActivations(m.fc5, XArray, m) PlotFiltersFC(varName + "/fc5.png", units, 10, 4, 1, 4, 4, 5) # Plot Predicted and Truth Y unitsX = [GetActivations(m.YBaseChangeSigmoid, XArray, m),\ GetActivations(m.YZygositySoftmax, XArray, m),\ GetActivations(m.YVarTypeSoftmax, XArray, m),\ GetActivations(m.YIndelLengthSoftmax, XArray, m)] unitsX = np.concatenate(unitsX, axis=1) unitsY = np.reshape(YArray[0], (1, -1)) PlotOutputArray(varName + "/output.png", unitsX, unitsY, 1, 4, 2, 4, 4, 5)
def CalcAll(args, m, utils): if args.bin_fn != None: with open(args.bin_fn, "rb") as fh: total = pickle.load(fh) XArrayCompressed = pickle.load(fh) YArrayCompressed = pickle.load(fh) posArrayCompressed = pickle.load(fh) else: total, XArrayCompressed, YArrayCompressed, posArrayCompressed = \ utils.GetTrainingArray(args.tensor_fn, args.var_fn, args.bed_fn) predictBatchSize = param.predictBatchSize trainingTotal = int(total * param.trainingDatasetPercentage) validationStart = trainingTotal + 1 numValItems = total - validationStart for n in args.chkpnt_fn: m.restoreParameters(os.path.abspath(n)) datasetPtr = 0 trainingLost = 0 validationLost = 0 i = 1 XBatch, XNum, XEndFlag = utils.DecompressArray(XArrayCompressed, datasetPtr, predictBatchSize, total) YBatch, YNum, YEndFlag = utils.DecompressArray(YArrayCompressed, datasetPtr, predictBatchSize, total) datasetPtr += XNum while True: threadPool = [] threadPool.append( Thread(target=m.getLossNoRT, args=( XBatch, YBatch, ))) for t in threadPool: t.start() predictBatchSize = param.predictBatchSize if datasetPtr < validationStart and ( validationStart - datasetPtr) < predictBatchSize: predictBatchSize = validationStart - datasetPtr elif datasetPtr >= validationStart and (datasetPtr % predictBatchSize) != 0: predictBatchSize = predictBatchSize - (datasetPtr % predictBatchSize) #print >> sys.stderr, "%d\t%d\t%d\t%d" % (datasetPtr, predictBatchSize, validationStart, total) XBatch2, XNum2, XEndFlag2 = utils.DecompressArray( XArrayCompressed, datasetPtr, predictBatchSize, total) YBatch2, YNum2, YEndFlag2 = utils.DecompressArray( YArrayCompressed, datasetPtr, predictBatchSize, total) if XNum != YNum or XEndFlag != YEndFlag: sys.exit("Inconsistency between decompressed arrays: %d/%d" % (XNum, YNum)) for t in threadPool: t.join() XBatch = XBatch2 YBatch = YBatch2 if datasetPtr >= validationStart: validationLost += m.getLossLossRTVal else: trainingLost += m.getLossLossRTVal if XEndFlag2 != 0: m.getLossNoRT(XBatch, YBatch) if datasetPtr >= validationStart: validationLost += m.getLossLossRTVal else: trainingLost += m.getLossLossRTVal print >> sys.stderr, "%s\t%.10f\t%.10f" % ( n, trainingLost / trainingTotal, validationLost / numValItems) break i += 1 datasetPtr += XNum2
def TrainAll(args, m): logging.info("Loading the training dataset ...") total, XArrayCompressed, YArrayCompressed, posArrayCompressed = \ utils.GetTrainingArray("../training/tensor_can_chr21", "../training/var_chr21", "../training/bed") logging.info("The size of training dataset: {}".format(total)) # op to write logs to Tensorboard if args.olog != None: summaryWriter = m.summaryFileWriter(args.olog) # training and save the parameters, we train on the first 90% variant sites and validate on the last 10% variant sites logging.info("Start training ...") trainingStart = time.time() trainBatchSize = param.trainBatchSize validationLosts = [] logging.info("Start at learning rate: %.2e" % m.setLearningRate(args.learning_rate)) c = 0 maxLearningRateSwitch = param.maxLearningRateSwitch epochStart = time.time() datasetPtr = 0 trainingTotal = int(total * param.trainingDatasetPercentage) validationStart = trainingTotal + 1 numValItems = total - validationStart valXArray, _, _ = utils.DecompressArray(XArrayCompressed, validationStart, numValItems, total) valYArray, _, _ = utils.DecompressArray(YArrayCompressed, validationStart, numValItems, total) logging.info("Number of variants for validation: %d" % len(valXArray)) i = 1 while i < (1 + int(param.maxEpoch * trainingTotal / trainBatchSize + 0.499)): XBatch, num, endFlag = utils.DecompressArray(XArrayCompressed, datasetPtr, trainBatchSize, trainingTotal) YBatch, num2, endFlag2 = utils.DecompressArray(YArrayCompressed, datasetPtr, trainBatchSize, trainingTotal) if num != num2 or endFlag != endFlag2: sys.exit("Inconsistency between decompressed arrays: %d/%d" % (num, num2)) loss, summary = m.train(XBatch, YBatch) if args.olog != None: summaryWriter.add_summary(summary, i) if endFlag != 0: validationLost = m.getLoss(valXArray, valYArray) logging.info(" ".join([ str(i), "Training loss:", str(loss / trainBatchSize), "Validation loss: ", str(validationLost / numValItems) ])) logging.info("Epoch time elapsed: %.2f s" % (time.time() - epochStart)) validationLosts.append((validationLost, i)) c += 1 flag = 0 flipFlop = 0 if c >= 6: if validationLosts[-6][0] - validationLosts[-5][0] <= 0: flipFlop += 1 if validationLosts[-5][0] - validationLosts[-4][0] <= 0: flipFlop += 1 if validationLosts[-4][0] - validationLosts[-3][0] <= 0: flipFlop += 1 if validationLosts[-3][0] - validationLosts[-2][0] <= 0: flipFlop += 1 if validationLosts[-2][0] - validationLosts[-1][0] <= 0: flipFlop += 1 if flipFlop >= 3: maxLearningRateSwitch -= 1 if maxLearningRateSwitch == 0: break logging.info("New learning rate: %.2e" % m.setLearningRate()) c = 0 epochStart = time.time() datasetPtr = 0 i += 1 datasetPtr += trainBatchSize logging.info("Training time elapsed: %.2f s" % (time.time() - trainingStart)) # show the parameter set with the smallest validation loss validationLosts.sort() i = validationLosts[0][1] logging.info("Best validation loss at batch: %d" % i) logging.info("Testing on the training dataset ...") predictStart = time.time() predictBatchSize = param.predictBatchSize datasetPtr = 0 XBatch, _, _ = utils.DecompressArray(XArrayCompressed, datasetPtr, predictBatchSize, total) bases = [] zs = [] ts = [] ls = [] base, z, t, l = m.predict(XBatch) bases.append(base) zs.append(z) ts.append(t) ls.append(l) datasetPtr += predictBatchSize while datasetPtr < total: XBatch, _, endFlag = utils.DecompressArray(XArrayCompressed, datasetPtr, predictBatchSize, total) base, z, t, l = m.predict(XBatch) bases.append(base) zs.append(z) ts.append(t) ls.append(l) datasetPtr += predictBatchSize if endFlag != 0: break bases = np.concatenate(bases[:]) zs = np.concatenate(zs[:]) ts = np.concatenate(ts[:]) ls = np.concatenate(ls[:]) print >> sys.stderr, "Prediciton time elapsed: %.2f s" % (time.time() - predictStart) # Evaluate the trained model YArray, _, _ = utils.DecompressArray(YArrayCompressed, 0, total, total) print >> sys.stderr, "Version 2 model, evaluation on base change:" allBaseCount = top1Count = top2Count = 0 for predictV, annotateV in zip(bases, YArray[:, 0:4]): allBaseCount += 1 sortPredictV = predictV.argsort()[::-1] if np.argmax(annotateV) == sortPredictV[0]: top1Count += 1 top2Count += 1 elif np.argmax(annotateV) == sortPredictV[1]: top2Count += 1 print >> sys.stderr, "all/top1/top2/top1p/top2p: %d/%d/%d/%.2f/%.2f" %\ (allBaseCount, top1Count, top2Count, float(top1Count)/allBaseCount*100, float(top2Count)/allBaseCount*100) print >> sys.stderr, "Version 2 model, evaluation on Zygosity:" ed = np.zeros((2, 2), dtype=np.int) for predictV, annotateV in zip(zs, YArray[:, 4:6]): ed[np.argmax(annotateV)][np.argmax(predictV)] += 1 for i in range(2): print >> sys.stderr, "\t".join([str(ed[i][j]) for j in range(2)]) print >> sys.stderr, "Version 2 model, evaluation on variant type:" ed = np.zeros((4, 4), dtype=np.int) for predictV, annotateV in zip(ts, YArray[:, 6:10]): ed[np.argmax(annotateV)][np.argmax(predictV)] += 1 for i in range(4): print >> sys.stderr, "\t".join([str(ed[i][j]) for j in range(4)]) print >> sys.stderr, "Version 2 model, evaluation on indel length:" ed = np.zeros((6, 6), dtype=np.int) for predictV, annotateV in zip(ls, YArray[:, 10:16]): ed[np.argmax(annotateV)][np.argmax(predictV)] += 1 for i in range(6): print >> sys.stderr, "\t".join([str(ed[i][j]) for j in range(6)])
def TrainAll(args, m, utils): logging.info("Loading the training dataset ...") if args.bin_fn != None: with open(args.bin_fn, "rb") as fh: total = pickle.load(fh) XArrayCompressed = pickle.load(fh) YArrayCompressed = pickle.load(fh) posArrayCompressed = pickle.load(fh) else: total, XArrayCompressed, YArrayCompressed, posArrayCompressed = \ utils.GetTrainingArray(args.tensor_fn, args.var_fn, args.bed_fn) logging.info("The size of training dataset: {}".format(total)) # Op to write logs to Tensorboard if args.olog_dir != None: summaryWriter = m.summaryFileWriter(args.olog_dir) # Train and save the parameters, we train on the first 90% variant sites and validate on the last 10% variant sites logging.info("Start training ...") logging.info("Learning rate: %.2e" % m.setLearningRate(args.learning_rate)) logging.info("L2 regularization lambda: %.2e" % m.setL2RegularizationLambda(args.lambd)) validationLosses = [] # Model Constants trainingStart = time.time() trainingTotal = int(total * param.trainingDatasetPercentage) validationStart = trainingTotal + 1 numValItems = total - validationStart maxLearningRateSwitch = param.maxLearningRateSwitch # Variables reset per epoch batchSize = param.trainBatchSize epochStart = time.time() trainLossSum = 0 validationLossSum = 0 datasetPtr = 0 # Variables reset per learning rate decay c = 0 i = 1 if args.chkpnt_fn == None else int( args.chkpnt_fn[-param.parameterOutputPlaceHolder:]) + 1 XBatch, XNum, XEndFlag = utils.DecompressArray(XArrayCompressed, datasetPtr, batchSize, total) YBatch, YNum, YEndFlag = utils.DecompressArray(YArrayCompressed, datasetPtr, batchSize, total) datasetPtr += XNum while i < param.maxEpoch: threadPool = [] if datasetPtr < validationStart: threadPool.append( Thread(target=m.trainNoRT, args=( XBatch, YBatch, ))) elif datasetPtr >= validationStart: threadPool.append( Thread(target=m.getLossNoRT, args=( XBatch, YBatch, ))) for t in threadPool: t.start() if datasetPtr < validationStart and ( validationStart - datasetPtr) < param.trainBatchSize: batchSize = validationStart - datasetPtr elif datasetPtr < validationStart: batchSize = param.trainBatchSize elif datasetPtr >= validationStart and (datasetPtr % param.predictBatchSize) != 0: batchSize = param.predictBatchSize - (datasetPtr % param.predictBatchSize) elif datasetPtr >= validationStart: batchSize = param.predictBatchSize XBatch2, XNum2, XEndFlag2 = utils.DecompressArray( XArrayCompressed, datasetPtr, batchSize, total) YBatch2, YNum2, YEndFlag2 = utils.DecompressArray( YArrayCompressed, datasetPtr, batchSize, total) if XNum2 != YNum2 or XEndFlag2 != YEndFlag2: sys.exit("Inconsistency between decompressed arrays: %d/%d" % (XNum, YNum)) for t in threadPool: t.join() XBatch = XBatch2 YBatch = YBatch2 if datasetPtr < validationStart: trainLossSum += m.trainLossRTVal summary = m.trainSummaryRTVal if args.olog_dir != None: summaryWriter.add_summary(summary, i) elif datasetPtr >= validationStart: validationLossSum += m.getLossLossRTVal datasetPtr += XNum2 if XEndFlag2 != 0: validationLossSum += m.getLoss(XBatch, YBatch) logging.info(" ".join([ str(i), "Training loss:", str(trainLossSum / trainingTotal), "Validation loss: ", str(validationLossSum / numValItems) ])) logging.info("Epoch time elapsed: %.2f s" % (time.time() - epochStart)) validationLosses.append((validationLossSum, i)) # Output the model if args.ochk_prefix != None: parameterOutputPath = "%s-%%0%dd" % ( args.ochk_prefix, param.parameterOutputPlaceHolder) m.saveParameters(os.path.abspath(parameterOutputPath % i)) # Adaptive learning rate decay c += 1 flag = 0 if c >= 6: if validationLosses[-6][0] - validationLosses[-5][0] > 0: if validationLosses[-5][0] - validationLosses[-4][0] < 0: if validationLosses[-4][0] - validationLosses[-3][ 0] > 0: if validationLosses[-3][0] - validationLosses[-2][ 0] < 0: if validationLosses[-2][0] - validationLosses[ -1][0] > 0: flag = 1 elif validationLosses[-6][0] - validationLosses[-5][0] < 0: if validationLosses[-5][0] - validationLosses[-4][0] > 0: if validationLosses[-4][0] - validationLosses[-3][ 0] < 0: if validationLosses[-3][0] - validationLosses[-2][ 0] > 0: if validationLosses[-2][0] - validationLosses[ -1][0] < 0: flag = 1 else: flag = 1 if flag == 1: maxLearningRateSwitch -= 1 if maxLearningRateSwitch == 0: break logging.info("New learning rate: %.2e" % m.setLearningRate()) logging.info("New L2 regularization lambda: %.2e" % m.setL2RegularizationLambda()) c = 0 # Reset per epoch variables i += 1 trainLossSum = 0 validationLossSum = 0 datasetPtr = 0 epochStart = time.time() batchSize = param.trainBatchSize XBatch, XNum, XEndFlag = utils.DecompressArray( XArrayCompressed, datasetPtr, batchSize, total) YBatch, YNum, YEndFlag = utils.DecompressArray( YArrayCompressed, datasetPtr, batchSize, total) datasetPtr += XNum logging.info("Training time elapsed: %.2f s" % (time.time() - trainingStart)) # show the parameter set with the smallest validation loss validationLosses.sort() i = validationLosses[0][1] logging.info("Best validation loss at batch: %d" % i) logging.info("Testing on the training and validation dataset ...") predictStart = time.time() predictBatchSize = param.predictBatchSize if args.v2 == True or args.v3 == True: datasetPtr = 0 XBatch, _, _ = utils.DecompressArray(XArrayCompressed, datasetPtr, predictBatchSize, total) bases = [] zs = [] ts = [] ls = [] base, z, t, l = m.predict(XBatch) bases.append(base) zs.append(z) ts.append(t) ls.append(l) datasetPtr += predictBatchSize while datasetPtr < total: XBatch, _, endFlag = utils.DecompressArray(XArrayCompressed, datasetPtr, predictBatchSize, total) base, z, t, l = m.predict(XBatch) bases.append(base) zs.append(z) ts.append(t) ls.append(l) datasetPtr += predictBatchSize if endFlag != 0: break bases = np.concatenate(bases[:]) zs = np.concatenate(zs[:]) ts = np.concatenate(ts[:]) ls = np.concatenate(ls[:]) logging.info("Prediciton time elapsed: %.2f s" % (time.time() - predictStart)) # Evaluate the trained model YArray, _, _ = utils.DecompressArray(YArrayCompressed, 0, total, total) if args.v2 == True or args.v3 == True: logging.info("Version 2 model, evaluation on base change:") allBaseCount = top1Count = top2Count = 0 for predictV, annotateV in zip(bases, YArray[:, 0:4]): allBaseCount += 1 sortPredictV = predictV.argsort()[::-1] if np.argmax(annotateV) == sortPredictV[0]: top1Count += 1 top2Count += 1 elif np.argmax(annotateV) == sortPredictV[1]: top2Count += 1 logging.info("all/top1/top2/top1p/top2p: %d/%d/%d/%.2f/%.2f" %\ (allBaseCount, top1Count, top2Count, float(top1Count)/allBaseCount*100, float(top2Count)/allBaseCount*100)) logging.info("Version 2 model, evaluation on Zygosity:") ed = np.zeros((2, 2), dtype=np.int) for predictV, annotateV in zip(zs, YArray[:, 4:6]): ed[np.argmax(annotateV)][np.argmax(predictV)] += 1 for i in range(2): logging.info("\t".join([str(ed[i][j]) for j in range(2)])) logging.info("Version 2 model, evaluation on variant type:") ed = np.zeros((4, 4), dtype=np.int) for predictV, annotateV in zip(ts, YArray[:, 6:10]): ed[np.argmax(annotateV)][np.argmax(predictV)] += 1 for i in range(4): logging.info("\t".join([str(ed[i][j]) for j in range(4)])) logging.info("Version 2 model, evaluation on indel length:") ed = np.zeros((6, 6), dtype=np.int) for predictV, annotateV in zip(ls, YArray[:, 10:16]): ed[np.argmax(annotateV)][np.argmax(predictV)] += 1 for i in range(6): logging.info("\t".join([str(ed[i][j]) for j in range(6)]))
def TrainAll(args, m, utils): logging.info("Loading the training dataset ...") if args.bin_fn != None: with open(args.bin_fn, "rb") as fh: total = pickle.load(fh) XArrayCompressed = pickle.load(fh) YArrayCompressed = pickle.load(fh) posArrayCompressed = pickle.load(fh) else: total, XArrayCompressed, YArrayCompressed, posArrayCompressed = \ utils.GetTrainingArray(args.tensor_fn, args.var_fn, args.bed_fn) logging.info("The size of training dataset: {}".format(total)) # Op to write logs to Tensorboard if args.olog_dir != None: summaryWriter = m.summaryFileWriter(args.olog_dir) # Train and save the parameters, we train on the first 90% variant sites and validate on the last 10% variant sites logging.info("Start training ...") logging.info("Learning rate: %.2e" % m.setLearningRate(args.learning_rate)) logging.info("L2 regularization lambda: %.2e" % m.setL2RegularizationLambda(args.lambd)) # Model Constants trainingStart = time.time() trainingTotal = total maxLearningRateSwitch = param.maxLearningRateSwitch batchSize = param.trainBatchSize # Variables reset per epoch batchSize = param.trainBatchSize epochStart = time.time() trainLossSum = 0 datasetPtr = 0 i = 1 if args.chkpnt_fn == None else int( args.chkpnt_fn[-param.parameterOutputPlaceHolder:]) + 1 XBatch, XNum, XEndFlag = utils.DecompressArray(XArrayCompressed, datasetPtr, batchSize, total) YBatch, YNum, YEndFlag = utils.DecompressArray(YArrayCompressed, datasetPtr, batchSize, total) datasetPtr += XNum while i < param.maxEpoch: threadPool = [] threadPool.append(Thread(target=m.trainNoRT, args=( XBatch, YBatch, ))) for t in threadPool: t.start() XBatch2, XNum2, XEndFlag2 = utils.DecompressArray( XArrayCompressed, datasetPtr, batchSize, total) YBatch2, YNum2, YEndFlag2 = utils.DecompressArray( YArrayCompressed, datasetPtr, batchSize, total) if XNum2 != YNum2 or XEndFlag2 != YEndFlag2: sys.exit("Inconsistency between decompressed arrays: %d/%d" % (XNum, YNum)) for t in threadPool: t.join() XBatch = XBatch2 YBatch = YBatch2 trainLossSum += m.trainLossRTVal summary = m.trainSummaryRTVal if args.olog_dir != None: summaryWriter.add_summary(summary, i) datasetPtr += XNum2 if XEndFlag2 != 0: logging.info(" ".join( [str(i), "Training loss:", str(trainLossSum / trainingTotal)])) logging.info("Epoch time elapsed: %.2f s" % (time.time() - epochStart)) parameterOutputPath = "%s-%%0%dd" % ( args.ochk_prefix, param.parameterOutputPlaceHolder) m.saveParameters(parameterOutputPath % i) # Reset per epoch variables i += 1 trainLossSum = 0 datasetPtr = 0 epochStart = time.time() batchSize = param.trainBatchSize XBatch, XNum, XEndFlag = utils.DecompressArray( XArrayCompressed, datasetPtr, batchSize, total) YBatch, YNum, YEndFlag = utils.DecompressArray( YArrayCompressed, datasetPtr, batchSize, total) datasetPtr += XNum logging.info("Training time elapsed: %.2f s" % (time.time() - trainingStart))