def main(argv):
    goids = GODAG.initialize_idmap(None, None)

    labelembedding = load_labelembedding(os.path.join(FLAGS.resources, 'goEmbeddings.txt'), goids)
    assert(labelembedding.shape[0] == (len(goids))) , 'label embeddings and known go ids differ'

    ## Add a row of zeros to refer to NOGO or STOPGO
    labelembedding = np.vstack([np.zeros(labelembedding.shape[1]), labelembedding]).astype(np.float32)
    labelembeddingsize = labelembedding.shape[1]

    # shift all goids by 1, to allow STOPGO
    GODAG.idmap = {key: (val + 1) for key, val in GODAG.idmap.items()}
    log.info('min go index - {}'.format(min(list(GODAG.idmap.values()))))
    GODAG.idmap['STOPGO'] = 0
    GODAG.GOIDS.insert(0, 'STOPGO')
    log.info('first from main-{}, from goids-{},  from idmap-{}, by reversemap-{}'.format(goids[0], GODAG.GOIDS[1], GODAG.id2node(1), GODAG.get_id(goids[0])))

    FeatureExtractor.load(FLAGS.resources)
    log.info('Loaded amino acid and ngram mapping data')

    data = DataLoader(filename=FLAGS.inputfile)
    modelsavename = FLAGS.predict
    if FLAGS.predict == "":
        modelsavename = 'savedmodels_{}'.format(int(time.time()))
        with tf.Session() as sess:
            # sess = tf_debug.LocalCLIDebugWrapperSession(sess)
            valid_dataiter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.validationsize,
                                          dataloader=data, functype=FLAGS.function, featuretype='onehot',
                                          onlyLeafNodes=True, numfuncs=FLAGS.maxnumfuncs)


            train_iter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.trainsize,
                                      seqlen=FLAGS.maxseqlen, dataloader=data,
                                      numfiles=np.floor((FLAGS.trainsize * FLAGS.batchsize) / 250000),
                                      functype=FLAGS.function, featuretype='onehot', onlyLeafNodes=True, numfuncs=FLAGS.maxnumfuncs)

            #encoder = CNNEncoder(vocab_size=len(FeatureExtractor.ngrammap) + 1, inputsize=train_iter.expectedshape).build()

            encoder = MultiCharCNN(vocab_size=len(FeatureExtractor.aminoacidmap) + 1,
                                   inputsize=train_iter.expectedshape, with_dilation=False, charfilter=32,
                                   poolsize=80, poolstride=48).build()

            log.info('built encoder')
            decoder = GORNNDecoder(encoder.outputs, labelembedding, numfuncs=FLAGS.maxnumfuncs,
                                   trainlabelEmbedding=FLAGS.trainlabel, distancefunc=FLAGS.distancefunc, godag=GODAG).build()
            log.info('built decoder')

            init = tf.global_variables_initializer()
            init.run(session=sess)
            chkpt = tf.train.Saver(max_to_keep=4)
            train_writer = tf.summary.FileWriter(FLAGS.outputdir + '/train',
                                              sess.graph)

            test_writer = tf.summary.FileWriter(FLAGS.outputdir + '/test')
            step = 0
            maxwait = 2
            wait = 0
            bestf1 = 0
            bestthres = 0
            metagraphFlag = True
            log.info('starting epochs')
            log.info('params - trainsize-{}, validsie-{}, rootfunc-{}, batchsize-{}'.format(FLAGS.trainsize, FLAGS.validationsize,
                                                                                            FLAGS.function, FLAGS.batchsize))
            for epoch in range(FLAGS.num_epochs):
                for x, y in train_iter:
                    if x.shape[0] != y.shape[0]:
                        raise Exception('invalid, x-{}, y-{}'.format(str(x.shape), str(y.shape)))

                    negatives = get_negatives(y, 10)
                    _, loss, summary = sess.run([decoder.train, decoder.loss, decoder.summary],
                                                 feed_dict={decoder.ys_: y[:, :FLAGS.maxnumfuncs], encoder.xs_: x,
                                                    decoder.negsamples: negatives, decoder.istraining: [True]})
                    train_writer.add_summary(summary, step)
                    log.info('step-{}, loss-{}'.format(step, round(loss, 2)))
                    step += 1

                log.info('beginning validation')
                prec, recall, f1 = validate(valid_dataiter, sess, encoder, decoder, test_writer)
                log.info('epoch: {} \n precision: {}, recall: {}, f1: {}'.format(epoch,
                                                                                 np.round(prec, 2),
                                                                                 np.round(recall, 2),
                                                                                 np.round(f1, 2)))
                if np.round(f1,2) >= (bestf1):
                    bestf1 = np.round(f1,2)
                    wait = 0
                    log.info('saving meta graph')
                    #ipdb.set_trace()
                    chkpt.save(sess, os.path.join(FLAGS.outputdir, modelsavename,
                                                    'model_{}_{}'.format(FLAGS.function, step)),
                                global_step=step, write_meta_graph=metagraphFlag)
                    metagraphFlag = True
                else:
                    wait += 1
                    if wait > maxwait:
                        log.info('f1 didnt improve for last {} validation steps, so stopping'.format(maxwait))
                        break

                train_iter.reset()
                prec, recall, f1 = validate(train_iter, sess, encoder, decoder, None)
                log.info('training error,epoch-{}, precision: {}, recall: {}, f1: {}'.format(epoch,
                                                                                             np.round(prec, 2),
                                                                                             np.round(recall, 2),
                                                                                             np.round(f1, 2)))


                train_iter.reset()

    log.info('testing model')
    test_dataiter = DataIterator(batchsize=FLAGS.batchsize, size=FLAGS.testsize,
                                 dataloader=data, functype=FLAGS.function, featuretype='onehot',
                                 onlyLeafNodes=True, numfuncs=FLAGS.maxnumfuncs)
    prec, recall, f1 = predict_evaluate(test_dataiter, os.path.join(FLAGS.outputdir, modelsavename))
    log.info('test results')
    log.info('precision: {}, recall: {}, F1: {}'.format(round(prec, 2), round(recall, 2), round(f1, 2)))
    data.close()
def printnodes(ids):
    print([GODAG.id2node(i) for i in ids])
    return