def main(_):
    trainX, trainY, testX, testY = None, None, None, None
    vocabulary_word2index, vocabulary_index2word = create_voabulary()
    vocab_size = len(vocabulary_word2index)
    vocabulary_word2index_label, vocabulary_index2word_label = create_voabulary_label()
    train,test = load_data_with_multilabels(vocabulary_word2index, vocabulary_word2index_label,FLAGS.training_path) #[1,11,3,1998,1998]
    trainX, trainY= train #TODO trainY1999
    testX, testY = test #TODO testY1999
    print("testX.shape:", np.array(testX).shape);print("testY.shape:", np.array(testY).shape)  # 2500个label
    # 2.Data preprocessing
    # Sequence padding
    print("start padding & transform to one hot...")
    trainX = pad_sequences(trainX, maxlen=FLAGS.sentence_len, value=0.)  # padding to max length
    testX = pad_sequences(testX, maxlen=FLAGS.sentence_len, value=0.)  # padding to max length
    print("end padding & transform to one hot...")

    #2.create session.
    config=tf.ConfigProto()
    config.gpu_options.allow_growth=True
    with tf.Session(config=config) as sess:
        #Instantiate Model
        fast_text=fastText(FLAGS.label_size, FLAGS.learning_rate, FLAGS.batch_size, FLAGS.decay_steps, FLAGS.decay_rate,FLAGS.num_sampled,FLAGS.sentence_len,vocab_size,FLAGS.embed_size,FLAGS.is_training)
        #Initialize Save
        saver=tf.train.Saver()
        if os.path.exists(FLAGS.ckpt_dir+"checkpoint"):
            print("Restoring Variables from Checkpoint")
            saver.restore(sess,tf.train.latest_checkpoint(FLAGS.ckpt_dir))
        else:
            print('Initializing Variables')
            sess.run(tf.global_variables_initializer())
            if FLAGS.use_embedding: #load pre-trained word embedding
                assign_pretrained_word_embedding(sess, vocabulary_index2word, vocab_size, fast_text)

        curr_epoch=sess.run(fast_text.epoch_step)
        #3.feed data & training
        number_of_training_data=len(trainX)
        batch_size=FLAGS.batch_size
        for epoch in range(curr_epoch,FLAGS.num_epochs):#range(start,stop,step_size)
            loss, acc, counter = 0.0, 0.0, 0
            for start, end in zip(range(0, number_of_training_data, batch_size),range(batch_size, number_of_training_data, batch_size)):
                if epoch==0 and counter==0:
                    print("trainX[start:end]:",trainX[start:end]) #2d-array. each element slength is a 100.
                    print("trainY[start:end]:",trainY[start:end]) #a list,each element is a list.element:may be has 1,2,3,4,5 labels.
                    #print("trainY1999[start:end]:",trainY1999[start:end])
                curr_loss,_=sess.run([fast_text.loss_val,fast_text.train_op],feed_dict={fast_text.sentence:trainX[start:end],fast_text.labels:trainY[start:end],}) #fast_text.labels_l1999:trainY1999[start:end]
                loss,counter=loss+curr_loss,counter+1 #acc+curr_acc,
                if counter %500==0:
                    print("Epoch %d\tBatch %d\tTrain Loss:%.3f" %(epoch,counter,loss/float(counter))) #\tTrain Accuracy:%.3f--->,acc/float(counter)

            #epoch increment
            print("going to increment epoch counter....")
            sess.run(fast_text.epoch_increment)

            # 4.validation
            print("epoch:",epoch,"validate_every:",FLAGS.validate_every,"validate or not:",(epoch % FLAGS.validate_every==0))
            if epoch % FLAGS.validate_every==0:
                eval_loss,eval_accuracy=do_eval(sess,fast_text,testX,testY,batch_size,vocabulary_index2word_label) #testY1999,eval_acc
                print("Epoch %d Validation Loss:%.3f\tValidation Accuracy: %.3f" % (epoch,eval_loss,eval_accuracy)) #,\tValidation Accuracy: %.3f--->eval_acc
                #save model to checkpoint
                save_path=FLAGS.ckpt_dir+"model.ckpt"
                saver.save(sess,save_path,global_step=epoch) #fast_text.epoch_step

        # 5.最后在测试集上做测试,并报告测试准确率 Test
        test_loss, test_acc = do_eval(sess, fast_text, testX, testY,batch_size,vocabulary_index2word_label) #testY1999
    pass
