def main(argv): goids = GODAG.initialize_idmap(None, None) labelembedding = load_labelembedding(os.path.join(FLAGS.resources, 'goEmbeddings.txt'), goids) assert(labelembedding.shape[0] == (len(goids))) , 'label embeddings and known go ids differ' ## Add a row of zeros to refer to NOGO or STOPGO labelembedding = np.vstack([np.zeros(labelembedding.shape[1]), labelembedding]).astype(np.float32) labelembeddingsize = labelembedding.shape[1] # shift all goids by 1, to allow STOPGO GODAG.idmap = {key: (val + 1) for key, val in GODAG.idmap.items()} log.info('min go index - {}'.format(min(list(GODAG.idmap.values())))) GODAG.idmap['STOPGO'] = 0 GODAG.GOIDS.insert(0, 'STOPGO') log.info('first from main-{}, from goids-{}, from idmap-{}, by reversemap-{}'.format(goids[0], GODAG.GOIDS[1], GODAG.id2node(1), GODAG.get_id(goids[0]))) FeatureExtractor.load(FLAGS.resources) log.info('Loaded amino acid and ngram mapping data') data = DataLoader(filename=FLAGS.inputfile) modelsavename = FLAGS.predict if FLAGS.predict == "": modelsavename = 'savedmodels_{}'.format(int(time.time())) with tf.Session() as sess: # sess = tf_debug.LocalCLIDebugWrapperSession(sess) valid_dataiter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.validationsize, dataloader=data, functype=FLAGS.function, featuretype='onehot', onlyLeafNodes=True, numfuncs=FLAGS.maxnumfuncs) train_iter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.trainsize, seqlen=FLAGS.maxseqlen, dataloader=data, numfiles=np.floor((FLAGS.trainsize * FLAGS.batchsize) / 250000), functype=FLAGS.function, featuretype='onehot', onlyLeafNodes=True, numfuncs=FLAGS.maxnumfuncs) #encoder = CNNEncoder(vocab_size=len(FeatureExtractor.ngrammap) + 1, inputsize=train_iter.expectedshape).build() encoder = MultiCharCNN(vocab_size=len(FeatureExtractor.aminoacidmap) + 1, inputsize=train_iter.expectedshape, with_dilation=False, charfilter=32, poolsize=80, poolstride=48).build() log.info('built encoder') decoder = GORNNDecoder(encoder.outputs, labelembedding, numfuncs=FLAGS.maxnumfuncs, trainlabelEmbedding=FLAGS.trainlabel, distancefunc=FLAGS.distancefunc, godag=GODAG).build() log.info('built decoder') init = tf.global_variables_initializer() init.run(session=sess) chkpt = tf.train.Saver(max_to_keep=4) train_writer = tf.summary.FileWriter(FLAGS.outputdir + '/train', sess.graph) test_writer = tf.summary.FileWriter(FLAGS.outputdir + '/test') step = 0 maxwait = 2 wait = 0 bestf1 = 0 bestthres = 0 metagraphFlag = True log.info('starting epochs') log.info('params - trainsize-{}, validsie-{}, rootfunc-{}, batchsize-{}'.format(FLAGS.trainsize, FLAGS.validationsize, FLAGS.function, FLAGS.batchsize)) for epoch in range(FLAGS.num_epochs): for x, y in train_iter: if x.shape[0] != y.shape[0]: raise Exception('invalid, x-{}, y-{}'.format(str(x.shape), str(y.shape))) negatives = get_negatives(y, 10) _, loss, summary = sess.run([decoder.train, decoder.loss, decoder.summary], feed_dict={decoder.ys_: y[:, :FLAGS.maxnumfuncs], encoder.xs_: x, decoder.negsamples: negatives, decoder.istraining: [True]}) train_writer.add_summary(summary, step) log.info('step-{}, loss-{}'.format(step, round(loss, 2))) step += 1 log.info('beginning validation') prec, recall, f1 = validate(valid_dataiter, sess, encoder, decoder, test_writer) log.info('epoch: {} \n precision: {}, recall: {}, f1: {}'.format(epoch, np.round(prec, 2), np.round(recall, 2), np.round(f1, 2))) if np.round(f1,2) >= (bestf1): bestf1 = np.round(f1,2) wait = 0 log.info('saving meta graph') #ipdb.set_trace() chkpt.save(sess, os.path.join(FLAGS.outputdir, modelsavename, 'model_{}_{}'.format(FLAGS.function, step)), global_step=step, write_meta_graph=metagraphFlag) metagraphFlag = True else: wait += 1 if wait > maxwait: log.info('f1 didnt improve for last {} validation steps, so stopping'.format(maxwait)) break train_iter.reset() prec, recall, f1 = validate(train_iter, sess, encoder, decoder, None) log.info('training error,epoch-{}, precision: {}, recall: {}, f1: {}'.format(epoch, np.round(prec, 2), np.round(recall, 2), np.round(f1, 2))) train_iter.reset() log.info('testing model') test_dataiter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.testsize, dataloader=data, functype=FLAGS.function, featuretype='onehot', onlyLeafNodes=True, numfuncs=FLAGS.maxnumfuncs) prec, recall, f1 = predict_evaluate(test_dataiter, os.path.join(FLAGS.outputdir, modelsavename)) log.info('test results') log.info('precision: {}, recall: {}, F1: {}'.format(round(prec, 2), round(recall, 2), round(f1, 2))) data.close()
def printnodes(ids): print([GODAG.id2node(i) for i in ids]) return