def main(argv):
    log.info('Beginning prediction')
    funcs = pd.read_pickle(
        os.path.join(FLAGS.resources,
                     '{}.pkl'.format(FLAGS.function)))['functions'].values
    funcs = GODAG.initialize_idmap(funcs, FLAGS.function)

    log.info('GO DAG initialized. Updated function list-{}'.format(len(funcs)))
    FeatureExtractor.load(FLAGS.resources)
    log.info('Loaded amino acid and ngram mapping data')

    data = DataLoader(filename=FLAGS.inputfile)
    if FLAGS.evaluate:
        test_dataiter = DataIterator(batchsize=FLAGS.batchsize,
                                     size=FLAGS.testsize,
                                     dataloader=data,
                                     functype=FLAGS.function,
                                     featuretype='ngrams')
        predict_evaluate(test_dataiter, 0.2, FLAGS.modelsdir)
    else:
        test_dataiter = DataIterator(batchsize=FLAGS.batchsize,
                                     size=FLAGS.testsize,
                                     dataloader=data,
                                     functype=FLAGS.function,
                                     featuretype='ngrams',
                                     test=True)
        predict(test_dataiter, 0.2, FLAGS.modelsdir, funcs)
예제 #2
0
def main(argv):
    funcs = pd.read_pickle(os.path.join(FLAGS.resources, '{}.pkl'.format(FLAGS.function)))['functions'].values
    funcs = GODAG.initialize_idmap(funcs, FLAGS.function)

    log.info('GO DAG initialized. Updated function list-{}'.format(len(funcs)))
    FeatureExtractor.load(FLAGS.resources)
    log.info('Loaded amino acid and ngram mapping data')

    data = DataLoader(filename=FLAGS.inputfile)
    modelsavename = 'savedmodels_{}_{}'.format(__processor__, int(time.time()))
    if FLAGS.predict != '':
        modelsavename = FLAGS.predict
        bestthres = 0.1
        log.info('no training')
        valid_dataiter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.validationsize,
                                      dataloader=data, functype=FLAGS.function, featuretype='onehot')

        train_iter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.trainsize,
                                  seqlen=FLAGS.maxseqlen, dataloader=data,
                                  numfiles=np.floor((FLAGS.trainsize * FLAGS.batchsize) / 250000),
                                  functype=FLAGS.function, featuretype='onehot')
        next(valid_dataiter)
        next(train_iter)
    else:
        with tf.Session() as sess:
            valid_dataiter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.validationsize,
                                          dataloader=data, functype=FLAGS.function, featuretype='onehot')

            train_iter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.trainsize,
                                      seqlen=FLAGS.maxseqlen, dataloader=data,
                                      numfiles=np.floor((FLAGS.trainsize * FLAGS.batchsize) / 250000),
                                      functype=FLAGS.function, featuretype='onehot')

            encoder = CHARCNNEncoder(vocab_size=len(FeatureExtractor.aminoacidmap) + 1,
                                     inputsize=train_iter.expectedshape).build()
            log.info('built encoder')
            decoder = HierarchicalGODecoder(funcs, encoder.outputs, FLAGS.function).build(GODAG)
            log.info('built decoder')
            init = tf.global_variables_initializer()
            init.run(session=sess)
            chkpt = tf.train.Saver(max_to_keep=4)
            train_writer = tf.summary.FileWriter(FLAGS.outputdir + '/train',
                                              sess.graph)

            test_writer = tf.summary.FileWriter(FLAGS.outputdir + '/test')
            step = 0
            maxwait = 1
            wait = 0
            bestf1 = -1
            metagraphFlag = True
            log.info('starting epochs')
            for epoch in range(FLAGS.num_epochs):
                for x, y in train_iter:
                    if x.shape[0] != y.shape[0]:
                        raise Exception('invalid, x-{}, y-{}'.format(str(x.shape), str(y.shape)))

                    _, loss, summary = sess.run([decoder.train, decoder.loss, decoder.summary],
                                                 feed_dict={decoder.ys_: y, encoder.xs_: x,
                                                            decoder.threshold: [0.2]})
                    train_writer.add_summary(summary, step)
                    log.info('step-{}, loss-{}'.format(step, round(loss, 2)))
                    step += 1

                if True:
                    log.info('beginning validation')
                    prec, recall, f1 = validate(valid_dataiter, sess, encoder, decoder, test_writer)
                    thres = np.argmax(np.round(f1, 2))
                    log.info('epoch: {} \n precision: {}, recall: {}, f1: {}'.format(epoch,
                                                                                     np.round(prec, 2)[thres],
                                                                                     np.round(recall, 2)[thres],
                                                                                     np.round(f1, 2)[thres]))
                    log.info('precision mat {}'.format(str(np.round(prec, 2))))
                    log.info('recall mat {}'.format(str(np.round(recall, 2))))
                    log.info('f1 mat {}'.format(str(np.round(f1, 2))))

                    log.info('selected threshold is {}'.format(thres/10 + 0.1))
                    if f1[thres] > (bestf1 + 1e-3):
                        bestf1 = f1[thres]
                        bestthres = THRESHOLD_RANGE[thres]
                        wait = 0
                        chkpt.save(sess, os.path.join(FLAGS.outputdir, modelsavename,
                                                        'model_{}_{}'.format(FLAGS.function, step)),
                                    global_step=step, write_meta_graph=metagraphFlag)
                        metagraphFlag = False
                    else:
                        wait += 1
                        if wait > maxwait:
                            log.info('f1 didnt improve for last {} validation steps, so stopping'.format(maxwait))
                            break

                    step += 1

                train_iter.reset()

    log.info('testing model')
    test_dataiter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.testsize,
                                 dataloader=data, functype=FLAGS.function, featuretype='onehot')
    prec, recall, f1 = predict_evaluate(test_dataiter, [bestthres], os.path.join(FLAGS.outputdir, modelsavename))
    log.info('test results')
    log.info('precision: {}, recall: {}, F1: {}'.format(round(prec, 2), round(recall, 2), round(f1, 2)))
    data.close()
