def inference_step(self, X_data): """ 利用GAN去计算每个分类的类别,X——data会自动的拓展到合适的num数目 :param X_data: :return: """ num_class = self.config.num_class Y_data = [[i for i in range(num_class)] for x in X_data] # print(np.shape(Y_data)) X_data = [[x for i in range(num_class)] for x in X_data] # print(np.shape(X_data)) Y_data = np.reshape(Y_data, [-1]) X_data = np.reshape(X_data, [-1, self.config.input_dim]) Y_data = one_hot(Y_data, num_class) probs = self.sess.run([self.infer_discriminator], feed_dict={ self.X_input: X_data, self.Y: Y_data }) # batch_size * num_class probs = np.reshape(probs, [-1, num_class]) # print(probs) predict_label = np.argmax(probs, axis=-1) return predict_label
def next_train_data(self): batch_instances = self.next_batch() pos_tag, y, x, t, c, pos_c, pos_t = [list() for _ in range(7)] for instance in batch_instances: words = instance['words'] pos_taggings = instance['pos_taggings'] marks = instance['marks'] label = instance['label'] index_candidates = find_candidates(marks, ['B']) assert (len(index_candidates)) == 1 index_triggers = find_candidates(marks, ['T']) # assert (len(index_triggers)) == 1 y.append(label) marks = marks + ['A'] * (self.max_sequence_length - len(marks)) words = words + ['<eos>'] * (self.max_sequence_length - len(words)) pos_taggings = pos_taggings + ['*'] * (self.max_sequence_length - len(pos_taggings)) pos_taggings = list( map(lambda x: self.pos_taggings_id[x], pos_taggings)) pos_tag.append(pos_taggings) index_words = list(map(lambda x: self.word_id[x], words)) x.append(index_words) pos_candidate = [i for i in range(-index_candidates[0], 0)] + [ i for i in range(0, self.max_sequence_length - index_candidates[0]) ] pos_c.append(pos_candidate) pos_trigger = [i for i in range(-index_triggers[0], 0)] + [ i for i in range(0, self.max_sequence_length - index_triggers[0]) ] pos_t.append(pos_trigger) t.append([index_words[index_triggers[0]]] * self.max_sequence_length) c.append([index_words[index_candidates[0]]] * self.max_sequence_length) # print(len(words), len(marks), len(pos_taggings), len(index_words), len(pos_candidate), len(pos_trigger)) assert len(words) == len(marks) == len(pos_taggings) == len( index_words) == len(pos_candidate) == len(pos_trigger) assert len(y) == len(x) == len(t) == len(c) == len(pos_c) == len( pos_t) == len(pos_tag) return x, t, c, one_hot(y, self.label_id, len(self.all_labels)), pos_c, pos_t, pos_tag