Exemplo n.º 1
0
class MSTParserLSTM:
    def __init__(self, vocab, options):
        import dynet as dy
        from feature_extractor import FeatureExtractor
        global dy
        self.model = dy.ParameterCollection()
        self.trainer = dy.AdamTrainer(self.model, alpha=options.learning_rate)
        self.activations = {'tanh': dy.tanh, 'sigmoid': dy.logistic, 'relu':
                            dy.rectify, 'tanh3': (lambda x:
                                                  dy.tanh(dy.cwise_multiply(dy.cwise_multiply(x, x), x)))}
        self.activation = self.activations[options.activation]
        self.costaugFlag = options.costaugFlag
        self.feature_extractor = FeatureExtractor(self.model, options, vocab)
        self.labelsFlag=options.labelsFlag
        mlp_in_dims = options.lstm_output_size*2

        self.unlabeled_MLP = biMLP(self.model, mlp_in_dims, options.mlp_hidden_dims,
                                 options.mlp_hidden2_dims, 1, self.activation)
        if self.labelsFlag:
            self.labeled_MLP = biMLP(self.model, mlp_in_dims, options.mlp_hidden_dims,
                               options.mlp_hidden2_dims,len(self.feature_extractor.irels),self.activation)

        self.proj = options.proj


    def  __getExpr(self, sentence, i, j, train):
        output = self.unlabeled_MLP(sentence[i].vec, sentence[j].vec)
        return output


    def __evaluate(self, sentence, train):
        exprs = [ [self.__getExpr(sentence, i, j, train) for j in xrange(len(sentence))] for i in xrange(len(sentence)) ]
        scores = np.array([ [output.scalar_value() for output in exprsRow] for exprsRow in exprs ])
        return scores, exprs


    def __evaluateLabel(self, sentence, i, j):
        output = self.labeled_MLP(sentence[i].vec, sentence[j].vec)
        return output.value(), output


    def Save(self, filename):
        self.model.save(filename)


    def Load(self, filename):
        self.model.populate(filename)


    def Predict(self, treebanks, datasplit, options):
        char_map = {}
        if options.char_map_file:
            char_map_fh = codecs.open(options.char_map_file,encoding='utf-8')
            char_map = json.loads(char_map_fh.read())
        # should probably use a namedtuple in get_vocab to make this prettier
        _, test_words, test_chars, _, _, _, test_treebanks, test_langs = utils.get_vocab(treebanks,datasplit,char_map)

        # get external embeddings for the set of words and chars in the
        # test vocab but not in the training vocab
        test_embeddings = defaultdict(lambda: {})
        if options.word_emb_size > 0 and options.ext_word_emb_file:
            new_test_words = \
                    set(test_words) - self.feature_extractor.words.viewkeys()

            print "Number of OOV word types at test time: %i (out of %i)" % (
                len(new_test_words), len(test_words))

            if len(new_test_words) > 0:
                # no point loading embeddings if there are no words to look for
                for lang in test_langs:
                    embeddings = utils.get_external_embeddings(
                        options,
                        emb_file=options.ext_word_emb_file,
                        lang=lang,
                        words=new_test_words
                    )
                    test_embeddings["words"].update(embeddings)
                    if len(test_langs) > 1 and test_embeddings["words"]:
                        print "External embeddings found for %i words "\
                                "(out of %i)" % \
                                (len(test_embeddings["words"]), len(new_test_words))

        if options.char_emb_size > 0:
            new_test_chars = \
                    set(test_chars) - self.feature_extractor.chars.viewkeys()
            print "Number of OOV char types at test time: %i (out of %i)" % (
                len(new_test_chars), len(test_chars))

            if len(new_test_chars) > 0:
                for lang in test_langs:
                    embeddings = utils.get_external_embeddings(
                        options,
                        emb_file=options.ext_char_emb_file,
                        lang=lang,
                        words=new_test_chars,
                        chars=True
                    )
                    test_embeddings["chars"].update(embeddings)
                    if len(test_langs) > 1 and test_embeddings["chars"]:
                        print "External embeddings found for %i chars "\
                                "(out of %i)" % \
                                (len(test_embeddings["chars"]), len(new_test_chars))

        data = utils.read_conll_dir(treebanks,datasplit,char_map=char_map)
        for iSentence, osentence in enumerate(data,1):
            sentence = deepcopy(osentence)
            self.feature_extractor.Init(options)
            conll_sentence = [entry for entry in sentence if isinstance(entry, utils.ConllEntry)]
            self.feature_extractor.getWordEmbeddings(conll_sentence, False, options, test_embeddings)

            scores, exprs = self.__evaluate(conll_sentence, True)
            if self.proj:
                heads = decoder.parse_proj(scores)
                #LATTICE solution to multiple roots
                # see https://github.com/jujbob/multilingual-bist-parser/blob/master/bist-parser/bmstparser/src/mstlstm.py
                ## ADD for handling multi-roots problem
                rootHead = [head for head in heads if head==0]
                if len(rootHead) != 1:
                    print "it has multi-root, changing it for heading first root for other roots"
                    rootHead = [seq for seq, head in enumerate(heads) if head == 0]
                    for seq in rootHead[1:]:heads[seq] = rootHead[0]
                ## finish to multi-roots

            else:
                heads = chuliu_edmonds_one_root(scores.T)

            for entry, head in zip(conll_sentence, heads):
                entry.pred_parent_id = head
                entry.pred_relation = '_'

            if self.labelsFlag:
                for modifier, head in enumerate(heads[1:]):
                    scores, exprs = self.__evaluateLabel(conll_sentence, head, modifier+1)
                    conll_sentence[modifier+1].pred_relation = self.feature_extractor.irels[max(enumerate(scores), key=itemgetter(1))[0]]

            dy.renew_cg()

            #keep in memory the information we need, not all the vectors
            oconll_sentence = [entry for entry in osentence if isinstance(entry, utils.ConllEntry)]
            for tok_o, tok in zip(oconll_sentence, conll_sentence):
                tok_o.pred_relation = tok.pred_relation
                tok_o.pred_parent_id = tok.pred_parent_id
            yield osentence

    def Train(self, trainData, options):
        errors = 0
        batch = 0
        eloss = 0.0
        mloss = 0.0
        eerrors = 0
        lerrors = 0
        etotal = 0
        beg = start = time.time()

        random.shuffle(trainData) # in certain cases the data will already have been shuffled after being read from file or while creating dev data

        errs = []
        lerrs = []
        eeloss = 0.0
        self.feature_extractor.Init(options)

        for iSentence, sentence in enumerate(trainData,1):
            if iSentence % 100 == 0 and iSentence != 0:
                loss_message = 'Processing sentence number: %d'%iSentence + \
                        ' Loss: %.3f'%(eloss / etotal)+ \
                        ' Errors: %.3f'%((float(eerrors)) / etotal)+\
                        ' Labeled Errors: %.3f'%(float(lerrors) / etotal)+\
                        ' Time: %.2gs'%(time.time()-start)
                print loss_message
                start = time.time()
                eerrors = 0
                eloss = 0.0
                etotal = 0
                lerrors = 0
                ltotal = 0

            conll_sentence = [entry for entry in sentence if isinstance(entry, utils.ConllEntry)]
            self.feature_extractor.getWordEmbeddings(conll_sentence, True, options)

            scores, exprs = self.__evaluate(conll_sentence, True)
            gold = [entry.parent_id for entry in conll_sentence]
            if self.proj:
                heads = decoder.parse_proj(scores, gold if self.costaugFlag else None)
            else:
                if self.costaugFlag:
                    #augment the score of non-gold arcs
                    for i in range(len(scores)):
                        for j in range(len(scores)):
                            if gold[j] != i:
                                scores[i][j] += 1.
                heads = chuliu_edmonds_one_root(scores.T)
                heads[0] = -1

            if self.labelsFlag:
                for modifier, head in enumerate(gold[1:]):
                    rscores, rexprs = self.__evaluateLabel(conll_sentence, head, modifier+1)
                    goldLabelInd = self.feature_extractor.rels[conll_sentence[modifier+1].relation]
                    wrongLabelInd = max(((l, scr) for l, scr in enumerate(rscores) if l != goldLabelInd), key=itemgetter(1))[0]
                    if rscores[goldLabelInd] < rscores[wrongLabelInd] + 1:
                        lerrs.append(rexprs[wrongLabelInd] - rexprs[goldLabelInd])
                        lerrors += 1 #not quite right but gives some indication

            e = sum([1 for h, g in zip(heads[1:], gold[1:]) if h != g])
            eerrors += e
            if e > 0:
                loss = [(exprs[h][i] - exprs[g][i]) for i, (h,g) in enumerate(zip(heads, gold)) if h != g]
                eloss += dy.esum(loss).scalar_value()
                mloss += dy.esum(loss).scalar_value()
                errs.extend(loss)

            etotal += len(conll_sentence)

            if iSentence % 1 == 0 or len(errs) > 0 or len(lerrs) > 0:
                eeloss = 0.0

                if len(errs) > 0 or len(lerrs) > 0:
                    eerrs = (dy.esum(errs + lerrs))
                    eerrs.scalar_value()
                    eerrs.backward()
                    self.trainer.update()
                    errs = []
                    lerrs = []

                dy.renew_cg()

        if len(errs) > 0:
            eerrs = (dy.esum(errs + lerrs))
            eerrs.scalar_value()
            eerrs.backward()
            self.trainer.update()

            errs = []
            lerrs = []
            eeloss = 0.0

            dy.renew_cg()

        self.trainer.update()
        print "Loss: ", mloss/iSentence
        print "Total Training Time: %.2gs"%(time.time()-beg)
