def main(_):
    print('Configurations:')
    print(FLAGS)

    log_dir = FLAGS.model_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    path_prefix = log_dir + "/G2S.{}".format(FLAGS.suffix)
    log_file_path = path_prefix + ".log"
    print('Log file path: {}'.format(log_file_path))
    log_file = open(log_file_path, 'wt')
    log_file.write("{}\n".format(FLAGS))
    log_file.flush()

    # save configuration
    namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")

    train_files = FLAGS.train_path.split(',')
    trainset = []
    max_node = 0
    max_in_neigh = 0
    max_out_neigh = 0
    max_sent = 0
    for file in train_files:
        print(file)
        if file.split('.')[-1] == 'json':
            print('Loading train amr set.')
            trainset_amr, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_amr_file(
                file)
            print('Number of training samples: {}'.format(len(trainset_amr)))
            trainset += list(trainset_amr)
            max_node = max(max_node, trn_node)
            max_in_neigh = max(max_in_neigh, trn_in_neigh)
            max_out_neigh = max(max_out_neigh, trn_out_neigh)
            max_sent = max(max_sent, trn_sent)
        elif file.split('.')[-1] == 'xml':
            print('Loading train rdf set.')
            trainset_rdf, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_rdf_file(
                file)
            print('Number of training samples: {}'.format(len(trainset_rdf)))
            trainset += list(trainset_rdf)
            max_node = max(max_node, trn_node)
            max_in_neigh = max(max_in_neigh, trn_in_neigh)
            max_out_neigh = max(max_out_neigh, trn_out_neigh)
            max_sent = max(max_sent, trn_sent)
        else:
            trainset_tmp, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = (
                None, 0, 0, 0, 0)
    random.shuffle(trainset)

    dev_files = FLAGS.test_path.split(',')
    devset = []
    for file in dev_files:
        print(file)
        if file.split('.')[-1] == 'json':
            print('Loading dev amr set.')
            devset_amr, tst_node, tst_in_neigh, tst_out_neigh, tst_sent = G2S_data_stream.read_amr_file(
                file)
            print('Number of dev samples: {}'.format(len(devset_amr)))
            devset += list(devset_amr)
            max_node = max(max_node, tst_node)
            max_in_neigh = max(max_in_neigh, tst_in_neigh)
            max_out_neigh = max(max_out_neigh, tst_out_neigh)
            max_sent = max(max_sent, tst_sent)
        elif file.split('.')[-1] == 'xml':
            print('Loading dev rdf set.')
            devset_rdf, tst_node, tst_in_neigh, tst_out_neigh, tst_sent = G2S_data_stream.read_rdf_file(
                file)
            print('Number of dev samples: {}'.format(len(devset_rdf)))
            devset += list(devset_rdf)
            max_node = max(max_node, tst_node)
            max_in_neigh = max(max_in_neigh, tst_in_neigh)
            max_out_neigh = max(max_out_neigh, tst_out_neigh)
            max_sent = max(max_sent, tst_sent)
        else:
            devset_tmp, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = (
                None, 0, 0, 0, 0)
        random.shuffle(devset)

    if FLAGS.finetune_path != "":
        fintune_files = FLAGS.finetune_path.split(',')
        ftset = []
        for file in fintune_files:
            print(file)
            if file.split('.')[-1] == 'json':
                print('Loading finetune amr set.')
                ftset_amr, ft_node, ft_in_neigh, ft_out_neigh, ft_sent = G2S_data_stream.read_amr_file(
                    file)
                print('Number of finetune samples: {}'.format(len(ftset_amr)))
                ftset += list(ftset_amr)
                max_node = max(max_node, ft_node)
                max_in_neigh = max(max_in_neigh, ft_in_neigh)
                max_out_neigh = max(max_out_neigh, ft_out_neigh)
                max_sent = max(max_sent, ft_sent)
            elif file.split('.')[-1] == 'xml':
                print('Loading finetune rdf set.')
                ftset_rdf, ft_node, ft_in_neigh, ft_out_neigh, ft_sent = G2S_data_stream.read_rdf_file(
                    file)
                print('Number of finetune samples: {}'.format(len(ftset_rdf)))
                ftset += list(ftset_rdf)
                max_node = max(max_node, ft_node)
                max_in_neigh = max(max_in_neigh, ft_in_neigh)
                max_out_neigh = max(max_out_neigh, ft_out_neigh)
                max_sent = max(max_sent, ft_sent)
            else:
                ftset_tmp, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = (
                    None, 0, 0, 0, 0)
            random.shuffle(ftset)

    print('Max node number: {}, while max allowed is {}'.format(
        max_node, FLAGS.max_node_num))
    print('Max parent number: {}, truncated to {}'.format(
        max_in_neigh, FLAGS.max_in_neigh_num))
    print('Max children number: {}, truncated to {}'.format(
        max_out_neigh, FLAGS.max_out_neigh_num))
    print('Max answer length: {}, truncated to {}'.format(
        max_sent, FLAGS.max_answer_len))

    word_vocab = None
    char_vocab = None
    POS_vocab = None
    edgelabel_vocab = None
    has_pretrained_model = False
    best_path = path_prefix + ".best.model"
    if os.path.exists(best_path + ".index"):
        has_pretrained_model = True
        print('!!Existing pretrained model. Loading vocabs.')
        word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
        char_vocab = None
        if FLAGS.with_char:
            char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2')
            print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
        edgelabel_vocab = Vocab(path_prefix + ".edgelabel_vocab",
                                fileformat='txt2')
        POS_vocab = Vocab(path_prefix + ".POS_vocab", fileformat='txt2')

    else:
        print('Collecting vocabs.')
        (allWords, allChars,
         allEdgelabels) = G2S_data_stream.collect_vocabs(trainset)
        print('Number of words: {}'.format(len(allWords)))
        print('Number of allChars: {}'.format(len(allChars)))
        print('Number of allEdgelabels: {}'.format(len(allEdgelabels)))

        word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
        char_vocab = None
        if FLAGS.with_char:
            char_vocab = Vocab(voc=allChars,
                               dim=FLAGS.char_dim,
                               fileformat='build')
            char_vocab.dump_to_txt2(path_prefix + ".char_vocab")
        edgelabel_vocab = Vocab(voc=allEdgelabels,
                                dim=FLAGS.edgelabel_dim,
                                fileformat='build')
        edgelabel_vocab.dump_to_txt2(path_prefix + ".edgelabel_vocab")
        POS_vocab = Vocab(voc=['amr', 'rdf'],
                          dim=FLAGS.POS_dim,
                          fileformat='build')
        POS_vocab.dump_to_txt2(path_prefix + ".POS_vocab")

    print('word vocab size {}'.format(word_vocab.vocab_size))
    sys.stdout.flush()

    print('Build DataStream ... ')
    trainDataStream = G2S_data_stream.G2SDataStream(trainset,
                                                    word_vocab,
                                                    char_vocab,
                                                    edgelabel_vocab,
                                                    POS_vocab,
                                                    options=FLAGS,
                                                    isShuffle=True,
                                                    isLoop=True,
                                                    isSort=True)

    devDataStream = G2S_data_stream.G2SDataStream(devset,
                                                  word_vocab,
                                                  char_vocab,
                                                  edgelabel_vocab,
                                                  POS_vocab,
                                                  options=FLAGS,
                                                  isShuffle=False,
                                                  isLoop=False,
                                                  isSort=True)
    print('Number of instances in trainDataStream: {}'.format(
        trainDataStream.get_num_instance()))
    print('Number of instances in devDataStream: {}'.format(
        devDataStream.get_num_instance()))
    print('Number of batches in trainDataStream: {}'.format(
        trainDataStream.get_num_batch()))
    print('Number of batches in devDataStream: {}'.format(
        devDataStream.get_num_batch()))
    if ftset != None:
        ftDataStream = G2S_data_stream.G2SDataStream(ftset,
                                                     word_vocab,
                                                     char_vocab,
                                                     edgelabel_vocab,
                                                     POS_vocab,
                                                     options=FLAGS,
                                                     isShuffle=True,
                                                     isLoop=True,
                                                     isSort=True)
        print('Number of instances in ftDataStream: {}'.format(
            ftDataStream.get_num_instance()))
        print('Number of batches in ftDataStream: {}'.format(
            ftDataStream.get_num_batch()))

    sys.stdout.flush()

    # initialize the best bleu and accu scores for current training session
    best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0
    best_bleu = FLAGS.best_bleu if FLAGS.__dict__.has_key('best_bleu') else 0.0
    if best_accu > 0.0:
        print('With initial dev accuracy {}'.format(best_accu))
    if best_bleu > 0.0:
        print('With initial dev BLEU score {}'.format(best_bleu))

    init_scale = 0.01
    with tf.Graph().as_default():
        initializer = tf.random_uniform_initializer(-init_scale, init_scale)
        with tf.name_scope("Train"):
            with tf.variable_scope("Model",
                                   reuse=None,
                                   initializer=initializer):
                train_graph = ModelGraph(word_vocab=word_vocab,
                                         Edgelabel_vocab=edgelabel_vocab,
                                         char_vocab=char_vocab,
                                         POS_vocab=POS_vocab,
                                         options=FLAGS,
                                         mode=FLAGS.mode)

        assert FLAGS.mode in (
            'ce_train',
            'rl_train',
        )
        valid_mode = 'evaluate' if FLAGS.mode == 'ce_train' else 'evaluate_bleu'

        with tf.name_scope("Valid"):
            with tf.variable_scope("Model",
                                   reuse=True,
                                   initializer=initializer):
                valid_graph = ModelGraph(word_vocab=word_vocab,
                                         Edgelabel_vocab=edgelabel_vocab,
                                         char_vocab=char_vocab,
                                         POS_vocab=POS_vocab,
                                         options=FLAGS,
                                         mode=valid_mode)

        initializer = tf.global_variables_initializer()

        vars_ = {}
        for var in tf.all_variables():
            if FLAGS.fix_word_vec and "word_embedding" in var.name: continue
            if not var.name.startswith("Model"): continue
            vars_[var.name.split(":")[0]] = var
            print(var)
        saver = tf.train.Saver(vars_)

        sess = tf.Session()
        sess.run(initializer)
        if has_pretrained_model:
            print("Restoring model from " + best_path)
            saver.restore(sess, best_path)
            print("DONE!")

            if FLAGS.mode == 'rl_train' and abs(best_bleu) < 0.00001:
                print("Getting BLEU score for the model")
                sys.stdout.flush()
                best_bleu = evaluate(sess,
                                     valid_graph,
                                     devDataStream,
                                     options=FLAGS)['dev_bleu']
                FLAGS.best_bleu = best_bleu
                namespace_utils.save_namespace(FLAGS,
                                               path_prefix + ".config.json")
                print('BLEU = %.4f' % best_bleu)
                sys.stdout.flush()
                log_file.write('BLEU = %.4f\n' % best_bleu)
            if FLAGS.mode == 'ce_train' and abs(best_accu) < 0.00001:
                print("Getting ACCU score for the model")
                best_accu = evaluate(sess,
                                     valid_graph,
                                     devDataStream,
                                     options=FLAGS)['dev_accu']
                FLAGS.best_accu = best_accu
                namespace_utils.save_namespace(FLAGS,
                                               path_prefix + ".config.json")
                print('ACCU = %.4f' % best_accu)
                log_file.write('ACCU = %.4f\n' % best_accu)

        print('Start the training loop.')
        train_size = trainDataStream.get_num_batch()
        max_steps = train_size * FLAGS.max_epochs
        total_loss = 0.0
        start_time = time.time()
        for step in xrange(max_steps):
            cur_batch = trainDataStream.nextBatch()
            if FLAGS.mode == 'rl_train':
                loss_value = train_graph.run_rl_training_subsample(
                    sess, cur_batch, FLAGS)
            elif FLAGS.mode == 'ce_train':
                loss_value = train_graph.run_ce_training(
                    sess, cur_batch, FLAGS)
            total_loss += loss_value

            if step % 100 == 0:
                print('{} '.format(step), end="")
                sys.stdout.flush()

            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps or \
                    (trainDataStream.get_num_batch() > 10000 and (step+1)%2000 == 0):
                print()
                duration = time.time() - start_time
                print('Step %d: loss = %.2f (%.3f sec)' %
                      (step, total_loss, duration))
                log_file.write('Step %d: loss = %.2f (%.3f sec)\n' %
                               (step, total_loss, duration))
                log_file.flush()
                sys.stdout.flush()
                total_loss = 0.0

                if ftset != None:
                    best_accu, best_bleu = fine_tune(sess, saver, FLAGS,
                                                     log_file, ftDataStream,
                                                     devDataStream,
                                                     train_graph, valid_graph,
                                                     path_prefix, best_accu,
                                                     best_bleu)
                else:
                    best_accu, best_bleu = validate_and_save(
                        sess, saver, FLAGS, log_file, devDataStream,
                        valid_graph, path_prefix, best_accu, best_bleu)
                start_time = time.time()

    log_file.close()
    POS_vocab = Vocab(model_prefix + ".POS_vocab", fileformat='txt2')
    print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape))
    edgelabel_vocab = Vocab(model_prefix + ".edgelabel_vocab", fileformat='txt2')
    print('edgelabel_vocab: {}'.format(edgelabel_vocab.word_vecs.shape))
    char_vocab = None
    if FLAGS.with_char:
        char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2')
        print('char_vocab: {}'.format(char_vocab.word_vecs.shape))

    if graph_cate == "amr":
        print('Loading amr test set from {}.'.format(in_path))
        testset, _, _, _, _ = G2S_data_stream.read_amr_file(in_path)
        print('Number of samples: {}'.format(len(testset)))
    elif graph_cate == "rdf":
        print('Loading rdf test set from {}.'.format(in_path))
        testset, _, _, _, _ = G2S_data_stream.read_rdf_file(in_path,mode='test')
        print('Number of samples: {}'.format(len(testset)))
    else:
        testset = None

    print('Build DataStream ... ')
    batch_size=-1
    if mode not in ('pointwise', 'multinomial', 'greedy', 'greedy_evaluate', ):
        batch_size = 1

    devDataStream = G2S_data_stream.G2SDataStream(testset, word_vocab, char_vocab, edgelabel_vocab, POS_vocab, options=FLAGS,
                 isShuffle=False, isLoop=False, isSort=False, batch_size=batch_size)
    print('Number of instances in testDataStream: {}'.format(devDataStream.get_num_instance()))
    print('Number of batches in testDataStream: {}'.format(devDataStream.get_num_batch()))

    best_path = model_prefix + ".best.model"