def get_label_scores(self, lstm_outputs, left, right): ''' Get label scores and fix the score of empty label to zero. ''' non_empty_label_scores = self.f_label( self.get_span_encoding(lstm_outputs, left, right)) return dy.concatenate([dy.zeros(1), non_empty_label_scores])
def train_fake(self, input, targets, epsilon = 1e-10): init_states = [input, dy.zeros(self.dim_lstm)] state = self.lstm.initial_state(init_states) loss = dy.zeros(1) W = dy.parameter(self.h2o) b = dy.parameter(self.b) state = state.add_input(self.lu[targets[0]]) for target in targets[1:]: loss += dy.pickneglogsoftmax(W * state.output() + b + epsilon, target) embedding = self.lu[target] state = state.add_input(embedding) return loss
def _get_loss(self, input, targets, epsilon=1e-10): layers = self.compute_output_layer(input) log_out = dy.log(layers[-1] + epsilon) loss = dy.zeros(1) for t in targets: loss += dy.pick(log_out, t) r = np.random.randint(self.dim_out) while r in targets: r = np.random.randint(self.dim_out) loss += dy.log(1 - dy.pick(layers[-1], r) + epsilon) #loss -= dy.pick(log_out, r) return -loss
def _get_loss_and_prediction(self, input, targets, epsilon=1e-10): layers = self.compute_output_layer(input) output = layers[-1].value() res = {i for i in output if i > 0.5} log_out = dy.log(layers[-1] + epsilon) loss = dy.zeros(1) for t in targets: loss += dy.pick(log_out, t) r = np.random.randint(self.dim_out) while r in targets: r = np.random.randint(self.dim_out) loss += dy.log(1 - dy.pick(layers[-1], r) + epsilon) #loss -= dy.pick(log_out, r) return -loss, res
def get_embeddings(self, word_inds, tag_inds, is_train=False, train_bert_embedding=None): if is_train: self.char_lstm.set_dropout(self.dropout) else: self.char_lstm.disable_dropout() embeddings = [] for idx, (w, t) in enumerate(zip(word_inds, tag_inds)): if w > 2: count = self.vocab.word_freq_list[w] if not count or (is_train and np.random.rand() < self.unk_param / (self.unk_param + count)): w = 0 tag_embedding = self.tag_embeddings[t] chars = list(self.vocab.i2w[w]) if w > 2 else [self.vocab.i2w[w]] char_lstm_outputs = self.char_lstm.transduce([ self.char_embeddings[self.vocab.c2i[char]] for char in [Vocabulary.START] + chars + [Vocabulary.STOP] ]) char_embedding = dy.concatenate([ char_lstm_outputs[-1][:self.char_lstm_dim], char_lstm_outputs[0][self.char_lstm_dim:] ]) word_embedding = self.word_embeddings[w] embs = [tag_embedding, char_embedding, word_embedding] if train_bert_embedding is not None: if w != 0: embs.append(dy.inputTensor(train_bert_embedding[idx])) else: embs.append(dy.zeros(768)) embeddings.append(dy.concatenate(embs)) return embeddings
def parse(self, data, is_train=False): if is_train: self.lstm.set_dropout(self.dropout) else: self.lstm.disable_dropout() word_indices = data['w'] tag_indices = data['t'] gold_tree = data['tree'] sentence = gold_tree.sentence embeddings = self.get_embeddings(word_indices, tag_indices, is_train) lstm_outputs = self.lstm.transduce(embeddings) def helper(force_gold): if force_gold: assert is_train chart = {} for length in range(1, len(sentence) + 1): for left in range(0, len(sentence) + 1 - length): right = left + length - 1 label_scores = self.f_label( self.get_span_encoding(lstm_outputs, left, right)) if is_train: oracle_label, oracle_label_index, crossing = self.get_oracle_label( gold_tree, left, right) if force_gold: label = oracle_label label_score = label_scores[oracle_label_index] else: if is_train: label_scores = self.augment( label_scores, oracle_label_index, crossing) argmax_label, argmax_label_index = self.predict_label( label_scores, gold_tree, left, right) label = argmax_label label_score = label_scores[argmax_label_index] if length == 1: tree = self.gen_leaf_tree(left, label) chart[left, right] = [tree], label_score continue if force_gold: oracle_splits = gold_tree.span_splits(left, right) oracle_split = min(oracle_splits) best_split = oracle_split else: best_split = max( range(left + 1, right + 1), key=lambda split: chart[left, split - 1][1].value( ) + chart[split, right][1].value()) left_trees, left_score = chart[left, best_split - 1] right_trees, right_score = chart[best_split, right] childrens = left_trees + right_trees childrens = self.gen_nonleaf_tree(childrens, label) chart[left, right] = (childrens, label_score + left_score + right_score) childrens, score = chart[0, len(sentence) - 1] assert len(childrens) == 1 return childrens[0], score tree, score = helper(False) tree.propagate_sentence(sentence) if is_train: oracle_tree, oracle_score = helper(True) oracle_tree.propagate_sentence(sentence) assert str(oracle_tree) == str(gold_tree) correct = (str(tree) == str(oracle_tree)) loss = dy.zeros(1) if correct else score - oracle_score return loss, tree, 1 else: return tree