예제 #1
0
    def get_label_scores(self, lstm_outputs, left, right):
        '''
            Get label scores and fix the score of empty label to zero.
        '''

        non_empty_label_scores = self.f_label(
            self.get_span_encoding(lstm_outputs, left, right))
        return dy.concatenate([dy.zeros(1), non_empty_label_scores])
예제 #2
0
    def train_fake(self, input, targets, epsilon = 1e-10):
        init_states = [input, dy.zeros(self.dim_lstm)]
        
        state = self.lstm.initial_state(init_states)

        loss = dy.zeros(1)
        W = dy.parameter(self.h2o)
        b = dy.parameter(self.b)
        
        state = state.add_input(self.lu[targets[0]])
        
        for target in targets[1:]:
            loss += dy.pickneglogsoftmax(W * state.output() + b + epsilon, target)
            
            embedding = self.lu[target]
            state = state.add_input(embedding)
        
        return loss
예제 #3
0
    def _get_loss(self, input, targets, epsilon=1e-10):
        layers = self.compute_output_layer(input)

        log_out = dy.log(layers[-1] + epsilon)

        loss = dy.zeros(1)
        for t in targets:
            loss += dy.pick(log_out, t)

        r = np.random.randint(self.dim_out)
        while r in targets:
            r = np.random.randint(self.dim_out)
        loss += dy.log(1 - dy.pick(layers[-1], r) + epsilon)
        #loss -= dy.pick(log_out, r)

        return -loss
예제 #4
0
    def _get_loss_and_prediction(self, input, targets, epsilon=1e-10):
        layers = self.compute_output_layer(input)
        output = layers[-1].value()
        res = {i for i in output if i > 0.5}

        log_out = dy.log(layers[-1] + epsilon)

        loss = dy.zeros(1)
        for t in targets:
            loss += dy.pick(log_out, t)

        r = np.random.randint(self.dim_out)
        while r in targets:
            r = np.random.randint(self.dim_out)
        loss += dy.log(1 - dy.pick(layers[-1], r) + epsilon)
        #loss -= dy.pick(log_out, r)

        return -loss, res
예제 #5
0
    def get_embeddings(self,
                       word_inds,
                       tag_inds,
                       is_train=False,
                       train_bert_embedding=None):
        if is_train:
            self.char_lstm.set_dropout(self.dropout)
        else:
            self.char_lstm.disable_dropout()

        embeddings = []
        for idx, (w, t) in enumerate(zip(word_inds, tag_inds)):
            if w > 2:
                count = self.vocab.word_freq_list[w]
                if not count or (is_train
                                 and np.random.rand() < self.unk_param /
                                 (self.unk_param + count)):
                    w = 0

            tag_embedding = self.tag_embeddings[t]
            chars = list(self.vocab.i2w[w]) if w > 2 else [self.vocab.i2w[w]]
            char_lstm_outputs = self.char_lstm.transduce([
                self.char_embeddings[self.vocab.c2i[char]]
                for char in [Vocabulary.START] + chars + [Vocabulary.STOP]
            ])
            char_embedding = dy.concatenate([
                char_lstm_outputs[-1][:self.char_lstm_dim],
                char_lstm_outputs[0][self.char_lstm_dim:]
            ])
            word_embedding = self.word_embeddings[w]
            embs = [tag_embedding, char_embedding, word_embedding]
            if train_bert_embedding is not None:
                if w != 0:
                    embs.append(dy.inputTensor(train_bert_embedding[idx]))
                else:
                    embs.append(dy.zeros(768))
            embeddings.append(dy.concatenate(embs))

        return embeddings
예제 #6
0
    def parse(self, data, is_train=False):
        if is_train:
            self.lstm.set_dropout(self.dropout)
        else:
            self.lstm.disable_dropout()

        word_indices = data['w']
        tag_indices = data['t']
        gold_tree = data['tree']
        sentence = gold_tree.sentence

        embeddings = self.get_embeddings(word_indices, tag_indices, is_train)
        lstm_outputs = self.lstm.transduce(embeddings)

        def helper(force_gold):
            if force_gold:
                assert is_train

            chart = {}

            for length in range(1, len(sentence) + 1):
                for left in range(0, len(sentence) + 1 - length):
                    right = left + length - 1

                    label_scores = self.f_label(
                        self.get_span_encoding(lstm_outputs, left, right))
                    if is_train:
                        oracle_label, oracle_label_index, crossing = self.get_oracle_label(
                            gold_tree, left, right)
                    if force_gold:
                        label = oracle_label
                        label_score = label_scores[oracle_label_index]
                    else:
                        if is_train:
                            label_scores = self.augment(
                                label_scores, oracle_label_index, crossing)

                        argmax_label, argmax_label_index = self.predict_label(
                            label_scores, gold_tree, left, right)

                        label = argmax_label
                        label_score = label_scores[argmax_label_index]

                    if length == 1:
                        tree = self.gen_leaf_tree(left, label)
                        chart[left, right] = [tree], label_score
                        continue

                    if force_gold:
                        oracle_splits = gold_tree.span_splits(left, right)
                        oracle_split = min(oracle_splits)
                        best_split = oracle_split
                    else:
                        best_split = max(
                            range(left + 1, right + 1),
                            key=lambda split: chart[left, split - 1][1].value(
                            ) + chart[split, right][1].value())

                    left_trees, left_score = chart[left, best_split - 1]
                    right_trees, right_score = chart[best_split, right]

                    childrens = left_trees + right_trees
                    childrens = self.gen_nonleaf_tree(childrens, label)

                    chart[left,
                          right] = (childrens,
                                    label_score + left_score + right_score)

            childrens, score = chart[0, len(sentence) - 1]
            assert len(childrens) == 1
            return childrens[0], score

        tree, score = helper(False)
        tree.propagate_sentence(sentence)

        if is_train:
            oracle_tree, oracle_score = helper(True)
            oracle_tree.propagate_sentence(sentence)

            assert str(oracle_tree) == str(gold_tree)

            correct = (str(tree) == str(oracle_tree))
            loss = dy.zeros(1) if correct else score - oracle_score

            return loss, tree, 1
        else:
            return tree