def main(_):
    trainX, trainY, testX, testY = None, None, None, None
    #vocabulary_word2index, vocabulary_index2word = create_voabulary()
    #vocab_size = len(vocabulary_word2index)
    #vocabulary_word2index_label, vocabulary_index2word_label = create_voabulary_label()
    #train,test = load_data_with_multilabels(vocabulary_word2index, vocabulary_word2index_label,FLAGS.training_path) #[1,11,3,1998,1998]
    #trainX, trainY= train #TODO trainY1999
    #testX, testY = test #TODO testY1999
    #print("testX.shape:", np.array(testX).shape);print("testY.shape:", np.array(testY).shape)  # 2500个label
    # 2.Data preprocessing
    # Sequence padding
    #print("start padding & transform to one hot...")
    #trainX = pad_sequences(trainX, maxlen=FLAGS.sentence_len, value=0.)  # padding to max length
    #testX = pad_sequences(testX, maxlen=FLAGS.sentence_len, value=0.)  # padding to max length
    #print("end padding & transform to one hot...")
    word2index, label2index, trainX, trainY, vaildX, vaildY, testX, testY = load_data(
        FLAGS.cache_file_h5py, FLAGS.cache_file_pickle)
    index2label = {v: k for k, v in label2index.items()}
    vocab_size = len(word2index)
    print("cnn_model.vocab_size:", vocab_size)
    num_classes = len(label2index)
    print("num_classes:", num_classes)
    num_examples, FLAGS.sentence_len = trainX.shape
    print("num_examples of training:", num_examples, ";sentence_len:",
          FLAGS.sentence_len)

    #2.create session.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        #Instantiate Model
        fast_text = fastText(num_classes, FLAGS.learning_rate,
                             FLAGS.batch_size, FLAGS.decay_steps,
                             FLAGS.decay_rate, FLAGS.num_sampled,
                             FLAGS.sentence_len, vocab_size, FLAGS.embed_size,
                             FLAGS.is_training)
        #Initialize Save
        saver = tf.train.Saver()
        if os.path.exists(FLAGS.ckpt_dir + "checkpoint"):
            print("Restoring Variables from Checkpoint")
            saver.restore(sess, tf.train.latest_checkpoint(FLAGS.ckpt_dir))
        else:
            print('Initializing Variables')
            sess.run(tf.global_variables_initializer())
            if FLAGS.use_embedding:  #load pre-trained word embedding
                vocabulary_index2word = {v: k for k, v in word2index.items()}
                assign_pretrained_word_embedding(sess, vocabulary_index2word,
                                                 vocab_size, fast_text)

        curr_epoch = sess.run(fast_text.epoch_step)
        #3.feed data & training
        number_of_training_data = len(trainX)
        batch_size = FLAGS.batch_size
        for epoch in range(curr_epoch,
                           FLAGS.num_epochs):  #range(start,stop,step_size)
            loss, acc, counter = 0.0, 0.0, 0
            for start, end in zip(
                    range(0, number_of_training_data, batch_size),
                    range(batch_size, number_of_training_data, batch_size)):
                if epoch == 0 and counter == 0:
                    print("trainX[start:end]:", trainX[start:end]
                          )  #2d-array. each element slength is a 100.
                    print(
                        "trainY[start:end]:", trainY[start:end]
                    )  #a list,each element is a list.element:may be has 1,2,3,4,5 labels.
                    #print("trainY1999[start:end]:",trainY1999[start:end])
                    train_Y_batch = process_labels(trainY[start:end])
                curr_loss, _ = sess.run(
                    [fast_text.loss_val, fast_text.train_op],
                    feed_dict={
                        fast_text.sentence: trainX[start:end],
                        fast_text.labels: train_Y_batch
                    })  #fast_text.labels_l1999:trainY1999[start:end]
                loss, counter = loss + curr_loss, counter + 1  #acc+curr_acc,
                if counter % 500 == 0:
                    print("Epoch %d\tBatch %d\tTrain Loss:%.3f" %
                          (epoch, counter, loss / float(counter))
                          )  #\tTrain Accuracy:%.3f--->,acc/float(counter)

            #epoch increment
            print("going to increment epoch counter....")
            sess.run(fast_text.epoch_increment)

            # 4.validation
            print("epoch:", epoch, "validate_every:", FLAGS.validate_every,
                  "validate or not:", (epoch % FLAGS.validate_every == 0))
            if epoch % FLAGS.validate_every == 0:
                eval_loss, eval_accuracy = do_eval(
                    sess, fast_text, testX, testY, batch_size,
                    index2label)  #testY1999,eval_acc
                print(
                    "Epoch %d Validation Loss:%.3f\tValidation Accuracy: %.3f"
                    %
                    (epoch, eval_loss,
                     eval_accuracy))  #,\tValidation Accuracy: %.3f--->eval_acc
                #save model to checkpoint
                save_path = FLAGS.ckpt_dir + "model.ckpt"
                saver.save(sess, save_path,
                           global_step=epoch)  #fast_text.epoch_step

        # 5.最后在测试集上做测试,并报告测试准确率 Test
        test_loss, test_acc = do_eval(sess, fast_text, testX, testY,
                                      batch_size, index2label)  #testY1999
    pass