def main(argv): log.info('Beginning prediction') funcs = pd.read_pickle( os.path.join(FLAGS.resources, '{}.pkl'.format(FLAGS.function)))['functions'].values funcs = GODAG.initialize_idmap(funcs, FLAGS.function) log.info('GO DAG initialized. Updated function list-{}'.format(len(funcs))) FeatureExtractor.load(FLAGS.resources) log.info('Loaded amino acid and ngram mapping data') data = DataLoader(filename=FLAGS.inputfile) if FLAGS.evaluate: test_dataiter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.testsize, dataloader=data, functype=FLAGS.function, featuretype='ngrams') predict_evaluate(test_dataiter, 0.2, FLAGS.modelsdir) else: test_dataiter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.testsize, dataloader=data, functype=FLAGS.function, featuretype='ngrams', test=True) predict(test_dataiter, 0.2, FLAGS.modelsdir, funcs)
def main(FLAGS): origfuncs = pd.read_pickle( os.path.join(FLAGS.resources, '{}.pkl'.format(FLAGS.function)))['functions'].values revmap = dict(zip([GODAG.get(node) for node in origfuncs], origfuncs)) funcs = GODAG.initialize_idmap(origfuncs, FLAGS.function) print(len(funcs)) #ipdb.set_trace() new_order = [ np.where(origfuncs == revmap.get(node, node))[0][0] for node in funcs ] #ipdb.set_trace() avgprec, avgrecall, avgf1, step = np.zeros_like( THRESHOLD_RANGE), np.zeros_like(THRESHOLD_RANGE), np.zeros_like( THRESHOLD_RANGE), 0 for x, y in read_batches(FLAGS.infile, FLAGS.function, funcs=funcs, batchsize=128, order=new_order): prec, recall, f1 = [], [], [] for thres in THRESHOLD_RANGE: p, r, f = numpy_calc_performance_metrics(x, y, thres) #ipdb.set_trace() prec.append(p) recall.append(r) f1.append(f) avgprec += prec avgrecall += recall avgf1 += f1 step += 1.0 if step % 1000 == 0: print('recall-{}, prec-{}, f1-{}'.format( np.round(avgprec / step, 2), np.round(avgrecall / step, 2), np.round(avgf1 / step, 2))) #break print(THRESHOLD_RANGE) print('recall-{}, prec-{}, f1-{}'.format(np.round(avgprec / step, 2), np.round(avgrecall / step, 2), np.round(avgf1 / step, 2)))
def read_batches(infile, functype, funcs=None, batchsize=32, order=[]): batchtrue, batchpred = [], [] with gzip.open(infile, 'rt') as inf: for ln in SeqIO.parse(inf, 'fasta'): #print("reading") msg = json.loads(ln.description.split(' ', 1)[1]) truelabels = [ i['go_id'] for i in msg['go_ids'] if i['aspect'].lower() == functype[-1].lower() ] truelabels = GODAG.to_npy(truelabels) predictions = np.array( msg['prediction'])[order] #GODAG.to_npy(msg['prediction']) #print('read labels') batchtrue.append(truelabels) batchpred.append(predictions) if len(batchtrue) == batchsize: #print("here") if funcs is None: preds = np.concatenate([ np.vstack(batchpred), np.zeros((len(batchpred), batchtrue[0].shape[0] - len(batchpred[0]))) ], axis=1) else: batchtrue = np.vstack(batchtrue)[:, :len(funcs)] mask = batchtrue.any(axis=1) preds = np.vstack(batchpred)[mask, :] batchtrue = batchtrue[mask, :] #ipdb.set_trace() yield (np.vstack(batchtrue), preds) batchpred, batchtrue = [], [] if batchtrue: preds = np.concatenate([ np.vstack(batchpred), np.zeros( (len(batchpred), batchtrue[0].shape[0] - len(batchpred[0]))) ], axis=1) return (np.vstack(batchtrue), preds)
def main(argv): funcs = pd.read_pickle(os.path.join(FLAGS.resources, '{}.pkl'.format(FLAGS.function)))['functions'].values funcs = GODAG.initialize_idmap(funcs, FLAGS.function) log.info('GO DAG initialized. Updated function list-{}'.format(len(funcs))) FeatureExtractor.load(FLAGS.resources) log.info('Loaded amino acid and ngram mapping data') data = DataLoader(filename=FLAGS.inputfile) modelsavename = 'savedmodels_{}_{}'.format(__processor__, int(time.time())) if FLAGS.predict != '': modelsavename = FLAGS.predict bestthres = 0.1 log.info('no training') valid_dataiter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.validationsize, dataloader=data, functype=FLAGS.function, featuretype='onehot') train_iter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.trainsize, seqlen=FLAGS.maxseqlen, dataloader=data, numfiles=np.floor((FLAGS.trainsize * FLAGS.batchsize) / 250000), functype=FLAGS.function, featuretype='onehot') next(valid_dataiter) next(train_iter) else: with tf.Session() as sess: valid_dataiter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.validationsize, dataloader=data, functype=FLAGS.function, featuretype='onehot') train_iter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.trainsize, seqlen=FLAGS.maxseqlen, dataloader=data, numfiles=np.floor((FLAGS.trainsize * FLAGS.batchsize) / 250000), functype=FLAGS.function, featuretype='onehot') encoder = CHARCNNEncoder(vocab_size=len(FeatureExtractor.aminoacidmap) + 1, inputsize=train_iter.expectedshape).build() log.info('built encoder') decoder = HierarchicalGODecoder(funcs, encoder.outputs, FLAGS.function).build(GODAG) log.info('built decoder') init = tf.global_variables_initializer() init.run(session=sess) chkpt = tf.train.Saver(max_to_keep=4) train_writer = tf.summary.FileWriter(FLAGS.outputdir + '/train', sess.graph) test_writer = tf.summary.FileWriter(FLAGS.outputdir + '/test') step = 0 maxwait = 1 wait = 0 bestf1 = -1 metagraphFlag = True log.info('starting epochs') for epoch in range(FLAGS.num_epochs): for x, y in train_iter: if x.shape[0] != y.shape[0]: raise Exception('invalid, x-{}, y-{}'.format(str(x.shape), str(y.shape))) _, loss, summary = sess.run([decoder.train, decoder.loss, decoder.summary], feed_dict={decoder.ys_: y, encoder.xs_: x, decoder.threshold: [0.2]}) train_writer.add_summary(summary, step) log.info('step-{}, loss-{}'.format(step, round(loss, 2))) step += 1 if True: log.info('beginning validation') prec, recall, f1 = validate(valid_dataiter, sess, encoder, decoder, test_writer) thres = np.argmax(np.round(f1, 2)) log.info('epoch: {} \n precision: {}, recall: {}, f1: {}'.format(epoch, np.round(prec, 2)[thres], np.round(recall, 2)[thres], np.round(f1, 2)[thres])) log.info('precision mat {}'.format(str(np.round(prec, 2)))) log.info('recall mat {}'.format(str(np.round(recall, 2)))) log.info('f1 mat {}'.format(str(np.round(f1, 2)))) log.info('selected threshold is {}'.format(thres/10 + 0.1)) if f1[thres] > (bestf1 + 1e-3): bestf1 = f1[thres] bestthres = THRESHOLD_RANGE[thres] wait = 0 chkpt.save(sess, os.path.join(FLAGS.outputdir, modelsavename, 'model_{}_{}'.format(FLAGS.function, step)), global_step=step, write_meta_graph=metagraphFlag) metagraphFlag = False else: wait += 1 if wait > maxwait: log.info('f1 didnt improve for last {} validation steps, so stopping'.format(maxwait)) break step += 1 train_iter.reset() log.info('testing model') test_dataiter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.testsize, dataloader=data, functype=FLAGS.function, featuretype='onehot') prec, recall, f1 = predict_evaluate(test_dataiter, [bestthres], os.path.join(FLAGS.outputdir, modelsavename)) log.info('test results') log.info('precision: {}, recall: {}, F1: {}'.format(round(prec, 2), round(recall, 2), round(f1, 2))) data.close()
def main(argv): goids = GODAG.initialize_idmap(None, None) labelembedding = load_labelembedding(os.path.join(FLAGS.resources, 'goEmbeddings.txt'), goids) assert(labelembedding.shape[0] == (len(goids))) , 'label embeddings and known go ids differ' ## Add a row of zeros to refer to NOGO or STOPGO labelembedding = np.vstack([np.zeros(labelembedding.shape[1]), labelembedding]).astype(np.float32) labelembeddingsize = labelembedding.shape[1] # shift all goids by 1, to allow STOPGO GODAG.idmap = {key: (val + 1) for key, val in GODAG.idmap.items()} log.info('min go index - {}'.format(min(list(GODAG.idmap.values())))) GODAG.idmap['STOPGO'] = 0 GODAG.GOIDS.insert(0, 'STOPGO') log.info('first from main-{}, from goids-{}, from idmap-{}, by reversemap-{}'.format(goids[0], GODAG.GOIDS[1], GODAG.id2node(1), GODAG.get_id(goids[0]))) FeatureExtractor.load(FLAGS.resources) log.info('Loaded amino acid and ngram mapping data') data = DataLoader(filename=FLAGS.inputfile) modelsavename = FLAGS.predict if FLAGS.predict == "": modelsavename = 'savedmodels_{}'.format(int(time.time())) with tf.Session() as sess: # sess = tf_debug.LocalCLIDebugWrapperSession(sess) valid_dataiter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.validationsize, dataloader=data, functype=FLAGS.function, featuretype='onehot', onlyLeafNodes=True, numfuncs=FLAGS.maxnumfuncs) train_iter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.trainsize, seqlen=FLAGS.maxseqlen, dataloader=data, numfiles=np.floor((FLAGS.trainsize * FLAGS.batchsize) / 250000), functype=FLAGS.function, featuretype='onehot', onlyLeafNodes=True, numfuncs=FLAGS.maxnumfuncs) #encoder = CNNEncoder(vocab_size=len(FeatureExtractor.ngrammap) + 1, inputsize=train_iter.expectedshape).build() encoder = MultiCharCNN(vocab_size=len(FeatureExtractor.aminoacidmap) + 1, inputsize=train_iter.expectedshape, with_dilation=False, charfilter=32, poolsize=80, poolstride=48).build() log.info('built encoder') decoder = GORNNDecoder(encoder.outputs, labelembedding, numfuncs=FLAGS.maxnumfuncs, trainlabelEmbedding=FLAGS.trainlabel, distancefunc=FLAGS.distancefunc, godag=GODAG).build() log.info('built decoder') init = tf.global_variables_initializer() init.run(session=sess) chkpt = tf.train.Saver(max_to_keep=4) train_writer = tf.summary.FileWriter(FLAGS.outputdir + '/train', sess.graph) test_writer = tf.summary.FileWriter(FLAGS.outputdir + '/test') step = 0 maxwait = 2 wait = 0 bestf1 = 0 bestthres = 0 metagraphFlag = True log.info('starting epochs') log.info('params - trainsize-{}, validsie-{}, rootfunc-{}, batchsize-{}'.format(FLAGS.trainsize, FLAGS.validationsize, FLAGS.function, FLAGS.batchsize)) for epoch in range(FLAGS.num_epochs): for x, y in train_iter: if x.shape[0] != y.shape[0]: raise Exception('invalid, x-{}, y-{}'.format(str(x.shape), str(y.shape))) negatives = get_negatives(y, 10) _, loss, summary = sess.run([decoder.train, decoder.loss, decoder.summary], feed_dict={decoder.ys_: y[:, :FLAGS.maxnumfuncs], encoder.xs_: x, decoder.negsamples: negatives, decoder.istraining: [True]}) train_writer.add_summary(summary, step) log.info('step-{}, loss-{}'.format(step, round(loss, 2))) step += 1 log.info('beginning validation') prec, recall, f1 = validate(valid_dataiter, sess, encoder, decoder, test_writer) log.info('epoch: {} \n precision: {}, recall: {}, f1: {}'.format(epoch, np.round(prec, 2), np.round(recall, 2), np.round(f1, 2))) if np.round(f1,2) >= (bestf1): bestf1 = np.round(f1,2) wait = 0 log.info('saving meta graph') #ipdb.set_trace() chkpt.save(sess, os.path.join(FLAGS.outputdir, modelsavename, 'model_{}_{}'.format(FLAGS.function, step)), global_step=step, write_meta_graph=metagraphFlag) metagraphFlag = True else: wait += 1 if wait > maxwait: log.info('f1 didnt improve for last {} validation steps, so stopping'.format(maxwait)) break train_iter.reset() prec, recall, f1 = validate(train_iter, sess, encoder, decoder, None) log.info('training error,epoch-{}, precision: {}, recall: {}, f1: {}'.format(epoch, np.round(prec, 2), np.round(recall, 2), np.round(f1, 2))) train_iter.reset() log.info('testing model') test_dataiter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.testsize, dataloader=data, functype=FLAGS.function, featuretype='onehot', onlyLeafNodes=True, numfuncs=FLAGS.maxnumfuncs) prec, recall, f1 = predict_evaluate(test_dataiter, os.path.join(FLAGS.outputdir, modelsavename)) log.info('test results') log.info('precision: {}, recall: {}, F1: {}'.format(round(prec, 2), round(recall, 2), round(f1, 2))) data.close()
def main(argv): goids = GODAG.initialize_idmap(None, None) # GO_MAT = GODAG.get_fullmat(goids) # log.info('GO Matrix shape - {}'.format(GO_MAT.shape)) # GO_MAT = np.vstack([np.zeros(GO_MAT.shape[1]), GO_MAT]) labelembedding = load_labelembedding(os.path.join(FLAGS.data, 'goEmbeddings.txt'), goids) assert(labelembedding.shape[0] == (len(goids) + 1)) , 'label embeddings and known go ids differ' labelembeddingsize = labelembedding.shape[1] FeatureExtractor.load(FLAGS.data) log.info('Loaded amino acid and ngram mapping data') data = DataLoader() modelsavename = 'savedmodels_{}'.format(int(time.time())) with tf.Session() as sess: # sess = tf_debug.LocalCLIDebugWrapperSession(sess) valid_dataiter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.validationsize, dataloader=data, functype=FLAGS.function, featuretype='ngrams', onlyLeafNodes=True, limit=FLAGS.maxnumfuncs) train_iter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.trainsize, seqlen=FLAGS.maxseqlen, dataloader=data, numfiles=np.floor((FLAGS.trainsize * FLAGS.batchsize) / 250000), functype=FLAGS.function, featuretype='ngrams', onlyLeafNodes=True, limit=FLAGS.maxnumfuncs) encoder = CNNEncoder(vocab_size=len(FeatureExtractor.ngrammap) + 1, inputsize=train_iter.expectedshape).build() log.info('built encoder') decoder = GORNNDecoder(encoder.outputs, labelembedding, numfuncs=FLAGS.maxnumfuncs).build() log.info('built decoder') init = tf.global_variables_initializer() init.run(session=sess) chkpt = tf.train.Saver(max_to_keep=4) train_writer = tf.summary.FileWriter(FLAGS.outputdir + '/train', sess.graph) test_writer = tf.summary.FileWriter(FLAGS.outputdir + '/test') step = 0 maxwait = 1 wait = 0 bestf1 = 0 bestthres = 0 metagraphFlag = True log.info('starting epochs') log.info('params - trainsize-{}, validsie-{}, rootfunc-{}, batchsize-{}'.format(FLAGS.trainsize, FLAGS.validationsize, FLAGS.function, FLAGS.batchsize)) for epoch in range(FLAGS.num_epochs): for x, y in train_iter: if x.shape[0] != y.shape[0]: raise Exception('invalid, x-{}, y-{}'.format(str(x.shape), str(y.shape))) negatives = get_negatives(y, 10) _, loss, summary = sess.run([decoder.train, decoder.loss, decoder.summary], feed_dict={decoder.ys_: y, encoder.xs_: x, decoder.negsamples: negatives}) train_writer.add_summary(summary, step) log.info('step-{}, loss-{}'.format(step, round(loss, 2))) step += 1 log.info('beginning validation') prec, recall, f1 = validate(valid_dataiter, sess, encoder, decoder, test_writer) log.info('epoch: {} \n precision: {}, recall: {}, f1: {}'.format(epoch, np.round(prec, 2), np.round(recall, 2), np.round(f1, 2))) if f1 > (bestf1 + 1e-3): bestf1 = f1 wait = 0 chkpt.save(sess, os.path.join(FLAGS.outputdir, modelsavename, 'model_{}_{}'.format(FLAGS.function, step)), global_step=step, write_meta_graph=metagraphFlag) metagraphFlag = False else: wait += 1 if wait > maxwait: log.info('f1 didnt improve for last {} validation steps, so stopping'.format(maxwait)) break train_iter.reset() log.info('testing model') test_dataiter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.testsize, dataloader=data, functype=FLAGS.function, featuretype='ngrams', onlyLeafNodes=True, limit=FLAGS.maxnumfuncs) prec, recall, f1 = predict_evaluate(test_dataiter, [bestthres], os.path.join(FLAGS.outputdir, modelsavename)) log.info('test results') log.info('precision: {}, recall: {}, F1: {}'.format(round(prec, 2), round(recall, 2), round(f1, 2))) data.close()
def main(argv): funcs = pd.read_pickle( os.path.join(FLAGS.resources, '{}.pkl'.format(FLAGS.function)))['functions'].values funcs = GODAG.initialize_idmap(funcs, FLAGS.function) log.info('GO DAG initialized. Updated function list-{}'.format(len(funcs))) FeatureExtractor.load(FLAGS.resources) log.info('Loaded amino acid and ngram mapping data') pretrained = None featuretype = 'onehot' if FLAGS.pretrained != '': log.info('loading pretrained embedding') pretrained, ngrammap = load_pretrained_embedding(FLAGS.pretrained) FeatureExtractor.ngrammap = ngrammap featuretype = 'ngrams' with tf.Session() as sess: data = DataLoader(filename=FLAGS.inputfile) log.info('initializing validation data') valid_dataiter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.validationsize, dataloader=data, functype=FLAGS.function, featuretype='ngrams', numfuncs=len(funcs), all_labels=False, autoreset=True) log.info('initializing train data') train_iter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.trainsize, seqlen=FLAGS.maxseqlen, dataloader=data, numfiles=4, numfuncs=len(funcs), functype=FLAGS.function, featuretype='ngrams', all_labels=False, autoreset=True) vocabsize = ((len(FeatureExtractor.ngrammap) + 1) if featuretype == 'ngrams' else (len(FeatureExtractor.aminoacidmap) + 1)) model = KerasDeepGO(funcs, FLAGS.function, GODAG, train_iter.expectedshape, vocabsize, pretrained_embedding=pretrained).build() log.info('built encoder') log.info('built decoder') keras.backend.set_session(sess) log.info('starting epochs') model_path = FLAGS.outputdir + 'models/model_seq_' + FLAGS.function + '.h5' checkpointer = keras.callbacks.ModelCheckpoint(filepath=model_path, verbose=1, save_best_only=True, save_weights_only=True) earlystopper = keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, verbose=1) model_jsonpath = FLAGS.outputdir + 'models/model_{}.json'.format( FLAGS.function) f = open(model_jsonpath, 'w') f.write(model.to_json()) f.close() model.fit_generator(train_iter, steps_per_epoch=FLAGS.trainsize, epochs=5, validation_data=valid_dataiter, validation_steps=FLAGS.validationsize, max_queue_size=128, callbacks=[checkpointer, earlystopper]) valid_dataiter.close() train_iter.close() log.info('initializing test data') test_dataiter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.testsize, seqlen=FLAGS.maxseqlen, dataloader=data, numfiles=4, numfuncs=len(funcs), functype=FLAGS.function, featuretype='ngrams', all_labels=True) prec, recall, f1 = predict_evaluate(test_dataiter, model_jsonpath, model_path) log.info('testing error, prec-{}, recall-{}, f1-{}'.format( np.round(prec, 3), np.round(recall, 3), np.round(f1, 3))) data.close()
def main(argv): funcs = GODAG.initialize_idmap(None, None) log.info('GO DAG initialized. Updated function list-{}'.format(len(funcs))) FeatureExtractor.load(FLAGS.resources) log.info('Loaded amino acid and ngram mapping data') data = DataLoader(filename=FLAGS.inputfile) modelsavename = 'savedmodels_{}_{}'.format(__processor__, int(time.time())) pretrained = None featuretype = FLAGS.featuretype if FLAGS.pretrained != '': log.info('loading pretrained embedding') pretrained, ngrammap = load_pretrained_embedding(FLAGS.pretrained) FeatureExtractor.ngrammap = ngrammap featuretype = 'ngrams' log.info('using feature type - {}'.format(featuretype)) with tf.Session() as sess: valid_dataiter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.validationsize, dataloader=data, functype='', featuretype=featuretype) train_iter = DataIterator( batchsize=FLAGS.batchsize, size=FLAGS.trainsize, seqlen=FLAGS.maxseqlen, dataloader=data, numfiles=np.floor((FLAGS.trainsize * FLAGS.batchsize) / 250000), functype='', featuretype=featuretype) vocabsize = ((len(FeatureExtractor.ngrammap)) if featuretype == 'ngrams' else (len(FeatureExtractor.aminoacidmap))) encoder = ConvAutoEncoder(vocab_size=vocabsize, maxlen=train_iter.expectedshape).build() init = tf.global_variables_initializer() init.run(session=sess) chkpt = tf.train.Saver(max_to_keep=4) train_writer = tf.summary.FileWriter(FLAGS.outputdir + '/train', sess.graph) test_writer = tf.summary.FileWriter(FLAGS.outputdir + '/test') step = 0 maxwait = 1 wait = 0 metagraphFlag = True log.info('starting epochs') for epoch in range(FLAGS.num_epochs): for x, y in train_iter: if x.shape[0] != y.shape[0]: raise Exception('invalid, x-{}, y-{}'.format( str(x.shape), str(y.shape))) _, loss = sess.run([encoder.train, encoder.loss], feed_dict={encoder.xs_: x}) # train_writer.add_summary(summary, step) #ipdb.set_trace() log.info('step-{}, loss-{}'.format(step, round(loss, 2))) step += 1 chkpt.save(sess, os.path.join(FLAGS.outputdir, modelsavename, 'model_epoch{}'.format(epoch)), global_step=step, write_meta_graph=metagraphFlag) train_iter.reset() data.close()
def printnodes(ids): print([GODAG.id2node(i) for i in ids]) return
def main(argv): funcs = GODAG.initialize_idmap(None, None) log.info('GO DAG initialized. Updated function list-{}'.format(len(funcs))) FeatureExtractor.load(FLAGS.resources) log.info('Loaded amino acid and ngram mapping data') data = DataLoader(filename=FLAGS.inputfile) modelsavename = 'savedmodels_{}_{}'.format(__processor__, int(time.time())) pretrained = None featuretype = FLAGS.featuretype if FLAGS.pretrained != '': log.info('loading pretrained embedding') pretrained, ngrammap = load_pretrained_embedding(FLAGS.pretrained) FeatureExtractor.ngrammap = ngrammap featuretype = 'ngrams' log.info('using feature type - {}'.format(featuretype)) with tf.Session() as sess: valid_dataiter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.validationsize, seqlen=FLAGS.maxseqlen, dataloader=data, functype='', featuretype=featuretype) train_iter = DataIterator( batchsize=FLAGS.batchsize, size=FLAGS.trainsize, seqlen=FLAGS.maxseqlen, dataloader=data, numfiles=np.floor((FLAGS.trainsize * FLAGS.batchsize) / 250000), functype='', featuretype=featuretype) vocabsize = ((len(FeatureExtractor.ngrammap)) if featuretype == 'ngrams' else (len(FeatureExtractor.aminoacidmap))) cae_model_obj = ConvAutoEncoder(vocab_size=vocabsize, maxlen=train_iter.expectedshape, batch_size=FLAGS.batchsize, embedding_dim=256) cae_model_obj.build() init = tf.global_variables_initializer() init.run(session=sess) chkpt = tf.train.Saver(max_to_keep=4) train_writer = tf.summary.FileWriter(FLAGS.outputdir + '/train', sess.graph) test_writer = tf.summary.FileWriter(FLAGS.outputdir + '/test') step = 0 maxwait = 1 wait = 0 metagraphFlag = True log.info('starting epochs') # -------------------------- # # Start training the autoencoder # -------------------------- # bestloss = np.infty maxwait = 10 wait = 0 slack = 1e-5 earlystop = False for epoch in range(FLAGS.num_epochs): for x, y in train_iter: if x.shape[0] != y.shape[0]: raise Exception('invalid, x-{}, y-{}'.format( str(x.shape), str(y.shape))) _, sp_ss_loss, spmloss, smloss = sess.run( [ cae_model_obj.train, cae_model_obj.loss, cae_model_obj.loss1, cae_model_obj.loss2 ], feed_dict={cae_model_obj.x_input: x}) log.info('step :: {}, sploss :: {}, spm:{}, sm:{}'.format( step, round(sp_ss_loss, 3), round(spmloss, 3), round(smloss, 3))) step += 1 if step % 100 == 0: #ipdb.set_trace() x, y = next(valid_dataiter) valid_loss, tp = sess.run( [cae_model_obj.loss, cae_model_obj.max_out], feed_dict={cae_model_obj.x_input: x}) # pdb.set_trace() log.info( 'validation loss at step: {} is {}, precision is: {}'. format(step, round(valid_loss, 3), sum(sum(tp == x)) / (x.size))) if (valid_loss <= (bestloss + slack)) or (valid_loss + slack <= bestloss): wait = 0 bestloss = valid_loss chkpt.save(sess, os.path.join(FLAGS.outputdir, modelsavename, 'model_epoch{}'.format(epoch)), global_step=step, write_meta_graph=metagraphFlag) else: wait += 1 if wait > maxwait: earlystop = True break chkpt.save(sess, os.path.join(FLAGS.outputdir, modelsavename, 'model_epoch{}'.format(epoch)), global_step=step, write_meta_graph=metagraphFlag) train_iter.reset() valid_dataiter.reset() if earlystop: log.info( 'stopping early at epoch: {}, step:{}, loss:{}'.format( epoch, step, bestloss)) break data.close()
def main(argv): funcs = pd.read_pickle( os.path.join(FLAGS.resources, '{}.pkl'.format(FLAGS.function)))['functions'].values funcs = GODAG.initialize_idmap(funcs, FLAGS.function) log.info('GO DAG initialized. Updated function list-{}'.format(len(funcs))) FeatureExtractor.load(FLAGS.resources) log.info('Loaded amino acid and ngram mapping data') with tf.Session() as sess: data = DataLoader(filename=FLAGS.inputfile) log.info('initializing validation data') valid_dataiter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.validationsize, dataloader=data, functype=FLAGS.function, featuretype='onehot', numfuncs=len(funcs), all_labels=False, autoreset=True) log.info('initializing train data') train_iter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.trainsize, seqlen=FLAGS.maxseqlen, dataloader=data, numfiles=4, numfuncs=len(funcs), functype=FLAGS.function, featuretype='onehot', all_labels=False, autoreset=True) global original_dim original_dim = train_iter.expectedshape model = get_AE(train_iter.expectedshape, len(FeatureExtractor.aminoacidmap), train_iter.expectedshape) keras.backend.set_session(sess) log.info('starting epochs') model_path = FLAGS.outputdir + 'models/model_seq_' + FLAGS.function + '.h5' checkpointer = keras.callbacks.ModelCheckpoint(filepath=model_path, verbose=1, save_best_only=True, save_weights_only=True) earlystopper = keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, verbose=1) model_jsonpath = FLAGS.outputdir + 'models/model_{}.json'.format( FLAGS.function) f = open(model_jsonpath, 'w') f.write(model.to_json()) f.close() model.fit_generator(datawrapper(train_iter), steps_per_epoch=FLAGS.trainsize, epochs=5, validation_data=datawrapper(valid_dataiter), validation_steps=FLAGS.validationsize, callbacks=[checkpointer, earlystopper]) valid_dataiter.close() train_iter.close() data.close()