def main(argv):
    goids = GODAG.initialize_idmap(None, None)

    labelembedding = load_labelembedding(os.path.join(FLAGS.resources, 'goEmbeddings.txt'), goids)
    assert(labelembedding.shape[0] == (len(goids))) , 'label embeddings and known go ids differ'

    ## Add a row of zeros to refer to NOGO or STOPGO
    labelembedding = np.vstack([np.zeros(labelembedding.shape[1]), labelembedding]).astype(np.float32)
    labelembeddingsize = labelembedding.shape[1]

    # shift all goids by 1, to allow STOPGO
    GODAG.idmap = {key: (val + 1) for key, val in GODAG.idmap.items()}
    log.info('min go index - {}'.format(min(list(GODAG.idmap.values()))))
    GODAG.idmap['STOPGO'] = 0
    GODAG.GOIDS.insert(0, 'STOPGO')
    log.info('first from main-{}, from goids-{},  from idmap-{}, by reversemap-{}'.format(goids[0], GODAG.GOIDS[1], GODAG.id2node(1), GODAG.get_id(goids[0])))

    FeatureExtractor.load(FLAGS.resources)
    log.info('Loaded amino acid and ngram mapping data')

    data = DataLoader(filename=FLAGS.inputfile)
    modelsavename = FLAGS.predict
    if FLAGS.predict == "":
        modelsavename = 'savedmodels_{}'.format(int(time.time()))
        with tf.Session() as sess:
            # sess = tf_debug.LocalCLIDebugWrapperSession(sess)
            valid_dataiter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.validationsize,
                                          dataloader=data, functype=FLAGS.function, featuretype='onehot',
                                          onlyLeafNodes=True, numfuncs=FLAGS.maxnumfuncs)


            train_iter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.trainsize,
                                      seqlen=FLAGS.maxseqlen, dataloader=data,
                                      numfiles=np.floor((FLAGS.trainsize * FLAGS.batchsize) / 250000),
                                      functype=FLAGS.function, featuretype='onehot', onlyLeafNodes=True, numfuncs=FLAGS.maxnumfuncs)

            #encoder = CNNEncoder(vocab_size=len(FeatureExtractor.ngrammap) + 1, inputsize=train_iter.expectedshape).build()

            encoder = MultiCharCNN(vocab_size=len(FeatureExtractor.aminoacidmap) + 1,
                                   inputsize=train_iter.expectedshape, with_dilation=False, charfilter=32,
                                   poolsize=80, poolstride=48).build()

            log.info('built encoder')
            decoder = GORNNDecoder(encoder.outputs, labelembedding, numfuncs=FLAGS.maxnumfuncs,
                                   trainlabelEmbedding=FLAGS.trainlabel, distancefunc=FLAGS.distancefunc, godag=GODAG).build()
            log.info('built decoder')

            init = tf.global_variables_initializer()
            init.run(session=sess)
            chkpt = tf.train.Saver(max_to_keep=4)
            train_writer = tf.summary.FileWriter(FLAGS.outputdir + '/train',
                                              sess.graph)

            test_writer = tf.summary.FileWriter(FLAGS.outputdir + '/test')
            step = 0
            maxwait = 2
            wait = 0
            bestf1 = 0
            bestthres = 0
            metagraphFlag = True
            log.info('starting epochs')
            log.info('params - trainsize-{}, validsie-{}, rootfunc-{}, batchsize-{}'.format(FLAGS.trainsize, FLAGS.validationsize,
                                                                                            FLAGS.function, FLAGS.batchsize))
            for epoch in range(FLAGS.num_epochs):
                for x, y in train_iter:
                    if x.shape[0] != y.shape[0]:
                        raise Exception('invalid, x-{}, y-{}'.format(str(x.shape), str(y.shape)))

                    negatives = get_negatives(y, 10)
                    _, loss, summary = sess.run([decoder.train, decoder.loss, decoder.summary],
                                                 feed_dict={decoder.ys_: y[:, :FLAGS.maxnumfuncs], encoder.xs_: x,
                                                    decoder.negsamples: negatives, decoder.istraining: [True]})
                    train_writer.add_summary(summary, step)
                    log.info('step-{}, loss-{}'.format(step, round(loss, 2)))
                    step += 1

                log.info('beginning validation')
                prec, recall, f1 = validate(valid_dataiter, sess, encoder, decoder, test_writer)
                log.info('epoch: {} \n precision: {}, recall: {}, f1: {}'.format(epoch,
                                                                                 np.round(prec, 2),
                                                                                 np.round(recall, 2),
                                                                                 np.round(f1, 2)))
                if np.round(f1,2) >= (bestf1):
                    bestf1 = np.round(f1,2)
                    wait = 0
                    log.info('saving meta graph')
                    #ipdb.set_trace()
                    chkpt.save(sess, os.path.join(FLAGS.outputdir, modelsavename,
                                                    'model_{}_{}'.format(FLAGS.function, step)),
                                global_step=step, write_meta_graph=metagraphFlag)
                    metagraphFlag = True
                else:
                    wait += 1
                    if wait > maxwait:
                        log.info('f1 didnt improve for last {} validation steps, so stopping'.format(maxwait))
                        break

                train_iter.reset()
                prec, recall, f1 = validate(train_iter, sess, encoder, decoder, None)
                log.info('training error,epoch-{}, precision: {}, recall: {}, f1: {}'.format(epoch,
                                                                                             np.round(prec, 2),
                                                                                             np.round(recall, 2),
                                                                                             np.round(f1, 2)))


                train_iter.reset()

    log.info('testing model')
    test_dataiter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.testsize,
                                 dataloader=data, functype=FLAGS.function, featuretype='onehot',
                                 onlyLeafNodes=True, numfuncs=FLAGS.maxnumfuncs)
    prec, recall, f1 = predict_evaluate(test_dataiter, os.path.join(FLAGS.outputdir, modelsavename))
    log.info('test results')
    log.info('precision: {}, recall: {}, F1: {}'.format(round(prec, 2), round(recall, 2), round(f1, 2)))
    data.close()
