Exemplo n.º 1
0
    sents.append(sent)

    unknown_words = 0
    unknown_tags = 0
    unknown_rels = 0

    fout = open("result.conllu", "w")

    num_sents = len(sents)
    sents_per_interval = num_sents // 10
    next_dot = sents_per_interval
    current_sent = 0
    print("testing: [", end="", flush=True)

    word_vectorizer.reset_counters()
    tag_vectorizer.reset_counters()

    for sent in sents:
        word_embeddings = []
        tag_embeddings = []
        targets_heads = []
        targets_rels = []

        words = []
        tags = []

        for dep, rel, head in sent:
            word, tag = dep
            words.append(word)
            tags.append(tag)
Exemplo n.º 2
0
class Model:
    def __init__(self, name, sents, vectorizer_words, vectorizer_forms, embedding_size,
                 tag_sents, tag_embedding_size, context_size,
                 lrs=(0.1, 0.1, 0.1), lr_decrease_factor=0.5, epochs_per_decrease=10):
        ######################################################################
        # Model's parameters.
        # 'sents' is a list of sentences of tuples ((form, word, tag), rel, head)
        self.name = name
        self.sents = sents
        self.embedding_size = embedding_size
        self.context_size = context_size

        ######################################################################
        # Load or create indices.
        # Common
        self.path_base = "internal"
        self.num_words = 0
        self.root_tag = "root"

        # CUDA flag
        self.is_cuda_available = torch.cuda.is_available()

        # For POS tags:
        self.tags = set()
        self.num_tags = 0
        self.tag2index = {}
        self.index2tag = {}

        # For chunk tags:
        self.chunks = set()
        self.num_chunks = 0
        self.chunk2index = {}
        self.index2chunk = {}

        # For relation tags:
        self.rels = set()
        self.num_rels = 0
        self.rel2index = {}
        self.index2rel = {}

        # Update database
        self.create_or_load_indices()
        if self.num_words == 0:
            self.num_words = self.get_num_words(self.sents)

        ######################################################################
        # Logic.
        # Learning rate controls
        self.lrs = lrs
        self.lr_decrease_factor = lr_decrease_factor
        self.epochs_per_decrease = epochs_per_decrease

        # Define machines
        self.vectorizer = Vectorizer(vectorizer_words, vectorizer_forms, name,
                                     embedding_size, filler=ZeroFiller(embedding_size),
                                     ce_enabled=True)

        # self.vectorizer = FastTextVectorizer(name, embedding_size * 2, "ft_sg_syntagrus.bin")

        self.tag_vectorizer = Vectorizer(tag_sents, None, name + "_pos",
                                         tag_embedding_size, filler=ZeroFiller(tag_embedding_size),
                                         ce_enabled=False, tf_enabled=False)

        # Tags embeddings (H).
        # Chunker will get linear combination as an input:
        #    I = H^T * p
        #    p - probabilities vector
        self.tag_embeddings = []
        for i in range(self.num_tags):
            tag = self.index2tag[i].lower()
            self.tag_embeddings.append(self.tag_vectorizer(tag, tag))
        self.tag_embeddings = torch.stack(self.tag_embeddings)
        if self.is_cuda_available:
            self.tag_embeddings = self.tag_embeddings.cuda()

        # Vector size is 1 (TF) + 100 (Word embedding) + 100 (Char grams embedding)
        self.vector_size = self.vectorizer.get_vector_size()

        self.tag_size = self.tag_vectorizer.get_vector_size()

        # Chunk size.
        # Benchmark is 200 (POS hidden) + 201 (embedding) + NUM_TAGS (probabilities)
        self.chunk_size = 2 * embedding_size + self.vector_size + self.tag_size

        # Parse size -- input size for parser.
        # When chunking is not available, parse size is equal to chunk size
        self.parse_size = self.chunk_size

        self.log("tagger input size: {}".format(self.vector_size))
        self.log("chunker input size: {}".format(self.chunk_size))
        self.log("parser input size: {}".format(self.parse_size))

        self.tagger = Tagger(self.vector_size, self.num_tags, "GRU", embedding_size)
        # self.chunker = Tagger(self.chunk_size, self.num_chunks, "LSTM", embedding_size)
        self.parser = SyntaxParser(0, 0, 0, 0, self.parse_size, embedding_size, self.num_rels)

        self.is_tagger_trained = False
        # self.is_chunker_trained = False
        self.is_parser_trained = False

        self.tagger_name = "pos tagging"
        # self.chunker_name = "chunking"
        self.parser_name = "parsing"

        # Try to load from file
        self.tagger_path = "{}/model_pos_{}.pt".format(self.path_base, self.name)
        # self.chunker_path = "{}/model_chunk_{}.pt".format(self.path_base, self.name)
        self.parser_path = "{}/model_parse_{}.pt".format(self.path_base, self.name)

        if os.path.exists(self.tagger_path):
            self.log("Loading POS tagger")
            self.tagger = torch.load(self.tagger_path)
            self.tagger.unit.flatten_parameters()
            self.is_tagger_trained = True
            self.log("Done")

        # if os.path.exists(self.chunker_path):
        #     self.log("Loading chunker")
        #     self.chunker = torch.load(self.chunker_path)
        #     self.chunker.unit.flatten_parameters()
        #     self.is_chunker_trained = True
        #     self.log("Done")

        if os.path.exists(self.parser_path):
            self.log("Loading parser")
            self.parser = torch.load(self.parser_path)
            self.parser.unit.flatten_parameters()
            self.is_parser_trained = True
            self.log("Done")

    ##########################################################################
    def train(self, sents, num_epochs, machines):
        ######################################################################
        # Define optimizers
        tag_optimizer = optim.SGD(self.tagger.parameters(), lr=self.lrs[0])
        # chunk_optimizer = optim.SGD(self.chunker.parameters(), lr=self.lrs[1])
        chunk_optimizer = None

        # Parameters for both machines
        params = list(self.tagger.parameters()) + list(self.parser.parameters())
        parse_optimizer = optim.SGD(params, lr=self.lrs[2])
        tag_loss_function = nn.NLLLoss()
        chunk_loss_function = nn.NLLLoss()
        parse_loss_function = nn.NLLLoss()

        ######################################################################
        # Run loop
        start_time = time.time()
        for epoch in range(num_epochs):
            print("epoch #{}: ".format(epoch), end="", flush=True)
            optimizers = [tag_optimizer, chunk_optimizer, parse_optimizer]
            self.loop(sents, optimizers, [tag_loss_function, chunk_loss_function, parse_loss_function],
                      [None, None, None])
            self.decrease_lr(optimizers, epoch + 1, self.lr_decrease_factor,
                             self.epochs_per_decrease)
        print("elapsed: {} s".format(int(time.time() - start_time)))

        # Out misses
        # self.print_vectorizer_misses()

        # Save model
        torch.save(self.tagger, self.tagger_path)
        # torch.save(self.chunker, self.chunker_path)
        torch.save(self.parser, self.parser_path)

        self.log("Done")

    ##########################################################################
    def test(self, sents):
        # Collect statistics
        tag_score = TagScore(self.tags)
        # chunk_score = ChunkScore(self.chunks)
        chunk_score = None
        parse_score = ParserScore()

        num_correct_tags = 0
        num_correct_chunks = 0
        num_words = 0

        start_time = time.time()
        self.loop(sents, [None, None, None], [nn.NLLLoss(), nn.NLLLoss(), nn.NLLLoss()],
                  [tag_score, chunk_score, parse_score])
        print("elapsed: {} s".format(int(time.time() - start_time)))

        # Out statistics
        print("POS Tagging:")
        f1_s = []
        has_zero = False
        for tag in sorted(self.tags):
            stat = tag_score.stats[tag]

            if stat.num_gold_predicted == 0 or stat.num_gold == 0 or stat.num_predicted == 0:
                print("\tskipped: {:>5} ({} items)".format(tag, stat.num_gold))
                continue

            num_words += stat.num_gold
            num_correct_tags += stat.num_gold_predicted

            precision = stat.num_gold_predicted / max(stat.num_predicted, 1.0)
            recall = stat.num_gold_predicted / max(stat.num_gold, 1.0)

            f1 = 0.0
            if math.isclose(precision, 0.0) or math.isclose(recall, 0.0):
                has_zero = True
            else:
                f1 = hmean([precision, recall])

            f1_s.append(f1)

            print("\t{:>5}: P = {:4.2f}%, R = {:4.2f}%, F1 = {:4.2f}% ({} items)".format(
                tag, precision * 100, recall * 100, f1 * 100, stat.num_gold))

            # ratio = 0
            # if stat[1] != 0:
            #     ratio = stat[0] / stat[1]
            # ratio *= 100
            #
            # print("\t{:>4}: {:4} / {:4} = {:4.2f}%".format(tag, stat[0], stat[1], ratio))

        # print("Chunking:")
        # for chunk in sorted(self.chunks):
        #     stat = chunk_score.stats[chunk]
        #     num_correct_chunks += stat[0]
        #
        #     ratio = 0
        #     if stat[1] != 0:
        #         ratio = stat[0] / stat[1]
        #     ratio *= 100
        #
        #     print("\t{:>4}: {:4} / {:4} = {:4.2f}%".format(chunk, stat[0], stat[1], ratio))

        # self.print_vectorizer_misses()

        # POS aggregated
        print("Total words:", num_words)
        print("Correct POS tags:", num_correct_tags, "({:4.2f}%)".format(
            num_correct_tags / num_words * 100.0))

        average_f1 = 0.0
        if not has_zero:
            average_f1 = hmean(f1_s)
        print("Average F1 = {:4.2f}%".format(average_f1 * 100))

        # Chunks aggregated
        # print("Correct chunk tags:", num_correct_chunks, "({:4.2f}%)".format(
        #     num_correct_chunks / num_words * 100.0))
        # precision = chunk_score.num_retrieved_relevant / chunk_score.num_retrieved
        # recall = chunk_score.num_retrieved_relevant / chunk_score.num_relevant
        # f1_score = hmean([precision, recall])
        # print("\tPrecision: {:f}".format(precision))
        # print("\tRecall: {:f}".format(recall))
        # print("\tF1 score: {:f}".format(f1_score))

        print("Parsing:")
        print("\tUAS: {} from {} ({:4.2f}%)".format(parse_score.num_unlabeled_arcs, num_words,
                                                    parse_score.num_unlabeled_arcs / num_words * 100.0))
        print("\tLAS: {} from {} ({:4.2f}%)".format(
            parse_score.num_labeled_arcs, num_words,
            parse_score.num_labeled_arcs / num_words * 100.0
        ))
        print("\tMUAS: {} from {} ({:4.2f}%)".format(
            parse_score.num_modified_unlabeled_arcs, num_words,
            parse_score.num_modified_unlabeled_arcs / num_words * 100.0
        ))
        print("\tCRel: {} from {} ({:4.2f}%)".format(
            parse_score.num_labels, num_words,
            parse_score.num_labels / num_words * 100.0))
        print("\tUEM: {} from {} ({:4.2f}%)".format(
            parse_score.num_unlabeled_trees, len(sents),
            parse_score.num_unlabeled_trees / len(sents) * 100.0
        ))
        print("\tLEM: {} from {} ({:4.2f}%)".format(
            parse_score.num_labeled_trees, len(sents),
            parse_score.num_labeled_trees / len(sents) * 100.0
        ))
        print("\tMUEM: {} from {} ({:4.2f}%)".format(
            parse_score.num_modified_unlabeled_trees, len(sents),
            parse_score.num_modified_unlabeled_trees / len(sents) * 100.0
        ))

        self.log("Done")

    ##########################################################################
    # Lose control and decrease the pace
    # Warped and bewitched
    # It's time to erase
    @staticmethod
    def decrease_lr(optimizers, epoch, factor, epoch_interval):
        if epoch % epoch_interval != 0:
            return

        print("lr is multiplied by {:f}".format(factor))

        for optimizer in optimizers:
            if optimizer is not None:
                for param_group in optimizer.param_groups:
                    param_group["lr"] *= factor

    ##########################################################################
    # Used both by 'train' and 'test'.
    # The only difference is that optimizers mustn't be provided
    # during testing
    def loop(self, sents, optimizers, loss_functions, scores):
        tag_optimizer = optimizers[0]
        chunk_optimizer = optimizers[1]
        parse_optimizer = optimizers[2]

        tag_loss_function = loss_functions[0]
        chunk_loss_function = loss_functions[1]
        parse_loss_function = loss_functions[2]

        tag_scores = scores[0]
        chunk_scores = scores[1]
        parse_scores = scores[2]

        # Total number of words
        num_words = self.get_num_words(sents)
        words_per_interval = num_words // 10

        # Average losses
        tag_loss = 0
        chunk_loss = 0
        parse_loss = 0

        # Reset vectorizer
        self.vectorizer.reset_counters()

        # Reset progress bar
        print("progress [", end="", flush=True)
        current_word = 0
        next_interval = words_per_interval

        # Dump
        fout = open("result.conllu", "w")

        for sent in sents:
            # Reset optimizers and machines
            # if tag_optimizer is not None:
            #     tag_optimizer.zero_grad()
            # if chunk_optimizer is not None:
            #     chunk_optimizer.zero_grad()
            if parse_optimizer is not None:
                parse_optimizer.zero_grad()
            self.tagger.reset()
            # self.chunker.reset()

            ##############################################################
            # POS Tagger.
            # Prepare input for tagger and targets for chunker
            sequence = []
            tag_targets = []
            # chunk_targets = []
            parse_head_targets = []
            parse_rel_targets = []
            for dep, rel, head in sent:
                form, word, tag = dep
                sequence.append(self.vectorizer(form, word))
                tag_targets.append(self.tag2index[tag])
                # chunk_targets.append(self.chunk2index[chunk])
                parse_head_targets.append(head)
                parse_rel_targets.append(self.rel2index[rel])
            sequence = Variable(torch.stack(sequence, dim=0))
            tag_targets = Variable(torch.LongTensor(tag_targets))
            # chunk_targets = Variable(torch.LongTensor(chunk_targets))
            parse_head_targets = Variable(torch.LongTensor(parse_head_targets))
            parse_rel_targets = Variable(torch.LongTensor(parse_rel_targets))

            if self.is_cuda_available:
                tag_targets = tag_targets.cuda()
                # chunk_targets = chunk_targets.cuda()
                parse_head_targets = parse_head_targets.cuda()
                parse_rel_targets = parse_rel_targets.cuda()

            # Optimize tagger
            tag_output = self.tagger(sequence.view((len(sent), 1, -1)))
            current_loss = tag_loss_function(tag_output, tag_targets)
            tag_loss += current_loss.data[0]
            # if tag_optimizer is not None:
            #     current_loss.backward()
            #     tag_optimizer.step()

            ##############################################################
            # Chunking.
            # Prepare input for chunker
            sequence = Variable(sequence.data)
            probabilities = torch.exp(tag_output)
            probabilities = probabilities.mm(Variable(self.tag_embeddings))
            tagger_hidden = self.tagger.last_output.view((len(sent), -1))
            sequence = torch.cat((tagger_hidden, sequence, probabilities), dim=1)
            # sequence = Variable(torch.cat((sequence, probabilities), dim=1))
            # sequence = Variable(sequence)

            # Optimize chunker
            # chunk_output = self.chunker(sequence.view((len(sent), 1, -1)))
            # current_loss = chunk_loss_function(chunk_output, chunk_targets)
            # chunk_loss += current_loss.data[0]
            # if chunk_optimizer is not None:
            #     current_loss.backward()
            #     chunk_optimizer.step()

            # Optimize parser
            parse_output_heads, parse_output_rels = self.parser(sequence.view(len(sent), 1, -1))
            current_parser_loss = parse_loss_function(parse_output_heads, parse_head_targets)
            current_parser_loss += parse_loss_function(parse_output_rels, parse_rel_targets)
            parse_loss += current_parser_loss.data[0]

            current_loss += current_parser_loss

            if parse_optimizer is not None:
                current_loss.backward()
                parse_optimizer.step()

            ##################################################################
            # Collect stats if necessary
            actual_chunks = []
            is_sent_correct = True
            is_labeled_sent_correct = True
            heads = [0 for i in range(len(sent))]
            rels = ["" for i in range(len(sent))]
            probabilities = [[] for i in range(len(sent))]
            forms = []
            words = []
            tags = []

            if tag_scores is not None or chunk_scores is not None or parse_scores is not None:
                # Output is SEQ_LEN x NUM_TAGS
                for i in range(len(sent)):
                    forms.append(sent[i][0][0])
                    words.append(sent[i][0][1])

                    if tag_scores is not None:
                        maximum, indices = tag_output[i].max(0)
                        predicted = indices.data[0]
                        expected = tag_targets.data[i]

                        tag = sent[i][0][2]
                        stat = tag_scores.stats[tag]
                        stat.num_gold += 1

                        predicted_tag = self.index2tag[predicted]
                        tags.append(predicted_tag)

                        if predicted == expected:
                            stat.num_gold_predicted += 1

                        tag_scores.stats[predicted_tag].num_predicted += 1

                    # if chunk_scores is not None:
                    #     maximum, indices = chunk_output[i].max(0)
                    #     predicted = indices.data[0]
                    #     expected = chunk_targets.data[i]
                    #
                    #     chunk = sent[i][2]
                    #     stat = chunk_scores.stats[chunk]
                    #     stat[1] += 1
                    #
                    #     if chunk[0] == "B":
                    #         chunk_scores.num_relevant += 1
                    #
                    #     actual_chunk = self.index2chunk[predicted]
                    #     actual_chunks.append(actual_chunk)
                    #     if actual_chunk[0] == "B":
                    #         chunk_scores.num_retrieved += 1
                    #
                    #     if predicted == expected:
                    #         stat[0] += 1

                    if parse_scores is not None:
                        for j in range(len(sent) + 1):
                            probabilities[i].append((j, parse_output_heads.data[i][j]))
                        probabilities[i].sort(key=lambda pair: pair[1], reverse=True)

                        maximum, indices = parse_output_heads[i].max(0)
                        predicted = indices.data[0]
                        expected = parse_head_targets.data[i]
                        head = predicted
                        heads[i] = head
                        is_head_correct = expected == predicted
                        if is_head_correct:
                            parse_scores.num_unlabeled_arcs += 1
                        else:
                            is_sent_correct = False
                            is_labeled_sent_correct = False

                        maximum, indices = parse_output_rels[i].max(0)
                        predicted = indices.data[0]
                        expected = parse_rel_targets.data[i]
                        rel = self.index2rel[predicted]
                        rels[i] = rel
                        if expected == predicted:
                            parse_scores.num_labels += 1
                            if is_head_correct:
                                parse_scores.num_labeled_arcs += 1
                        else:
                            is_labeled_sent_correct = False

            # Specially for parser
            if parse_scores is not None:
                fout.write("#text = {}\n".format(" ".join(forms)))

                if is_sent_correct:
                    parse_scores.num_unlabeled_trees += 1
                if is_labeled_sent_correct:
                    parse_scores.num_labeled_trees += 1

                parse_output_heads = parse_output_heads.data
                # Trying to turn random graph into well-formed tree.
                # Step 1. Find a root

                outliers = set()
                roots = set()
                unvisited = set()
                maximum = parse_output_heads[0][0] - 1.0
                index = -1
                for i in range(len(sent)):
                    if parse_output_heads[i][0] > maximum:
                        maximum = parse_output_heads[i][0]
                        index = i
                    if heads[i] == 0 or rels[i] == self.root_tag:
                        roots.add(i)
                    else:
                        unvisited.add(i)
                roots.add(index)

                if len(roots) > 0:
                    options = []
                    for node in roots:
                        tmp_outliers = outliers.copy()
                        tmp_outliers.update(roots)
                        tmp_outliers.remove(node)
                        options.append(try_build_tree(node, heads, unvisited.copy(), tmp_outliers))
                    options.sort(key=lambda t: len(t[2]))
                    root, visited, outliers = options[0]
                else:
                    raise RuntimeError("unreachable branch")
                    # maximum = parse_output_heads[0][0] - 1.0
                    # index = -1
                    # for i in range(len(sent)):
                    #     if parse_output_heads[i][0] > maximum:
                    #         maximum = parse_output_heads[i][0]
                    #         index = i
                    # root = index
                    # if root in outliers:
                    #     outliers.remove(root)
                    # if root in unvisited:
                    #     unvisited.remove(root)
                    # root, visited, outliers = try_build_tree(root, heads, unvisited, outliers)
                heads[root] = 0
                rels[root] = self.root_tag

                # Now 'unvisited' contains only unresolved references.
                # Use minimal algo to resolve arcs
                while len(outliers) > 0:
                    options = []
                    for node in outliers:
                        options.append(try_expand_tree(node, probabilities, heads, visited, outliers))
                    options.sort(key=lambda t: len(t[3]))
                    index, head, visited, outliers = options[0]
                    heads[index] = head

                is_modified_sent_correct = True
                for i in range(len(sent)):
                    if heads[i] == parse_head_targets.data[i]:
                        parse_scores.num_modified_unlabeled_arcs += 1
                    else:
                        is_modified_sent_correct = False
                    fout.write("{}\t{}\t{}\t{}\t_\t_\t{}\t{}\t{}:{}\t_\n".format(
                        i + 1, forms[i], words[i], tags[i], heads[i], rels[i], heads[i], rels[i]))
                fout.write("\n")

                if is_modified_sent_correct:
                    parse_scores.num_modified_unlabeled_trees += 1


            # Specially for chunker determine the quantity of retrieved relevant chunks
            # if chunk_scores is not None:
            #     gold_chunks = [chunk for word, tag, chunk in sent]
            #     num_retrieved_relevant = 0
            #     i = 0
            #     while i < len(gold_chunks):
            #         gold_chunk = gold_chunks[i]
            #         actual_chunk = actual_chunks[i]
            #
            #         if gold_chunk[0] == "B":
            #             is_correct = True
            #             while True:
            #                 if gold_chunk != actual_chunk:
            #                     is_correct = False
            #
            #                 i += 1
            #                 if i == len(gold_chunks):
            #                     break
            #
            #                 gold_chunk = gold_chunks[i]
            #                 actual_chunk = actual_chunks[i]
            #
            #                 if gold_chunk[0] != "I":
            #                     if actual_chunk[0] == "I":
            #                         is_correct = False
            #                     break
            #
            #             if is_correct:
            #                 num_retrieved_relevant += 1
            #         else:
            #             i += 1
            #     chunk_scores.num_retrieved_relevant += num_retrieved_relevant

            # Emulate progress bar
            current_word += len(sent)
            if current_word >= next_interval:
                next_interval += words_per_interval
                print('💪', end="", flush=True)

        # Debug epoch log
        print("], ATL: {:10.8f}, ACL: {:10.8f}, APL: {:10.8f}".format(
            tag_loss / len(sents), chunk_loss / len(sents), parse_loss / len(sents)))
        fout.close()

    ##########################################################################
    def print_vectorizer_misses(self):
        print("unknown words: {} from {} ({:4.2f}%)".format(
            self.vectorizer.num_word_misses, self.vectorizer.num_words,
            self.vectorizer.num_word_misses / self.vectorizer.num_words * 100.0
        ))
        print("unknown grams: {} from {} ({:4.2f}%)".format(
            self.vectorizer.num_char_misses, self.vectorizer.num_grams,
            self.vectorizer.num_char_misses / self.vectorizer.num_grams * 100.0
        ))

    ##########################################################################
    @staticmethod
    def get_num_words(sents):
        num_words = 0
        for sent in sents:
            num_words += len(sent)
        return num_words

    ##########################################################################
    def create_or_load_indices(self):
        tag_path = "{}/{}_tags.txt".format(self.path_base, self.name)
        chunk_path = "{}/{}_chunks.txt".format(self.path_base, self.name)
        rel_path = "{}/{}_rels.txt".format(self.path_base, self.name)

        create_tag_index = False
        create_chunk_index = False
        create_rel_index = False

        # Try load
        if os.path.exists(tag_path):
            # Load from existing data base
            self.log("Loading POS tag index from file")

            for line in open(tag_path):
                tag, index = line.split()
                index = int(index)
                self.tags.add(tag)
                self.tag2index[tag] = index
                self.index2tag[index] = tag

            self.num_tags = len(self.tags)
        else:
            # Create from scratch
            self.log("Creating POS tag index")
            create_tag_index = True

        # Try load
        if os.path.exists(chunk_path):
            # Load chunk index from file
            self.log("Loading chunk index from file")

            for line in open(chunk_path):
                chunk, index = line.split()
                index = int(index)
                self.chunks.add(chunk)
                self.chunk2index[chunk] = index
                self.index2chunk[index] = chunk

            self.num_chunks = len(self.chunks)
        else:
            # Create from scratch
            self.log("Creating chunk tag index")
            create_chunk_index = True

        # Try load
        if os.path.exists(rel_path):
            # Load rel index from file
            self.log("Loading rel index from file")

            for line in open(rel_path):
                rel, index = line.split()
                index = int(index)
                self.rels.add(rel)
                self.rel2index[rel] = index
                self.index2rel[index] = rel

            self.num_rels = len(self.rels)
        else:
            # Create from scratch
            self.log("Creating rel tag index")
            create_rel_index = True

        # Create if necessary
        if create_tag_index or create_chunk_index or create_rel_index:
            # Collect data
            for sent in self.sents:
                self.num_words += len(sent)
                for dep, rel, head in sent:
                    form, word, tag = dep
                    if create_tag_index:
                        self.tags.add(tag)
                    if create_rel_index:
                        self.rels.add(rel)
                    # if create_chunk_index:
                    #     self.chunks.add(chunk)

            # Create POS tag database
            if create_tag_index:
                file_tags = open(tag_path, "w")
                self.num_tags = len(self.tags)
                for index, tag in enumerate(self.tags):
                    self.index2tag[index] = tag
                    self.tag2index[tag] = index
                    file_tags.write("{} {}\n".format(tag, index))
                file_tags.close()

            # Create chunk tag database
            if create_chunk_index:
                file_chunks = open(chunk_path, "w")
                self.num_chunks = len(self.chunks)
                for index, chunk in enumerate(self.chunks):
                    self.index2chunk[index] = chunk
                    self.chunk2index[chunk] = index
                    file_chunks.write("{} {}\n".format(chunk, index))
                file_chunks.close()

            # Create rel tag database
            if create_rel_index:
                file_rels = open(rel_path, "w")
                self.num_rels = len(self.rels)
                for index, rel in enumerate(self.rels):
                    self.index2rel[index] = rel
                    self.rel2index[rel] = index
                    file_rels.write("{} {}\n".format(rel, index))
                file_rels.close()

    ##########################################################################
    def log(self, message):
        print("Model [{}]:".format(self.name), message)