Exemplo n.º 2
0
class ArcHybridLSTM:
    def __init__(self, words, pos, rels, cpos, langs, w2i, ch, options):
        """
        0 = LA, 1 = RA, 2 = SH, 3 = SW
        """

        import dynet as dy  # import here so we don't load Dynet if just running parser.py --help for example
        global dy

        self.model = dy.ParameterCollection()
        self.trainer = dy.AdamTrainer(self.model, alpha=options.learning_rate)

        self.activations = {
            'tanh':
            dy.tanh,
            'sigmoid':
            dy.logistic,
            'relu':
            dy.rectify,
            'tanh3':
            (lambda x: dy.tanh(dy.cwise_multiply(dy.cwise_multiply(x, x), x)))
        }
        self.activation = self.activations[options.activation]

        self.oracle = options.oracle
        self.shareMLP = options.shareMLP
        self.config_lembed = options.lembed_config

        #vectors used
        self.headFlag = options.headFlag
        self.rlMostFlag = options.rlMostFlag
        self.rlFlag = options.rlFlag
        self.k = options.k

        #dimensions depending on extended features
        self.nnvecs = (1 if self.headFlag else 0) + (2 if self.rlFlag
                                                     or self.rlMostFlag else 0)
        self.feature_extractor = FeatureExtractor(self.model, words, rels,
                                                  langs, w2i, ch, self.nnvecs,
                                                  options)
        self.irels = self.feature_extractor.irels

        #mlps
        mlp_in_dims = options.lstm_output_size * 2 * self.nnvecs * (self.k + 1)
        if self.config_lembed:
            mlp_in_dims += options.lang_emb_size

        h1 = options.mlp_hidden_dims
        h2 = options.mlp_hidden2_dims
        if not options.multiling or self.shareMLP:
            self.unlabeled_MLP = MLP(self.model, mlp_in_dims, h1, h2, 4,
                                     self.activation)
            self.labeled_MLP = MLP(self.model, mlp_in_dims, h1, h2,
                                   2 * len(rels) + 2, self.activation)
        else:
            self.labeled_mlpdict = {}
            for lang in self.feature_extractor.langs:
                self.labeled_mlpdict[lang] = MLP(self.model, mlp_in_dims, h1,
                                                 h2, 2 * len(rels) + 2,
                                                 self.activation)

            self.unlabeled_mlpdict = {}
            for lang in self.feature_extractor.langs:
                self.unlabeled_mlpdict[lang] = MLP(self.model, mlp_in_dims, h1,
                                                   h2, 4, self.activation)

    def __evaluate(self, stack, buf, train):
        """
        Output: a list of tuples per transition:
        [left arc, right arc, shift, swap]

        output[i] = (rel, transition, score1, score2)
        rel = None for shift and swap

        output[i][j][2] ~= output[i][j][3] except the latter is a dynet
        expression used in the loss, the first is used in rest of training
        TODO: it is ugly and a headache to debug...
        """

        #feature rep
        lang = buf.roots[0].language_id
        if not self.feature_extractor.multiling or self.feature_extractor.shareWordLookup:
            empty = self.feature_extractor.empty
        else:
            empty = self.feature_extractor.emptyVecs[lang]
        topStack = [
            stack.roots[-i - 1].lstms if len(stack) > i else [empty]
            for i in xrange(self.k)
        ]
        topBuffer = [
            buf.roots[i].lstms if len(buf) > i else [empty] for i in xrange(1)
        ]

        input = dy.concatenate(list(chain(*(topStack + topBuffer))))
        if self.config_lembed:
            langvec = self.feature_extractor.langslookup[
                self.feature_extractor.langs[lang]]
            input = dy.concatenate([input, langvec])
        if not self.feature_extractor.multiling or self.shareMLP:
            routput = self.labeled_MLP(input)
            output = self.unlabeled_MLP(input)
        else:
            routput = self.labeled_mlpdict[lang](input)
            output = self.unlabeled_mlpdict[lang](input)

        scrs, uscrs = routput.value(), output.value()

        #transition conditions
        left_arc_conditions = len(stack) > 0
        right_arc_conditions = len(stack) > 1
        shift_conditions = buf.roots[0].id != 0
        swap_conditions = len(
            stack) > 0 and stack.roots[-1].id < buf.roots[0].id

        if not train:
            #(avoiding the multiple roots problem: disallow left-arc from root
            #if stack has more than one element
            left_arc_conditions = left_arc_conditions and not (
                buf.roots[0].id == 0 and len(stack) > 1)

        uscrs0 = uscrs[0]  #shift
        uscrs1 = uscrs[1]  #swap
        uscrs2 = uscrs[2]  #left-arc
        uscrs3 = uscrs[3]  #right-arc

        if train:
            output0 = output[0]
            output1 = output[1]
            output2 = output[2]
            output3 = output[3]

            ret = [[(rel, 0, scrs[2 + j * 2] + uscrs2,
                     routput[2 + j * 2] + output2)
                    for j, rel in enumerate(self.irels)]
                   if left_arc_conditions else [],
                   [(rel, 1, scrs[3 + j * 2] + uscrs3,
                     routput[3 + j * 2] + output3)
                    for j, rel in enumerate(self.irels)]
                   if right_arc_conditions else [],
                   [(None, 2, scrs[0] + uscrs0,
                     routput[0] + output0)] if shift_conditions else [],
                   [(None, 3, scrs[1] + uscrs1,
                     routput[1] + output1)] if swap_conditions else []]
        else:
            s1, r1 = max(zip(scrs[2::2], self.irels))
            s2, r2 = max(zip(scrs[3::2], self.irels))
            s1 += uscrs2
            s2 += uscrs3
            ret = [[(r1, 0, s1)] if left_arc_conditions else [],
                   [(r2, 1, s2)] if right_arc_conditions else [],
                   [(None, 2, scrs[0] + uscrs0)] if shift_conditions else [],
                   [(None, 3, scrs[1] + uscrs1)] if swap_conditions else []]
        return ret

    def Save(self, filename):
        print 'Saving model to ' + filename
        self.model.save(filename)

    def Load(self, filename):
        print 'Loading model from ' + filename
        self.model.populate(filename)

    def apply_transition(self, best, stack, buf, hoffset):
        if best[1] == 2:
            #SHIFT
            stack.roots.append(buf.roots[0])
            del buf.roots[0]

        elif best[1] == 3:
            #SWAP
            child = stack.roots.pop()
            buf.roots.insert(1, child)

        elif best[1] == 0:
            #LEFT-ARC
            child = stack.roots.pop()
            parent = buf.roots[0]

            #predict rel and label
            child.pred_parent_id = parent.id
            child.pred_relation = best[0]

        elif best[1] == 1:
            #RIGHT-ARC
            child = stack.roots.pop()
            parent = stack.roots[-1]

            child.pred_parent_id = parent.id
            child.pred_relation = best[0]

        #update the representation of head for attaching transitions
        if best[1] == 0 or best[1] == 1:
            #deepest leftmost/rightmost child
            if self.rlMostFlag:
                parent.lstms[best[1] + hoffset] = child.lstms[best[1] +
                                                              hoffset]
            #leftmost/rightmost direct child
            if self.rlFlag:
                parent.lstms[best[1] + hoffset] = child.vec

    def calculate_cost(self, scores, s0, s1, b, beta, stack_ids):
        if len(scores[0]) == 0:
            left_cost = 1
        else:
            left_cost = len(
                s0[0].rdeps) + int(s0[0].parent_id != b[0].id
                                   and s0[0].id in s0[0].parent_entry.rdeps)

        if len(scores[1]) == 0:
            right_cost = 1
        else:
            right_cost = len(
                s0[0].rdeps) + int(s0[0].parent_id != s1[0].id
                                   and s0[0].id in s0[0].parent_entry.rdeps)

        if len(scores[2]) == 0:
            shift_cost = 1
            shift_case = 0
        elif len([
                item for item in beta
                if item.projective_order < b[0].projective_order
                and item.id > b[0].id
        ]) > 0:
            shift_cost = 0
            shift_case = 1
        else:
            shift_cost = len([d for d in b[0].rdeps if d in stack_ids]) + int(
                len(s0) > 0 and b[0].parent_id in stack_ids[:-1]
                and b[0].id in b[0].parent_entry.rdeps)
            shift_case = 2

        if len(scores[3]) == 0:
            swap_cost = 1
        elif s0[0].projective_order > b[0].projective_order:
            swap_cost = 0
            #disable all the others
            left_cost = right_cost = shift_cost = 1
        else:
            swap_cost = 1

        costs = (left_cost, right_cost, shift_cost, swap_cost, 1)
        return costs, shift_case

    def Predict(self, data):
        reached_max_swap = 0
        for iSentence, osentence in enumerate(data, 1):
            sentence = deepcopy(osentence)
            reached_swap_for_i_sentence = False
            max_swap = 2 * len(sentence)
            iSwap = 0
            self.feature_extractor.Init()
            conll_sentence = [
                entry for entry in sentence
                if isinstance(entry, utils.ConllEntry)
            ]
            conll_sentence = conll_sentence[1:] + [conll_sentence[0]]
            self.feature_extractor.getWordEmbeddings(conll_sentence, False)
            stack = ParseForest([])
            buf = ParseForest(conll_sentence)

            hoffset = 1 if self.headFlag else 0

            lang = conll_sentence[1].language_id
            for root in conll_sentence:
                root.lstms = [root.vec] if self.headFlag else []
                if not self.feature_extractor.multiling or self.feature_extractor.shareWordLookup:
                    root.lstms += [
                        self.feature_extractor.paddingVec
                        for _ in range(self.nnvecs - hoffset)
                    ]
                else:
                    root.lstms += [
                        self.feature_extractor.paddingVecs[lang]
                        for _ in range(self.nnvecs - hoffset)
                    ]

            while not (len(buf) == 1 and len(stack) == 0):
                scores = self.__evaluate(stack, buf, False)
                best = max(
                    chain(*(scores if iSwap < max_swap else scores[:3])),
                    key=itemgetter(2))
                if iSwap == max_swap and not reached_swap_for_i_sentence:
                    reached_max_swap += 1
                    reached_swap_for_i_sentence = True
                    print "reached max swap in %d out of %d sentences" % (
                        reached_max_swap, iSentence)
                self.apply_transition(best, stack, buf, hoffset)
                if best[1] == 3:
                    iSwap += 1

            dy.renew_cg()

            #keep in memory the information we need, not all the vectors
            oconll_sentence = [
                entry for entry in osentence
                if isinstance(entry, utils.ConllEntry)
            ]
            oconll_sentence = oconll_sentence[1:] + [oconll_sentence[0]]
            for tok_o, tok in zip(oconll_sentence, conll_sentence):
                tok_o.pred_relation = tok.pred_relation
                tok_o.pred_parent_id = tok.pred_parent_id
            yield osentence

    def Train(self, trainData):
        mloss = 0.0
        eloss = 0.0
        eerrors = 0
        lerrors = 0
        etotal = 0
        ninf = -float('inf')

        beg = time.time()
        start = time.time()

        random.shuffle(
            trainData
        )  # in certain cases the data will already have been shuffled after being read from file or while creating dev data
        print "Length of training data: ", len(trainData)

        errs = []

        self.feature_extractor.Init()

        for iSentence, sentence in enumerate(trainData, 1):
            if iSentence % 100 == 0:
                loss_message = 'Processing sentence number: %d'%iSentence + \
                ' Loss: %.3f'%(eloss / etotal)+ \
                ' Errors: %.3f'%((float(eerrors)) / etotal)+\
                ' Labeled Errors: %.3f'%(float(lerrors) / etotal)+\
                ' Time: %.2gs'%(time.time()-start)
                print loss_message
                start = time.time()
                eerrors = 0
                eloss = 0.0
                etotal = 0
                lerrors = 0

            sentence = deepcopy(
                sentence
            )  # ensures we are working with a clean copy of sentence and allows memory to be recycled each time round the loop

            conll_sentence = [
                entry for entry in sentence
                if isinstance(entry, utils.ConllEntry)
            ]
            conll_sentence = conll_sentence[1:] + [conll_sentence[0]]
            self.feature_extractor.getWordEmbeddings(conll_sentence, True)
            stack = ParseForest([])
            buf = ParseForest(conll_sentence)
            hoffset = 1 if self.headFlag else 0
            lang = conll_sentence[1].language_id

            for root in conll_sentence:
                root.lstms = [root.vec] if self.headFlag else []
                if not self.feature_extractor.multiling or self.feature_extractor.shareWordLookup:
                    root.lstms += [
                        self.feature_extractor.paddingVec
                        for _ in range(self.nnvecs - hoffset)
                    ]
                else:
                    root.lstms += [
                        self.feature_extractor.paddingVecs[lang]
                        for _ in range(self.nnvecs - hoffset)
                    ]

            while not (len(buf) == 1 and len(stack) == 0):
                scores = self.__evaluate(stack, buf, True)

                #to ensure that we have at least one wrong operation
                scores.append([(None, 4, ninf, None)])

                stack_ids = [sitem.id for sitem in stack.roots]

                s1 = [stack.roots[-2]] if len(stack) > 1 else []
                s0 = [stack.roots[-1]] if len(stack) > 0 else []
                b = [buf.roots[0]] if len(buf) > 0 else []
                beta = buf.roots[1:] if len(buf) > 1 else []

                costs, shift_case = self.calculate_cost(
                    scores, s0, s1, b, beta, stack_ids)

                bestValid = list(
                    (s for s in chain(*scores) if costs[s[1]] == 0 and (
                        s[1] == 2 or s[1] == 3 or s[0] == s0[0].relation)))
                if len(bestValid) < 1:
                    print "===============dropping a sentence==============="
                    break

                bestValid = max(bestValid, key=itemgetter(2))
                bestWrong = max(
                    (s for s in chain(*scores) if costs[s[1]] != 0 or (
                        s[1] != 2 and s[1] != 3 and s[0] != s0[0].relation)),
                    key=itemgetter(2))

                #force swap
                if costs[3] == 0:
                    best = bestValid
                else:
                    #select a transition to follow
                    # + aggresive exploration
                    if bestWrong[1] == 3:
                        best = bestValid
                    else:
                        best = bestValid if (
                            (not self.oracle) or
                            (bestValid[2] - bestWrong[2] > 1.0) or
                            (bestValid[2] > bestWrong[2]
                             and random.random() > 0.1)) else bestWrong

                #updates for the dynamic oracle
                if best[1] == 2:
                    #SHIFT
                    if shift_case == 2:
                        if b[0].parent_entry.id in stack_ids[:-1] and b[
                                0].id in b[0].parent_entry.rdeps:
                            b[0].parent_entry.rdeps.remove(b[0].id)
                        blocked_deps = [
                            d for d in b[0].rdeps if d in stack_ids
                        ]
                        for d in blocked_deps:
                            b[0].rdeps.remove(d)

                elif best[1] == 0 or best[1] == 1:
                    #LA or RA
                    child = s0[0]
                    s0[0].rdeps = []
                    if s0[0].id in s0[0].parent_entry.rdeps:
                        s0[0].parent_entry.rdeps.remove(s0[0].id)

                self.apply_transition(best, stack, buf, hoffset)

                if bestValid[2] < bestWrong[2] + 1.0:
                    loss = bestWrong[3] - bestValid[3]
                    mloss += 1.0 + bestWrong[2] - bestValid[2]
                    eloss += 1.0 + bestWrong[2] - bestValid[2]
                    errs.append(loss)

                #labeled errors
                if best[1] != 2 and best[1] != 3 and (
                        child.pred_parent_id != child.parent_id
                        or child.pred_relation != child.relation):
                    lerrors += 1
                    #attachment error
                    if child.pred_parent_id != child.parent_id:
                        eerrors += 1

                if best[1] == 0 or best[1] == 2:
                    etotal += 1

            #footnote 8 in Eli's original paper
            if len(errs) > 50:  # or True:
                eerrs = dy.esum(errs)
                scalar_loss = eerrs.scalar_value()  #forward
                eerrs.backward()
                self.trainer.update()
                errs = []
                lerrs = []

                dy.renew_cg()
                self.feature_extractor.Init()

        if len(errs) > 0:
            eerrs = (dy.esum(errs))
            eerrs.scalar_value()
            eerrs.backward()
            self.trainer.update()

            errs = []
            lerrs = []

            dy.renew_cg()

        self.trainer.update()
        print "Loss: ", mloss / iSentence
        print "Total Training Time: %.2gs" % (time.time() - beg)