예제 #4
0
def main(argv):
    goids = GODAG.initialize_idmap(None, None)
    # GO_MAT = GODAG.get_fullmat(goids)
    # log.info('GO Matrix shape - {}'.format(GO_MAT.shape))
    # GO_MAT = np.vstack([np.zeros(GO_MAT.shape[1]), GO_MAT])
    labelembedding = load_labelembedding(os.path.join(FLAGS.data, 'goEmbeddings.txt'), goids)
    assert(labelembedding.shape[0] == (len(goids) + 1)) , 'label embeddings and known go ids differ'
    labelembeddingsize = labelembedding.shape[1]
    FeatureExtractor.load(FLAGS.data)
    log.info('Loaded amino acid and ngram mapping data')

    data = DataLoader()
    modelsavename = 'savedmodels_{}'.format(int(time.time()))
    with tf.Session() as sess:
        # sess = tf_debug.LocalCLIDebugWrapperSession(sess)
        valid_dataiter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.validationsize,
                                      dataloader=data, functype=FLAGS.function, featuretype='ngrams',
                                      onlyLeafNodes=True, limit=FLAGS.maxnumfuncs)


        train_iter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.trainsize,
                                  seqlen=FLAGS.maxseqlen, dataloader=data,
                                  numfiles=np.floor((FLAGS.trainsize * FLAGS.batchsize) / 250000),
                                  functype=FLAGS.function, featuretype='ngrams', onlyLeafNodes=True, limit=FLAGS.maxnumfuncs)

        encoder = CNNEncoder(vocab_size=len(FeatureExtractor.ngrammap) + 1, inputsize=train_iter.expectedshape).build()
        log.info('built encoder')
        decoder = GORNNDecoder(encoder.outputs, labelembedding, numfuncs=FLAGS.maxnumfuncs).build()
        log.info('built decoder')
        init = tf.global_variables_initializer()
        init.run(session=sess)
        chkpt = tf.train.Saver(max_to_keep=4)
        train_writer = tf.summary.FileWriter(FLAGS.outputdir + '/train',
                                          sess.graph)

        test_writer = tf.summary.FileWriter(FLAGS.outputdir + '/test')
        step = 0
        maxwait = 1
        wait = 0
        bestf1 = 0
        bestthres = 0
        metagraphFlag = True
        log.info('starting epochs')
        log.info('params - trainsize-{}, validsie-{}, rootfunc-{}, batchsize-{}'.format(FLAGS.trainsize, FLAGS.validationsize,
                                                                                        FLAGS.function, FLAGS.batchsize))
        for epoch in range(FLAGS.num_epochs):
            for x, y in train_iter:
                if x.shape[0] != y.shape[0]:
                    raise Exception('invalid, x-{}, y-{}'.format(str(x.shape), str(y.shape)))

                negatives = get_negatives(y, 10)
                _, loss, summary = sess.run([decoder.train, decoder.loss, decoder.summary],
                                            feed_dict={decoder.ys_: y, encoder.xs_: x,
                                                decoder.negsamples: negatives})
                train_writer.add_summary(summary, step)
                log.info('step-{}, loss-{}'.format(step, round(loss, 2)))
                step += 1

            log.info('beginning validation')
            prec, recall, f1 = validate(valid_dataiter, sess, encoder, decoder, test_writer)
            log.info('epoch: {} \n precision: {}, recall: {}, f1: {}'.format(epoch,
                                                                             np.round(prec, 2),
                                                                             np.round(recall, 2),
                                                                             np.round(f1, 2)))
            if f1 > (bestf1 + 1e-3):
                bestf1 = f1
                wait = 0
                chkpt.save(sess, os.path.join(FLAGS.outputdir, modelsavename,
                                                'model_{}_{}'.format(FLAGS.function, step)),
                            global_step=step, write_meta_graph=metagraphFlag)
                metagraphFlag = False
            else:
                wait += 1
                if wait > maxwait:
                    log.info('f1 didnt improve for last {} validation steps, so stopping'.format(maxwait))
                    break

            train_iter.reset()

    log.info('testing model')
    test_dataiter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.testsize,
                                 dataloader=data, functype=FLAGS.function, featuretype='ngrams',
                                 onlyLeafNodes=True, limit=FLAGS.maxnumfuncs)
    prec, recall, f1 = predict_evaluate(test_dataiter, [bestthres], os.path.join(FLAGS.outputdir, modelsavename))
    log.info('test results')
    log.info('precision: {}, recall: {}, F1: {}'.format(round(prec, 2), round(recall, 2), round(f1, 2)))
    data.close()
