示例#1
0
def main():
    tf.reset_default_graph()

    with tf.Session() as sess:

        embeddings = generate_embedding_mat(vocab_size, embedding_size)

        embed_file = './data/emb_mat.pkl'
        if not os.path.exists(embed_file):
            save_embedding_info(embed_file)

        classifier = ModelAPI(model_dir, embed_path)
        classifier.load_config()
        classifier.config["token_emb_mat"] = embeddings
        model = SelfAttRNN()
        classifier.build_graph(sess, model, "/gpu:2")

        actor = ActorNetwork(sess, lstm_dim, optimizer, lr, embeddings)

        saver = tf.train.Saver()
        model_file = "./checkpoints/{}".format(model_name)
        restore_model(sess, saver, model_file)
        #        params = get_model_params(sess)
        #        get_simplify(sess, actor)

        epoch = 5
        try:
            for e in range(epoch):
                if use_RL:
                    train(sess,
                          actor,
                          classifier,
                          batchsize,
                          classifier_trainable=True)
                else:
                    train_classifier(sess, classifier)

                saver.save(sess, model_file)

        except KeyboardInterrupt:
            print(
                '[INFO] Interrupt manually, try saving checkpoint for now...')
            saver.save(sess, model_file)
    def build_placeholder(self, config):

        self.config = config
        self.token_emb_mat = self.config["token_emb_mat"]
        self.vocab_size = int(self.config["vocab_size"])
        self.max_length = int(self.config["max_length"])
        self.emb_size = int(self.config["emb_size"])
        self.extra_symbol = self.config["extra_symbol"]
        self.scope = self.config["scope"]
        self.num_features = int(self.config["num_features"])
        self.num_classes = int(self.config["num_classes"])
        self.ema = self.config.get("ema", False)
        self.grad_clipper = float(self.config.get("grad_clipper", 10.0))

        print("--------vocab size---------", self.vocab_size)
        print("--------max length---------", self.max_length)
        print("--------emb size-----------", self.emb_size)
        print("--------extra symbol-------", self.extra_symbol)
        print("--------emb matrix---------", self.token_emb_mat.shape)

        # ---- place holder -----
        self.sent1_token = tf.placeholder(tf.int32, [None, None],
                                          name='sent1_token')
        self.gold_label = tf.placeholder(tf.int32, [None], name='gold_label')
        self.sent1_len = tf.placeholder(tf.int32, [None],
                                        name='sent1_token_length')

        self.is_train = tf.placeholder(tf.bool, [], name='is_train')

        self.features = tf.placeholder(tf.float32,
                                       shape=[None, self.num_features],
                                       name="features")

        self.one_hot_label = tf.one_hot(self.gold_label, 2)

        # ------------ other ---------
        self.sent1_token_mask = tf.cast(self.sent1_token, tf.bool)
        self.sent1_token_len = tf.reduce_sum(
            tf.cast(self.sent1_token_mask, tf.int32), -1)

        # ---------------- for dynamic learning rate -------------------
        self.learning_rate = tf.placeholder(tf.float32, [], 'learning_rate')
        self.learning_rate_value = float(self.config["learning_rate"])

        self.dropout_keep_prob = tf.placeholder(tf.float32,
                                                name="dropout_keep_prob")

        self.emb_mat = generate_embedding_mat(self.vocab_size,
                                              emb_len=self.emb_size,
                                              init_mat=self.token_emb_mat,
                                              extra_symbol=self.extra_symbol,
                                              scope='gene_token_emb_mat')

        self.s1_emb = tf.nn.embedding_lookup(self.emb_mat,
                                             self.sent1_token)  # bs,sl1,tel
        self.global_step = tf.Variable(0, name="global_step", trainable=False)

        self.learning_rate_updated = False
        # ------ start ------
        self.pred_probs = None
        self.logits = None
        self.loss = None
        self.accuracy = None
        self.var_ema = None
        self.ema = None
        self.opt = None
        self.train_op = None