Exemplo n.º 3
0
class ArcHybridLSTM:
    def __init__(self, vocab, options):

        # import here so we don't load Dynet if just running parser.py --help for example
        from multilayer_perceptron import MLP
        from feature_extractor import FeatureExtractor
        import dynet as dy
        global dy

        global LEFT_ARC, RIGHT_ARC, SHIFT, SWAP
        LEFT_ARC, RIGHT_ARC, SHIFT, SWAP = 0,1,2,3

        self.model = dy.ParameterCollection()
        self.trainer = dy.AdamTrainer(self.model, alpha=options.learning_rate)

        self.activations = {'tanh': dy.tanh, 'sigmoid': dy.logistic, 'relu':
                            dy.rectify, 'tanh3': (lambda x:
                            dy.tanh(dy.cwise_multiply(dy.cwise_multiply(x, x), x)))}
        self.activation = self.activations[options.activation]

        self.oracle = options.oracle


        self.headFlag = options.headFlag
        self.rlMostFlag = options.rlMostFlag
        self.rlFlag = options.rlFlag
        self.k = options.k
        self.recursive_composition = options.use_recursive_composition
        #ugly hack

        #dimensions depending on extended features
        self.nnvecs = (1 if self.headFlag else 0) + (2 if self.rlFlag or self.rlMostFlag else 0) + (1 if self.recursive_composition else 0)
        self.feature_extractor = FeatureExtractor(self.model,options,vocab,self.nnvecs)
        self.irels = self.feature_extractor.irels

        if options.no_bilstms > 0:
            mlp_in_dims = options.lstm_output_size*2*self.nnvecs*(self.k+1)
        else:
            mlp_in_dims = options.lstm_input_size*self.nnvecs*(self.k+1)

        self.unlabeled_MLP = MLP(self.model, 'unlabeled', mlp_in_dims, options.mlp_hidden_dims,
                                 options.mlp_hidden2_dims, 4, self.activation)
        self.labeled_MLP = MLP(self.model, 'labeled' ,mlp_in_dims, options.mlp_hidden_dims,
                               options.mlp_hidden2_dims,2*len(self.irels)+2,self.activation)


    def __evaluate(self, stack, buf, train):
        """
        ret = [left arc,
               right arc
               shift]

        RET[i] = (rel, transition, score1, score2) for shift, l_arc and r_arc
         shift = 2 (==> rel=None) ; l_arc = 0; r_acr = 1

        ret[i][j][2] ~= ret[i][j][3] except the latter is a dynet
        expression used in the loss, the first is used in rest of training
        """

        #feature rep
        empty = self.feature_extractor.empty
        topStack = [ stack.roots[-i-1].lstms if len(stack) > i else [empty] for i in xrange(self.k) ]
        topBuffer = [ buf.roots[i].lstms if len(buf) > i else [empty] for i in xrange(1) ]

        input = dy.concatenate(list(chain(*(topStack + topBuffer))))
        output = self.unlabeled_MLP(input)
        routput = self.labeled_MLP(input)


        #scores, unlabeled scores
        scrs, uscrs = routput.value(), output.value()

        #transition conditions
        left_arc_conditions = len(stack) > 0
        right_arc_conditions = len(stack) > 1
        shift_conditions = buf.roots[0].id != 0
        swap_conditions = len(stack) > 0 and stack.roots[-1].id < buf.roots[0].id

        if not train:
            #(avoiding the multiple roots problem: disallow left-arc from root
            #if stack has more than one element
            left_arc_conditions = left_arc_conditions and not (buf.roots[0].id == 0 and len(stack) > 1)

        uscrs0 = uscrs[0]
        uscrs1 = uscrs[1]
        uscrs2 = uscrs[2]
        uscrs3 = uscrs[3]

        if train:
            output0 = output[0]
            output1 = output[1]
            output2 = output[2]
            output3 = output[3]


            ret = [ [ (rel, LEFT_ARC, scrs[2 + j * 2] + uscrs2, routput[2 + j * 2 ] + output2) for j, rel in enumerate(self.irels) ] if left_arc_conditions else [],
                   [ (rel, RIGHT_ARC, scrs[3 + j * 2] + uscrs3, routput[3 + j * 2 ] + output3) for j, rel in enumerate(self.irels) ] if right_arc_conditions else [],
                   [ (None, SHIFT, scrs[0] + uscrs0, routput[0] + output0) ] if shift_conditions else [] ,
                    [ (None, SWAP, scrs[1] + uscrs1, routput[1] + output1) ] if swap_conditions else [] ]
        else:
            s1,r1 = max(zip(scrs[2::2],self.irels))
            s2,r2 = max(zip(scrs[3::2],self.irels))
            s1 += uscrs2
            s2 += uscrs3
            ret = [ [ (r1, LEFT_ARC, s1) ] if left_arc_conditions else [],
                   [ (r2, RIGHT_ARC, s2) ] if right_arc_conditions else [],
                   [ (None, SHIFT, scrs[0] + uscrs0) ] if shift_conditions else [] ,
                    [ (None, SWAP, scrs[1] + uscrs1) ] if swap_conditions else [] ]
        return ret


    def Save(self, filename):
        print 'Saving model to ' + filename
        self.model.save(filename)

    def Load(self, filename):
        print 'Loading model from ' + filename
        self.model.populate(filename)


    def apply_transition(self,best,stack,buf,hoffset):
        if best[1] == SHIFT:
            #replace the lstm rep with the forward ones
            stack.roots.append(buf.roots[0])
            del buf.roots[0]

        elif best[1] == SWAP:
            child = stack.roots.pop()
            buf.roots.insert(1,child)

        elif best[1] == LEFT_ARC:
            child = stack.roots.pop()
            parent = buf.roots[0]

        elif best[1] == RIGHT_ARC:
            child = stack.roots.pop()
            parent = stack.roots[-1]

        if best[1] == LEFT_ARC or best[1] == RIGHT_ARC:
            transition = best[0].split(":")[0]
            #attach
            child.pred_parent_id = parent.id
            child.pred_relation = best[0]
            #update head representation
            if self.rlMostFlag:
                #deepest leftmost/rightmost descendant
                parent.lstms[best[1] + hoffset] = child.lstms[best[1] + hoffset]
            elif self.rlFlag:
                #leftmost/rightmost child
                parent.lstms[best[1] + hoffset] = child.vec
            elif self.recursive_composition and transition == 'aux':
                #ouch
                if self.recursive_composition == 'RecNN':
                    #this code should be out of here: a Layer in a DNN lib
                    composition_input = dy.affine_transform([self.feature_extractor.biasCompos.expr(),\
                                                             self.feature_extractor.dCompos.expr(), child.lstms[1] , \
                                                             self.feature_extractor.hCompos.expr(), parent.lstms[1],\
                                                             self.feature_extractor.rCompos.expr(), \
                                                             self.feature_extractor.deprel_lookup[self.feature_extractor.ideprel_dir[best[0],best[1]]]])
                    composed_rep = self.activation(composition_input)
                else:
                    #TreeLSTM
                    rel_vec = self.feature_extractor.deprel_lookup[self.feature_extractor.ideprel_dir[best[0],best[1]]]
                    vec = dy.concatenate([child.lstms[1],parent.lstms[1],rel_vec])
                    if not parent.lstm:
                        parent.lstm = self.feature_extractor.composLSTM.initial_state()

                    composed_rep = parent.lstm.add_input(vec).output()

                parent.lstms[1]  = composed_rep
                parent.composed_rep = composed_rep.value()

    def calculate_cost(self,scores,s0,s1,b,beta,stack_ids):
        if len(scores[LEFT_ARC]) == 0:
            left_cost = 1
        else:
            left_cost = len(s0[0].rdeps) + int(s0[0].parent_id != b[0].id and s0[0].id in s0[0].parent_entry.rdeps)


        if len(scores[RIGHT_ARC]) == 0:
            right_cost = 1
        else:
            right_cost = len(s0[0].rdeps) + int(s0[0].parent_id != s1[0].id and s0[0].id in s0[0].parent_entry.rdeps)


        if len(scores[SHIFT]) == 0:
            shift_cost = 1
            shift_case = 0
        elif len([item for item in beta if item.projective_order < b[0].projective_order and item.id > b[0].id ])> 0:
            shift_cost = 0
            shift_case = 1
        else:
            shift_cost = len([d for d in b[0].rdeps if d in stack_ids]) + int(len(s0)>0 and b[0].parent_id in stack_ids[:-1] and b[0].id in b[0].parent_entry.rdeps)
            shift_case = 2


        if len(scores[SWAP]) == 0 :
            swap_cost = 1
        elif s0[0].projective_order > b[0].projective_order:
            swap_cost = 0
            #disable all the others
            left_cost = right_cost = shift_cost = 1
        else:
            swap_cost = 1

        costs = (left_cost, right_cost, shift_cost, swap_cost,1)
        return costs,shift_case


    def oracle_updates(self,best,b,s0,stack_ids,shift_case):
        if best[1] == SHIFT:
            if shift_case ==2:
                if b[0].parent_entry.id in stack_ids[:-1] and b[0].id in b[0].parent_entry.rdeps:
                    b[0].parent_entry.rdeps.remove(b[0].id)
                blocked_deps = [d for d in b[0].rdeps if d in stack_ids]
                for d in blocked_deps:
                    b[0].rdeps.remove(d)

        elif best[1] == LEFT_ARC or best[1] == RIGHT_ARC:
            s0[0].rdeps = []
            if s0[0].id in s0[0].parent_entry.rdeps:
                s0[0].parent_entry.rdeps.remove(s0[0].id)

    def Predict(self, treebanks, datasplit, options):
        reached_max_swap = 0
        char_map = {}
        if options.char_map_file:
            char_map_fh = codecs.open(options.char_map_file,encoding='utf-8')
            char_map = json.loads(char_map_fh.read())
        # should probably use a namedtuple in get_vocab to make this prettier
        print "Collecting test data vocab"
        _, test_words, test_chars, _, _, _, test_treebanks, test_langs = utils.get_vocab(treebanks,datasplit,char_map)
        # get external embeddings for the set of words and chars in the test vocab but not in the training vocab
        test_embeddings = defaultdict(lambda:{})
        if options.word_emb_size > 0:
            new_test_words = set(test_words) - self.feature_extractor.words.viewkeys()
            print "Number of OOV word types at test time: %i (out of %i)"%(len(new_test_words),len(test_words))
            if len(new_test_words) > 0: # no point loading embeddings if there are no words to look for
                for lang in test_langs:
                    test_embeddings["words"].update(utils.get_external_embeddings(options,lang,new_test_words))
                if len(test_langs) > 1 and test_embeddings["words"]:
                    print "External embeddings found for %i words (out of %i)"%(len(test_embeddings["words"]),len(new_test_words))
        if options.char_emb_size > 0:
            new_test_chars = set(test_chars) - self.feature_extractor.chars.viewkeys()
            print "Number of OOV char types at test time: %i (out of %i)"%(len(new_test_chars),len(test_chars))
            if len(new_test_chars) > 0:
                for lang in test_langs:
                    test_embeddings["chars"].update(utils.get_external_embeddings(options,lang,new_test_chars,chars=True))
                if len(test_langs) > 1 and test_embeddings["chars"]:
                    print "External embeddings found for %i chars (out of %i)"%(len(test_embeddings["chars"]),len(new_test_chars))

        ts = time()
        data = utils.read_conll_dir(treebanks,datasplit,char_map=char_map)
        for iSentence, osentence in enumerate(data,1):
            sentence = deepcopy(osentence)
            reached_swap_for_i_sentence = False
            max_swap = 2*len(sentence)
            iSwap = 0
            self.feature_extractor.Init(options)
            conll_sentence = [entry for entry in sentence if isinstance(entry, utils.ConllEntry)]
            conll_sentence = conll_sentence[1:] + [conll_sentence[0]]
            self.feature_extractor.getWordEmbeddings(conll_sentence, False, options, test_embeddings)
            stack = ParseForest([])
            buf = ParseForest(conll_sentence)

            hoffset = 1 if self.headFlag else 0

            for root in conll_sentence:
                root.lstms = [root.vec] if self.headFlag else []
                if not self.recursive_composition:
                    root.lstms += [self.feature_extractor.paddingVec for _ in range(self.nnvecs - hoffset)]
                else:
                    root.lstms += [root.vec]
                    root.lstm = None #only necessary for treeLSTM case
                    root.composed_rep = root.vec.value()

            while not (len(buf) == 1 and len(stack) == 0):
                scores = self.__evaluate(stack, buf, False)
                best = max(chain(*(scores if iSwap < max_swap else scores[:3] )), key = itemgetter(2) )
                if iSwap == max_swap and not reached_swap_for_i_sentence:
                    reached_max_swap += 1
                    reached_swap_for_i_sentence = True
                    print "reached max swap in %d out of %d sentences"%(reached_max_swap, iSentence)
                self.apply_transition(best,stack,buf,hoffset)
                if best[1] == SWAP:
                    iSwap += 1

            #keep in memory the information we need, not all the vectors
            oconll_sentence = [entry for entry in osentence if isinstance(entry, utils.ConllEntry)]
            oconll_sentence = oconll_sentence[1:] + [oconll_sentence[0]]
            for tok_o, tok in zip(oconll_sentence, conll_sentence):
                tok_o.pred_relation = tok.pred_relation
                tok_o.pred_parent_id = tok.pred_parent_id
                if self.recursive_composition:
                    tok_o.composed_rep = tok.composed_rep
            yield osentence

            dy.renew_cg()

        print "Total prediction time: %.2fs"%(time()-ts)

    def Train(self, trainData, options):
        mloss = 0.0
        eloss = 0.0
        eerrors = 0
        lerrors = 0
        etotal = 0
        ninf = -float('inf')

        ts = time()
        start = ts

        random.shuffle(trainData) # in certain cases the data will already have been shuffled after being read from file or while creating dev data
        print "Length of training data: ", len(trainData)

        errs = []

        self.feature_extractor.Init(options)

        for iSentence, sentence in enumerate(trainData,1):
            if iSentence % 100 == 0:
                loss_message = 'Processing sentence number: %d'%iSentence + \
                ' Loss: %.3f'%(eloss / etotal)+ \
                ' Errors: %.3f'%((float(eerrors)) / etotal)+\
                ' Labeled Errors: %.3f'%(float(lerrors) / etotal)+\
                ' Time: %.2gs'%(time()-start)
                print loss_message
                start = time()
                eerrors = 0
                eloss = 0.0
                etotal = 0
                lerrors = 0

            sentence = deepcopy(sentence) # ensures we are working with a clean copy of sentence and allows memory to be recycled each time round the loop

            conll_sentence = [entry for entry in sentence if isinstance(entry, utils.ConllEntry)]
            conll_sentence = conll_sentence[1:] + [conll_sentence[0]]
            self.feature_extractor.getWordEmbeddings(conll_sentence, True, options)
            stack = ParseForest([])
            buf = ParseForest(conll_sentence)
            hoffset = 1 if self.headFlag else 0

            for root in conll_sentence:
                root.lstms = [root.vec] if self.headFlag else []
                #root.lstms += [self.feature_extractor.paddingVec for _ in range(self.nnvecs - hoffset)]
                if not self.recursive_composition:
                    root.lstms += [self.feature_extractor.paddingVec for _ in range(self.nnvecs - hoffset)]
                else:
                    root.lstms += [root.vec]
                    root.lstm = None


            while not (len(buf) == 1 and len(stack) == 0):
                scores = self.__evaluate(stack, buf, True)

                #to ensure that we have at least one wrong operation
                scores.append([(None, 4, ninf ,None)])

                stack_ids = [sitem.id for sitem in stack.roots]

                s1 = [stack.roots[-2]] if len(stack) > 1 else []
                s0 = [stack.roots[-1]] if len(stack) > 0 else []
                b = [buf.roots[0]] if len(buf) > 0 else []
                beta = buf.roots[1:] if len(buf) > 1 else []

                costs, shift_case = self.calculate_cost(scores,s0,s1,b,beta,stack_ids)

                bestValid = list(( s for s in chain(*scores) if costs[s[1]] == 0 and ( s[1] == SHIFT or s[1] == SWAP or  s[0] == s0[0].relation ) ))

                bestValid = max(bestValid, key=itemgetter(2))
                bestWrong = max(( s for s in chain(*scores) if costs[s[1]] != 0 or ( s[1] != SHIFT and s[1] != SWAP and s[0] != s0[0].relation ) ), key=itemgetter(2))

                #force swap
                if costs[SWAP]== 0:
                    best = bestValid
                else:
                    #select a transition to follow
                    # + aggresive exploration
                    #1: might want to experiment with that parameter
                    if bestWrong[1] == SWAP:
                        best = bestValid
                    else:
                        best = bestValid if ( (not self.oracle) or (bestValid[2] - bestWrong[2] > 1.0) or (bestValid[2] > bestWrong[2] and random.random() > 0.1) ) else bestWrong

                if best[1] == LEFT_ARC or best[1] ==RIGHT_ARC:
                    child = s0[0]

                #updates for the dynamic oracle
                if self.oracle:
                    self.oracle_updates(best,b,s0,stack_ids,shift_case)

                self.apply_transition(best,stack,buf,hoffset)

                if bestValid[2] < bestWrong[2] + 1.0:
                    loss = bestWrong[3] - bestValid[3]
                    mloss += 1.0 + bestWrong[2] - bestValid[2]
                    eloss += 1.0 + bestWrong[2] - bestValid[2]
                    errs.append(loss)

                #labeled errors
                if best[1] == LEFT_ARC or best[1] ==RIGHT_ARC:
                    if (child.pred_parent_id != child.parent_id or child.pred_relation != child.relation):
                        lerrors += 1
                        #attachment error
                        if child.pred_parent_id != child.parent_id:
                            eerrors += 1

                #??? when did this happen and why?
                if best[1] == 0 or best[1] == 2:
                    etotal += 1

            #footnote 8 in Eli's original paper
            if len(errs) > 50: # or True:
                eerrs = dy.esum(errs)
                scalar_loss = eerrs.scalar_value() #forward
                eerrs.backward()
                self.trainer.update()
                errs = []
                lerrs = []

                dy.renew_cg()
                self.feature_extractor.Init(options)

        if len(errs) > 0:
            eerrs = (dy.esum(errs))
            eerrs.scalar_value()
            eerrs.backward()
            self.trainer.update()

            errs = []
            lerrs = []

            dy.renew_cg()

        self.trainer.update()
        print "Loss: ", mloss/iSentence
        print "Total training time: %.2fs"%(time()-ts)