예제 #5
0
def main(argv):
    funcs = pd.read_pickle(
        os.path.join(FLAGS.resources,
                     '{}.pkl'.format(FLAGS.function)))['functions'].values
    funcs = GODAG.initialize_idmap(funcs, FLAGS.function)

    log.info('GO DAG initialized. Updated function list-{}'.format(len(funcs)))
    FeatureExtractor.load(FLAGS.resources)
    log.info('Loaded amino acid and ngram mapping data')
    pretrained = None
    featuretype = 'onehot'
    if FLAGS.pretrained != '':
        log.info('loading pretrained embedding')
        pretrained, ngrammap = load_pretrained_embedding(FLAGS.pretrained)
        FeatureExtractor.ngrammap = ngrammap
        featuretype = 'ngrams'

    with tf.Session() as sess:
        data = DataLoader(filename=FLAGS.inputfile)
        log.info('initializing validation data')
        valid_dataiter = DataIterator(batchsize=FLAGS.batchsize,
                                      size=FLAGS.validationsize,
                                      dataloader=data,
                                      functype=FLAGS.function,
                                      featuretype='ngrams',
                                      numfuncs=len(funcs),
                                      all_labels=False,
                                      autoreset=True)

        log.info('initializing train data')
        train_iter = DataIterator(batchsize=FLAGS.batchsize,
                                  size=FLAGS.trainsize,
                                  seqlen=FLAGS.maxseqlen,
                                  dataloader=data,
                                  numfiles=4,
                                  numfuncs=len(funcs),
                                  functype=FLAGS.function,
                                  featuretype='ngrams',
                                  all_labels=False,
                                  autoreset=True)

        vocabsize = ((len(FeatureExtractor.ngrammap) +
                      1) if featuretype == 'ngrams' else
                     (len(FeatureExtractor.aminoacidmap) + 1))

        model = KerasDeepGO(funcs,
                            FLAGS.function,
                            GODAG,
                            train_iter.expectedshape,
                            vocabsize,
                            pretrained_embedding=pretrained).build()
        log.info('built encoder')
        log.info('built decoder')
        keras.backend.set_session(sess)
        log.info('starting epochs')

        model_path = FLAGS.outputdir + 'models/model_seq_' + FLAGS.function + '.h5'
        checkpointer = keras.callbacks.ModelCheckpoint(filepath=model_path,
                                                       verbose=1,
                                                       save_best_only=True,
                                                       save_weights_only=True)
        earlystopper = keras.callbacks.EarlyStopping(monitor='val_loss',
                                                     patience=10,
                                                     verbose=1)

        model_jsonpath = FLAGS.outputdir + 'models/model_{}.json'.format(
            FLAGS.function)
        f = open(model_jsonpath, 'w')
        f.write(model.to_json())
        f.close()

        model.fit_generator(train_iter,
                            steps_per_epoch=FLAGS.trainsize,
                            epochs=5,
                            validation_data=valid_dataiter,
                            validation_steps=FLAGS.validationsize,
                            max_queue_size=128,
                            callbacks=[checkpointer, earlystopper])

        valid_dataiter.close()
        train_iter.close()

    log.info('initializing test data')
    test_dataiter = DataIterator(batchsize=FLAGS.batchsize,
                                 size=FLAGS.testsize,
                                 seqlen=FLAGS.maxseqlen,
                                 dataloader=data,
                                 numfiles=4,
                                 numfuncs=len(funcs),
                                 functype=FLAGS.function,
                                 featuretype='ngrams',
                                 all_labels=True)

    prec, recall, f1 = predict_evaluate(test_dataiter, model_jsonpath,
                                        model_path)
    log.info('testing error, prec-{}, recall-{}, f1-{}'.format(
        np.round(prec, 3), np.round(recall, 3), np.round(f1, 3)))
    data.close()
예제 #6
0
def main(argv):
    funcs = GODAG.initialize_idmap(None, None)

    log.info('GO DAG initialized. Updated function list-{}'.format(len(funcs)))
    FeatureExtractor.load(FLAGS.resources)
    log.info('Loaded amino acid and ngram mapping data')

    data = DataLoader(filename=FLAGS.inputfile)
    modelsavename = 'savedmodels_{}_{}'.format(__processor__, int(time.time()))

    pretrained = None
    featuretype = FLAGS.featuretype
    if FLAGS.pretrained != '':
        log.info('loading pretrained embedding')
        pretrained, ngrammap = load_pretrained_embedding(FLAGS.pretrained)
        FeatureExtractor.ngrammap = ngrammap
        featuretype = 'ngrams'

    log.info('using feature type - {}'.format(featuretype))
    with tf.Session() as sess:
        valid_dataiter = DataIterator(batchsize=FLAGS.batchsize,
                                      size=FLAGS.validationsize,
                                      dataloader=data,
                                      functype='',
                                      featuretype=featuretype)

        train_iter = DataIterator(
            batchsize=FLAGS.batchsize,
            size=FLAGS.trainsize,
            seqlen=FLAGS.maxseqlen,
            dataloader=data,
            numfiles=np.floor((FLAGS.trainsize * FLAGS.batchsize) / 250000),
            functype='',
            featuretype=featuretype)

        vocabsize = ((len(FeatureExtractor.ngrammap)) if featuretype
                     == 'ngrams' else (len(FeatureExtractor.aminoacidmap)))

        encoder = ConvAutoEncoder(vocab_size=vocabsize,
                                  maxlen=train_iter.expectedshape).build()

        init = tf.global_variables_initializer()
        init.run(session=sess)
        chkpt = tf.train.Saver(max_to_keep=4)
        train_writer = tf.summary.FileWriter(FLAGS.outputdir + '/train',
                                             sess.graph)

        test_writer = tf.summary.FileWriter(FLAGS.outputdir + '/test')
        step = 0
        maxwait = 1
        wait = 0
        metagraphFlag = True
        log.info('starting epochs')
        for epoch in range(FLAGS.num_epochs):
            for x, y in train_iter:
                if x.shape[0] != y.shape[0]:
                    raise Exception('invalid, x-{}, y-{}'.format(
                        str(x.shape), str(y.shape)))

                _, loss = sess.run([encoder.train, encoder.loss],
                                   feed_dict={encoder.xs_: x})
                # train_writer.add_summary(summary, step)
                #ipdb.set_trace()
                log.info('step-{}, loss-{}'.format(step, round(loss, 2)))
                step += 1

            chkpt.save(sess,
                       os.path.join(FLAGS.outputdir, modelsavename,
                                    'model_epoch{}'.format(epoch)),
                       global_step=step,
                       write_meta_graph=metagraphFlag)
            train_iter.reset()

    data.close()