Exemplo n.º 4
0
class ArcHybridLSTM:
    def __init__(self, vocab, options):

        # import here so we don't load Dynet if just running parser.py --help for example
        from multilayer_perceptron import MLP
        from feature_extractor import FeatureExtractor
        import dynet as dy
        global dy

        global LEFT_ARC, RIGHT_ARC, SHIFT, SWAP
        LEFT_ARC, RIGHT_ARC, SHIFT, SWAP = 0, 1, 2, 3

        self.model = dy.ParameterCollection()
        self.trainer = dy.AdamTrainer(self.model, alpha=options.learning_rate)

        self.activations = {
            'tanh':
            dy.tanh,
            'sigmoid':
            dy.logistic,
            'relu':
            dy.rectify,
            'tanh3':
            (lambda x: dy.tanh(dy.cwise_multiply(dy.cwise_multiply(x, x), x)))
        }
        self.activation = self.activations[options.activation]

        self.oracle = options.oracle

        self.headFlag = options.headFlag
        self.rlMostFlag = options.rlMostFlag
        self.rlFlag = options.rlFlag
        self.k = options.k
        self.distances = 4  # probe looks at distances between tokens ahead, considering distances:
        # normalized by the smallest, among:
        # s0 - b0
        # s0 - b1
        # b0 - closest bi: if < s0-b0, do a Shift
        # closest si - b0	: if ~= s0-b0, do a reduce

        #dimensions depending on extended features
        self.nnvecs = (1 if self.headFlag else 0) + (2 if self.rlFlag
                                                     or self.rlMostFlag else 0)
        self.feature_extractor = FeatureExtractor(self.model, options, vocab,
                                                  self.nnvecs)
        self.irels = self.feature_extractor.irels

        if options.no_bilstms > 0:  # number of bilistms
            mlp_in_dims = options.lstm_output_size * 2 * self.nnvecs * (
                self.k + 1)
        else:
            mlp_in_dims = self.feature_extractor.lstm_input_size * self.nnvecs * (
                self.k + 1)

        # use attention
        if options.bert and options.attention:
            # add attention vectors for stack to top buf and viceversa
            attention_size = self.k * 2
            # all layers
            #layers = self.feature_extractor.bert.model.config.num_hidden_layers
            #attention_size = layers * layers * self.k # * 2
            mlp_in_dims += attention_size

        # Sartiano
        if options.distance_probe_conf:
            print('Distance Probe enabled', file=sys.stderr)
            from distance_probe import DistanceProbe
            self.distance_probe = DistanceProbe(options.distance_probe_conf,
                                                options.dynet_seed)
            mlp_in_dims += self.distances
        else:
            self.distance_probe = None

        self.attention_indices = [
            int(x) for x in options.attention.split(',')
        ] if options.attention else []

        self.unlabeled_MLP = MLP(self.model, 'unlabeled', mlp_in_dims,
                                 options.mlp_hidden_dims,
                                 options.mlp_hidden2_dims, SWAP + 1,
                                 self.activation)
        self.labeled_MLP = MLP(self.model, 'labeled', mlp_in_dims,
                               options.mlp_hidden_dims,
                               options.mlp_hidden2_dims,
                               2 * len(self.irels) + 2, self.activation)
        print('MLP size: (%d, %d)' % (mlp_in_dims, options.mlp_hidden_dims),
              file=sys.stderr)

    def __evaluate(self, stack, buf, train, attn_maps, dst_matrix):
        """
        :param attn: a matrix [layers, heads, sent_len, sent_len]
        :param dst_matrix: distance matrix computed by syntax probe.

        ret = [left arc,
               right arc,
               shift,
	       swap]

        RET[i] = (rel, transition, score1, score2) for shift, l_arc and r_arc
         shift = 2 (==> rel=None) ; l_arc = 0; r_acr = 1

        ret[i][j][2] ~= ret[i][j][3] except the latter is a dynet
        expression used in the loss, the first is used in rest of training
        """

        #feature rep
        empty = self.feature_extractor.empty
        topStack = [
            stack.roots[-i - 1].lstms if len(stack) > i else [empty]
            for i in range(self.k)
        ]
        topBuffer = [
            buf.roots[i].lstms if len(buf) > i else [empty] for i in range(1)
        ]

        # check if not None otherwise is ambiguous
        if dst_matrix is not None:
            # s0 - b0
            # s0 - b1
            # b0 - closest bi: if < s0-b0, do a Shift
            # closest si - b0	: if ~= s0-b0, do a reduce
            if len(stack) and len(buf) > 1:  # root token is at end of buffer
                b0_id = buf.roots[0].id
                # graph distances of b0
                # -1 since tokens are numbered from 1
                b0_distances = dst_matrix[b0_id - 1]
                s0_id = stack.roots[-1].id
                # graph distances of s0
                s0_distances = dst_matrix[s0_id - 1]

                s0_b0 = b0_distances[s0_id - 1]
                s0_b1 = s0_distances[buf.roots[1].id - 1]
                if b0_id < len(
                        s0_distances) - 1:  # at least one token after b0
                    s0_bi = min(s0_distances[b0_id:])
                else:
                    s0_bi = 100
                if len(stack) > 1:
                    si_b0 = min(
                        [b0_distances[tok.id] for tok in stack.roots[:-1]])
                else:
                    si_b0 = 100

                rel_dist = [s0_b0, s0_b1, s0_bi, si_b0]
            else:
                rel_dist = [100] * self.distances
            rel_dist = [dy.inputTensor(rel_dist)]

            # #noClose = [dy.scalarInput(-1, self.device)] * self.feature_extractor.pretrained_embeddings_size
            # topStackDistances = [10000] * self.k
            # if len(buf):        # sanity check: always true
            #     # -1 since tokens are numbered from 1
            #     # graph distances of top k elements on stack to top buffer
            #     topBufferDistances = dst_matrix[buf.roots[0].id-1]
            #     for i, entry in enumerate(stack.roots[:self.k]):
            #         # -1 since token id start from 1
            #         distance = topBufferDistances[entry.id-1]
            #         # backwards
            #         topStackDistances[-i-1] = distance
            #     tbd_np = np.array(topBufferDistances)
            #     tbd_np = tbd_np / tbd_np.min() # normalization
            #     topBufferDistances = [dy.inputTensor(tbd_np)]

            # # lookahead. Sartiano
            # # add token at distance 1 to topStack, appearing after topBuf
            # closeToStack = noClose
            # if len(stack):
            #     topStackId = stack.roots[-1].id
            #     for entry in buf.roots[1:]: # skip topBuf
            #         bufTok = entry.id-1
            #         if dst_matrix[bufTok, topStackId-1] == 1:
            #             closeToStack = buf.roots[bufTok].vecs[self.feature_extractor.pretrained_embeddings]
            #             break
            # closeToBuf = noClose
            # topBufId = buf.roots[0].id
            # for entry in buf.roots[1:]: # skip topBuf
            #     bufTok = entry.id-1
            #     if dst_matrix[bufTok, topBufId-1] == 1:
            #         closeToBuf = buf.roots[bufTok].vecs[self.feature_extractor.pretrained_embeddings]
            #         break
            # # Add the embeddings of closeToStack and closeTobuf:
            # deprels = [ closeToStack, closeToBuf ]

            # input = dy.concatenate(list(chain(*(topStack + topBuffer + deprels + topStackDistances))))
            #input = dy.concatenate(list(chain(*(topStack + topBuffer + topStackDistances))))
            input = dy.concatenate(
                list(chain(*(topStack + topBuffer + rel_dist))))
        elif self.attention_indices:
            # get attention vectors for the stack and buf tokens wrt the top buf
            attentions = [dy.scalarInput(-10.0)] * self.k * 2
            # all layers
            #layers = self.feature_extractor.bert.model.config.num_hidden_layers
            #attentions = [[dy.scalarInput(-10.0)] * layers * layers] * self.k # 2
            if len(buf):  # sanity check: always true
                topBufId = buf.roots[0].id
                # attn of top k elements on stack to top buffer
                for i, entry in enumerate(stack.roots[-self.k:]):
                    # -1 since tokens are numbered from 1
                    layer, head = self.attention_indices
                    attn = attn_maps[layer, head, entry.id - 1, topBufId - 1]
                    attentions[i] = dy.scalarInput(attn)
                    #attn = attn_maps[:, :, topBufId-1, entry.id-1].flatten().cpu().numpy()
                    #attentions[i] = dy.inputTensor(attn) #, self.device)
                # add attention from buf elements to top as well
                for i, entry in enumerate(buf.roots[1:self.k]):  # skip top
                    # -1 since tokens are numbered from 1
                    layer, head = self.attention_indices
                    attn = attn_maps[layer, head, entry.id - 1, topBufId - 1]
                    attentions[self.k + i] = dy.scalarInput(
                        attn)  #, self.device)
            # use all layers:
            # attentions = [[dy.scalarInput(-10.0)] * layers * layers] * self.k # 2
            # if len(buf):        # sanity check: always true
            #     topBufId = buf.roots[0].id
            #     # attn of top k elements on stack to top buffer and viceversa
            #     for i, entry in enumerate(stack.roots[:-self.k]):
            #         # -1 since tokens are numbered from 1
            #         #layer, head = self.attention_indices
            #         #attn = attn_maps[:, :, entry.id-1, topBufId-1].flatten().cpu().numpy()
            #         #attentions[i] = dy.inputTensor(attn)
            #         attn = attn_maps[:, :, topBufId-1, entry.id-1].flatten().cpu().numpy()
            #         attentions[i] = dy.inputTensor(attn) #, self.device)
            input = dy.concatenate(
                list(chain(*(topStack + topBuffer + attentions))))
        else:
            input = dy.concatenate(list(chain(*(topStack + topBuffer))))

        output = self.unlabeled_MLP(input)
        routput = self.labeled_MLP(input)

        #scores, unlabeled scores
        scrs, uscrs = routput.value(), output.value()

        #transition conditions
        left_arc_conditions = len(stack) > 0
        right_arc_conditions = len(stack) > 1
        shift_conditions = buf.roots[0].id != 0
        swap_conditions = len(
            stack) > 0 and stack.roots[-1].id < buf.roots[0].id

        if not train:
            #(avoiding the multiple roots problem: disallow left-arc from root
            #if stack has more than one element
            left_arc_conditions = left_arc_conditions and not (
                buf.roots[0].id == 0 and len(stack) > 1)

        uscrs0 = uscrs[0]
        uscrs1 = uscrs[1]
        uscrs2 = uscrs[2]
        uscrs3 = uscrs[3]

        if train:
            output0 = output[0]
            output1 = output[1]
            output2 = output[2]
            output3 = output[3]

            ret = [[(rel, LEFT_ARC, scrs[2 + j * 2] + uscrs2,
                     routput[2 + j * 2] + output2)
                    for j, rel in enumerate(self.irels)]
                   if left_arc_conditions else [],
                   [(rel, RIGHT_ARC, scrs[3 + j * 2] + uscrs3,
                     routput[3 + j * 2] + output3)
                    for j, rel in enumerate(self.irels)]
                   if right_arc_conditions else [],
                   [(None, SHIFT, scrs[0] + uscrs0,
                     routput[0] + output0)] if shift_conditions else [],
                   [(None, SWAP, scrs[1] + uscrs1,
                     routput[1] + output1)] if swap_conditions else []]
        else:
            s1, r1 = max(zip(scrs[2::2], self.irels))
            s2, r2 = max(zip(scrs[3::2], self.irels))
            s1 += uscrs2
            s2 += uscrs3
            ret = [[(r1, LEFT_ARC, s1)] if left_arc_conditions else [],
                   [(r2, RIGHT_ARC, s2)] if right_arc_conditions else [],
                   [(None, SHIFT,
                     scrs[0] + uscrs0)] if shift_conditions else [],
                   [(None, SWAP, scrs[1] + uscrs1)] if swap_conditions else []]
        return ret

    def Save(self, filename):
        print('Saving model to ' + filename, file=sys.stderr)
        self.model.save(filename)

    def Load(self, filename):
        print('Loading model from ' + filename, file=sys.stderr)
        self.model.populate(filename)

    def apply_transition(self, best, stack, buf, hoffset):
        if best[1] == SHIFT:
            stack.roots.append(buf.roots[0])
            del buf.roots[0]

        elif best[1] == SWAP:
            child = stack.roots.pop()
            buf.roots.insert(1, child)

        elif best[1] == LEFT_ARC:
            child = stack.roots.pop()
            parent = buf.roots[0]

        elif best[1] == RIGHT_ARC:
            child = stack.roots.pop()
            parent = stack.roots[-1]

        if best[1] == LEFT_ARC or best[1] == RIGHT_ARC:
            #attach
            child.pred_parent_id = parent.id
            child.pred_relation = best[0]
            #update head representation
            if self.rlMostFlag:
                #deepest leftmost/rightmost descendant
                parent.lstms[best[1] + hoffset] = child.lstms[best[1] +
                                                              hoffset]
            if self.rlFlag:
                #leftmost/rightmost child
                parent.lstms[best[1] + hoffset] = child.vec

    def calculate_cost(self, scores, s0, s1, b, beta, stack_ids):
        if len(scores[LEFT_ARC]) == 0:
            left_cost = 1
        else:
            left_cost = len(
                s0[0].rdeps) + int(s0[0].parent_id != b[0].id
                                   and s0[0].id in s0[0].parent_entry.rdeps)

        if len(scores[RIGHT_ARC]) == 0:
            right_cost = 1
        else:
            right_cost = len(
                s0[0].rdeps) + int(s0[0].parent_id != s1[0].id
                                   and s0[0].id in s0[0].parent_entry.rdeps)

        if len(scores[SHIFT]) == 0:
            shift_cost = 1
            shift_case = 0
        elif len([
                item for item in beta
                if item.projective_order < b[0].projective_order
                and item.id > b[0].id
        ]) > 0:
            shift_cost = 0
            shift_case = 1
        else:
            shift_cost = len([d for d in b[0].rdeps if d in stack_ids]) + int(
                len(s0) > 0 and b[0].parent_id in stack_ids[:-1]
                and b[0].id in b[0].parent_entry.rdeps)
            shift_case = 2

        if len(scores[SWAP]) == 0:
            swap_cost = 1
        elif s0[0].projective_order > b[0].projective_order:
            swap_cost = 0
            #disable all the others
            left_cost = right_cost = shift_cost = 1
        else:
            swap_cost = 1

        costs = (left_cost, right_cost, shift_cost, swap_cost, 1)
        return costs, shift_case

    def oracle_updates(self, best, b, s0, stack_ids, shift_case):
        if best[1] == SHIFT:
            if shift_case == 2:
                if b[0].parent_entry.id in stack_ids[:-1] and b[0].id in b[
                        0].parent_entry.rdeps:
                    b[0].parent_entry.rdeps.remove(b[0].id)
                blocked_deps = [d for d in b[0].rdeps if d in stack_ids]
                for d in blocked_deps:
                    b[0].rdeps.remove(d)

        elif best[1] == LEFT_ARC or best[1] == RIGHT_ARC:
            s0[0].rdeps = []
            if s0[0].id in s0[0].parent_entry.rdeps:
                s0[0].parent_entry.rdeps.remove(s0[0].id)

    def Predict(self, treebanks, datasplit, options):
        reached_max_swap = 0
        char_map = {}
        if options.char_map_file:
            char_map_fh = open(options.char_map_file, encoding='utf-8')
            char_map = json.loads(char_map_fh.read())
        # should probably use a namedtuple in get_vocab to make this prettier
        _, test_words, test_chars, _, _, _, test_treebanks, test_langs = utils.get_vocab(
            treebanks, datasplit, char_map)

        # get external embeddings for the set of words and chars in the
        # test vocab but not in the training vocab
        test_embeddings = defaultdict(lambda: {})
        if options.word_emb_size > 0 and options.ext_word_emb_file:
            new_test_words = \
                set(test_words) - self.feature_extractor.words.keys()

            print("Number of OOV word types at test time: %i (out of %i)" %
                  (len(new_test_words), len(test_words)),
                  file=sys.stderr)

            if len(new_test_words) > 0:
                # no point loading embeddings if there are no words to look for
                for lang in test_langs:
                    embeddings = utils.get_external_embeddings(
                        options,
                        emb_file=options.ext_word_emb_file,
                        lang=lang,
                        words=new_test_words)
                    test_embeddings["words"].update(embeddings)
                if len(test_langs) > 1 and test_embeddings["words"]:
                    print("External embeddings found for %i words "\
                          "(out of %i)" % \
                          (len(test_embeddings["words"]), len(new_test_words)), file=sys.stderr)

        if options.char_emb_size > 0:
            new_test_chars = \
                set(test_chars) - self.feature_extractor.chars.keys()
            print("Number of OOV char types at test time: %i (out of %i)" %
                  (len(new_test_chars), len(test_chars)),
                  file=sys.stderr)

            if len(new_test_chars) > 0:
                for lang in test_langs:
                    embeddings = utils.get_external_embeddings(
                        options,
                        emb_file=options.ext_char_emb_file,
                        lang=lang,
                        words=new_test_chars,
                        chars=True)
                    test_embeddings["chars"].update(embeddings)
                if len(test_langs) > 1 and test_embeddings["chars"]:
                    print("External embeddings found for %i chars "\
                          "(out of %i)" % \
                          (len(test_embeddings["chars"]), len(new_test_chars)), file=sys.stderr)

        data = utils.read_conll_dir(treebanks, datasplit, char_map=char_map)
        for iSentence, osentence in enumerate(data, 1):
            sentence = deepcopy(osentence)
            reached_swap_for_i_sentence = False
            max_swap = 2 * len(sentence)
            iSwap = 0
            self.feature_extractor.Init(options)
            conll_sentence = [
                entry for entry in sentence
                if isinstance(entry, utils.ConllEntry)
            ]
            conll_sentence = conll_sentence[1:] + [conll_sentence[0]]

            dst_matrix = None
            if self.distance_probe:
                tokens = [tok.form for tok in conll_sentence]
                # dst_matrix = {'tokens': tokens, 'matrix': self.distance_probe.calc(tokens)}
                dst_matrix = self.distance_probe.calc(tokens)

            # set the embeddings into root.vec of each sentence token
            try:
                sentence = self.feature_extractor.getWordEmbeddings(
                    conll_sentence, False, options, test_embeddings)
            except ValueError as e:
                print(e, file=sys.stderr)
                continue

            attn_maps = None
            if self.attention_indices:
                attn_maps = sentence.attentions

            stack = ParseForest([])
            buf = ParseForest(conll_sentence)

            hoffset = 1 if self.headFlag else 0

            for root in conll_sentence:
                root.lstms = [root.vec] if self.headFlag else []
                root.lstms += [root.vec for _ in range(self.nnvecs - hoffset)]
                root.relation = root.relation if root.relation in self.irels else 'runk'

            while not (len(buf) == 1 and len(stack) == 0):
                scores = self.__evaluate(stack, buf, False, attn_maps,
                                         dst_matrix)
                best = max(
                    chain(*(scores if iSwap < max_swap else scores[:3])),
                    key=itemgetter(2))
                if iSwap == max_swap and not reached_swap_for_i_sentence:
                    reached_max_swap += 1
                    reached_swap_for_i_sentence = True
                    print("reached max swap in %d out of %d sentences" %
                          (reached_max_swap, iSentence),
                          file=sys.stderr)
                self.apply_transition(best, stack, buf, hoffset)
                if best[1] == SWAP:
                    iSwap += 1

            dy.renew_cg()

            #keep in memory the information we need, not all the vectors
            oconll_sentence = [
                entry for entry in osentence
                if isinstance(entry, utils.ConllEntry)
            ]
            oconll_sentence = oconll_sentence[1:] + [oconll_sentence[0]]
            for tok_o, tok in zip(oconll_sentence, conll_sentence):
                tok_o.pred_relation = tok.pred_relation
                tok_o.pred_parent_id = tok.pred_parent_id
            yield osentence

    def Train(self, trainData, options):
        mloss = 0.0
        eloss = 0.0
        eerrors = 0
        lerrors = 0
        etotal = 0
        ninf = -float('inf')

        beg = time.time()
        start = time.time()

        random.shuffle(
            trainData
        )  # in certain cases the data will already have been shuffled after being read from file or while creating dev data
        print("Length of training data: ", len(trainData), file=sys.stderr)

        errs = []

        self.feature_extractor.Init(options)

        for iSentence, sentence in enumerate(trainData, 1):
            if iSentence % 100 == 0:
                loss_message = 'Processing sentence number: %d'%iSentence + \
                ' Loss: %.3f'%(eloss / etotal)+ \
                ' Errors: %.3f'%((float(eerrors)) / etotal)+\
                ' Labeled Errors: %.3f'%(float(lerrors) / etotal)+\
                ' Time: %.2gs'%(time.time()-start)
                print(loss_message, file=sys.stderr)
                start = time.time()
                eerrors = 0
                eloss = 0.0
                etotal = 0
                lerrors = 0

            sentence = deepcopy(
                sentence
            )  # ensures we are working with a clean copy of sentence and allows memory to be recycled each time round the loop

            conll_sentence = [
                entry for entry in sentence
                if isinstance(entry, utils.ConllEntry)
            ]
            conll_sentence = conll_sentence[1:] + [conll_sentence[0]
                                                   ]  # move root to the end

            dst_matrix = None
            if self.distance_probe:
                tokens = [tok.form for tok in conll_sentence]
                dst_matrix = self.distance_probe.calc(tokens)

            # set the embeddings into root.vec of each sentence token
            try:
                sentence = self.feature_extractor.getWordEmbeddings(
                    conll_sentence, True, options)
            except ValueError as e:
                print(e, file=sys.stderr)
                continue

            attn_maps = None
            if self.attention_indices:
                attn_maps = sentence.attentions

            stack = ParseForest([])
            buf = ParseForest(conll_sentence)
            hoffset = 1 if self.headFlag else 0

            for root in conll_sentence:
                root.lstms = [root.vec] if self.headFlag else []
                root.lstms += [root.vec for _ in range(self.nnvecs - hoffset)]
                root.relation = root.relation if root.relation in self.irels else 'runk'

            while not (len(buf) == 1 and len(stack) == 0):
                scores = self.__evaluate(stack, buf, True, attn_maps,
                                         dst_matrix)

                #to ensure that we have at least one wrong operation
                scores.append([(None, 4, ninf, None)])

                stack_ids = [sitem.id for sitem in stack.roots]

                s1 = [stack.roots[-2]] if len(stack) > 1 else []
                s0 = [stack.roots[-1]] if len(stack) > 0 else []
                b = [buf.roots[0]] if len(buf) > 0 else []
                beta = buf.roots[1:] if len(buf) > 1 else []

                costs, shift_case = self.calculate_cost(
                    scores, s0, s1, b, beta, stack_ids)

                bestValid = list(
                    (s for s in chain(*scores)
                     if costs[s[1]] == 0 and (s[1] == SHIFT or s[1] == SWAP
                                              or s[0] == s0[0].relation)))

                bestValid = max(bestValid, key=itemgetter(2))
                bestWrong = max(
                    (s for s in chain(*scores)
                     if costs[s[1]] != 0 or (s[1] != SHIFT and s[1] != SWAP
                                             and s[0] != s0[0].relation)),
                    key=itemgetter(2))

                #force swap
                if costs[SWAP] == 0:
                    best = bestValid
                else:
                    #select a transition to follow
                    # + aggresive exploration
                    #1: might want to experiment with that parameter
                    if bestWrong[1] == SWAP:
                        best = bestValid
                    else:
                        best = bestValid if (
                            (not self.oracle) or
                            (bestValid[2] - bestWrong[2] > 1.0) or
                            (bestValid[2] > bestWrong[2]
                             and random.random() > 0.1)) else bestWrong

                if best[1] == LEFT_ARC or best[1] == RIGHT_ARC:
                    child = s0[0]

                #updates for the dynamic oracle
                if self.oracle:
                    self.oracle_updates(best, b, s0, stack_ids, shift_case)

                self.apply_transition(best, stack, buf, hoffset)

                if bestValid[2] < bestWrong[2] + 1.0:
                    loss = bestWrong[3] - bestValid[3]
                    mloss += 1.0 + bestWrong[2] - bestValid[2]
                    eloss += 1.0 + bestWrong[2] - bestValid[2]
                    errs.append(loss)

                #labeled errors
                if best[1] == LEFT_ARC or best[1] == RIGHT_ARC:
                    if (child.pred_parent_id != child.parent_id
                            or child.pred_relation != child.relation):
                        lerrors += 1
                        #attachment error
                        if child.pred_parent_id != child.parent_id:
                            eerrors += 1

                #??? when did this happen and why?
                if best[1] == 0 or best[1] == 2:
                    etotal += 1

            #footnote 8 in Eli's original paper
            if len(errs) > 50:  # or True:
                eerrs = dy.esum(errs)
                scalar_loss = eerrs.scalar_value()  #forward
                eerrs.backward()
                self.trainer.update()
                errs = []
                lerrs = []

                dy.renew_cg()
                self.feature_extractor.Init(options)

        if len(errs) > 0:
            eerrs = (dy.esum(errs))
            eerrs.scalar_value()
            eerrs.backward()
            self.trainer.update()

            errs = []
            lerrs = []

            dy.renew_cg()

        self.trainer.update()
        print("Loss: ", mloss / iSentence, file=sys.stderr)
        print("Total Training Time: %.2gs" % (time.time() - beg),
              file=sys.stderr)