예제 #7
0
def main(argv):
    funcs = GODAG.initialize_idmap(None, None)

    log.info('GO DAG initialized. Updated function list-{}'.format(len(funcs)))
    FeatureExtractor.load(FLAGS.resources)
    log.info('Loaded amino acid and ngram mapping data')

    data = DataLoader(filename=FLAGS.inputfile)
    modelsavename = 'savedmodels_{}_{}'.format(__processor__, int(time.time()))

    pretrained = None
    featuretype = FLAGS.featuretype

    if FLAGS.pretrained != '':
        log.info('loading pretrained embedding')
        pretrained, ngrammap = load_pretrained_embedding(FLAGS.pretrained)
        FeatureExtractor.ngrammap = ngrammap
        featuretype = 'ngrams'

    log.info('using feature type - {}'.format(featuretype))

    with tf.Session() as sess:
        valid_dataiter = DataIterator(batchsize=FLAGS.batchsize,
                                      size=FLAGS.validationsize,
                                      seqlen=FLAGS.maxseqlen,
                                      dataloader=data,
                                      functype='',
                                      featuretype=featuretype)

        train_iter = DataIterator(
            batchsize=FLAGS.batchsize,
            size=FLAGS.trainsize,
            seqlen=FLAGS.maxseqlen,
            dataloader=data,
            numfiles=np.floor((FLAGS.trainsize * FLAGS.batchsize) / 250000),
            functype='',
            featuretype=featuretype)

        vocabsize = ((len(FeatureExtractor.ngrammap)) if featuretype
                     == 'ngrams' else (len(FeatureExtractor.aminoacidmap)))

        cae_model_obj = ConvAutoEncoder(vocab_size=vocabsize,
                                        maxlen=train_iter.expectedshape,
                                        batch_size=FLAGS.batchsize,
                                        embedding_dim=256)
        cae_model_obj.build()

        init = tf.global_variables_initializer()
        init.run(session=sess)
        chkpt = tf.train.Saver(max_to_keep=4)
        train_writer = tf.summary.FileWriter(FLAGS.outputdir + '/train',
                                             sess.graph)

        test_writer = tf.summary.FileWriter(FLAGS.outputdir + '/test')
        step = 0
        maxwait = 1
        wait = 0
        metagraphFlag = True
        log.info('starting epochs')

        # -------------------------- #
        #  Start training the autoencoder
        # -------------------------- #
        bestloss = np.infty
        maxwait = 10
        wait = 0
        slack = 1e-5
        earlystop = False
        for epoch in range(FLAGS.num_epochs):
            for x, y in train_iter:
                if x.shape[0] != y.shape[0]:
                    raise Exception('invalid, x-{}, y-{}'.format(
                        str(x.shape), str(y.shape)))

                _, sp_ss_loss, spmloss, smloss = sess.run(
                    [
                        cae_model_obj.train, cae_model_obj.loss,
                        cae_model_obj.loss1, cae_model_obj.loss2
                    ],
                    feed_dict={cae_model_obj.x_input: x})

                log.info('step :: {}, sploss :: {}, spm:{}, sm:{}'.format(
                    step, round(sp_ss_loss, 3), round(spmloss, 3),
                    round(smloss, 3)))
                step += 1
                if step % 100 == 0:
                    #ipdb.set_trace()
                    x, y = next(valid_dataiter)
                    valid_loss, tp = sess.run(
                        [cae_model_obj.loss, cae_model_obj.max_out],
                        feed_dict={cae_model_obj.x_input: x})
                    # pdb.set_trace()
                    log.info(
                        'validation loss at step: {} is {}, precision is: {}'.
                        format(step, round(valid_loss, 3),
                               sum(sum(tp == x)) / (x.size)))

                    if (valid_loss <=
                        (bestloss + slack)) or (valid_loss + slack <=
                                                bestloss):
                        wait = 0
                        bestloss = valid_loss
                        chkpt.save(sess,
                                   os.path.join(FLAGS.outputdir, modelsavename,
                                                'model_epoch{}'.format(epoch)),
                                   global_step=step,
                                   write_meta_graph=metagraphFlag)
                    else:
                        wait += 1
                        if wait > maxwait:
                            earlystop = True
                            break

            chkpt.save(sess,
                       os.path.join(FLAGS.outputdir, modelsavename,
                                    'model_epoch{}'.format(epoch)),
                       global_step=step,
                       write_meta_graph=metagraphFlag)
            train_iter.reset()
            valid_dataiter.reset()
            if earlystop:
                log.info(
                    'stopping early at epoch: {}, step:{}, loss:{}'.format(
                        epoch, step, bestloss))
                break

    data.close()
def main(argv):
    funcs = pd.read_pickle(
        os.path.join(FLAGS.resources,
                     '{}.pkl'.format(FLAGS.function)))['functions'].values
    funcs = GODAG.initialize_idmap(funcs, FLAGS.function)

    log.info('GO DAG initialized. Updated function list-{}'.format(len(funcs)))
    FeatureExtractor.load(FLAGS.resources)
    log.info('Loaded amino acid and ngram mapping data')
    with tf.Session() as sess:
        data = DataLoader(filename=FLAGS.inputfile)
        log.info('initializing validation data')
        valid_dataiter = DataIterator(batchsize=FLAGS.batchsize,
                                      size=FLAGS.validationsize,
                                      dataloader=data,
                                      functype=FLAGS.function,
                                      featuretype='onehot',
                                      numfuncs=len(funcs),
                                      all_labels=False,
                                      autoreset=True)

        log.info('initializing train data')
        train_iter = DataIterator(batchsize=FLAGS.batchsize,
                                  size=FLAGS.trainsize,
                                  seqlen=FLAGS.maxseqlen,
                                  dataloader=data,
                                  numfiles=4,
                                  numfuncs=len(funcs),
                                  functype=FLAGS.function,
                                  featuretype='onehot',
                                  all_labels=False,
                                  autoreset=True)

        global original_dim
        original_dim = train_iter.expectedshape

        model = get_AE(train_iter.expectedshape,
                       len(FeatureExtractor.aminoacidmap),
                       train_iter.expectedshape)
        keras.backend.set_session(sess)
        log.info('starting epochs')

        model_path = FLAGS.outputdir + 'models/model_seq_' + FLAGS.function + '.h5'
        checkpointer = keras.callbacks.ModelCheckpoint(filepath=model_path,
                                                       verbose=1,
                                                       save_best_only=True,
                                                       save_weights_only=True)
        earlystopper = keras.callbacks.EarlyStopping(monitor='val_loss',
                                                     patience=10,
                                                     verbose=1)

        model_jsonpath = FLAGS.outputdir + 'models/model_{}.json'.format(
            FLAGS.function)
        f = open(model_jsonpath, 'w')
        f.write(model.to_json())
        f.close()

        model.fit_generator(datawrapper(train_iter),
                            steps_per_epoch=FLAGS.trainsize,
                            epochs=5,
                            validation_data=datawrapper(valid_dataiter),
                            validation_steps=FLAGS.validationsize,
                            callbacks=[checkpointer, earlystopper])

        valid_dataiter.close()
        train_iter.close()
        data.close()