예제 #1
0
파일: embedding.py 프로젝트: baidu/DDParser
 def forward(self, words, feats):
     words, position = self.flat_words(words)
     _, encoded = self.ernie(words)
     x = layers.reshape(
         nn.index_sample(encoded, position),
         shape=position.shape[:2] + [encoded.shape[2]],
     )
     words = nn.index_sample(words, position)
     return words, x
예제 #2
0
def epoch_predict(env, args, model, loader):
    """Predict in one epoch"""
    model.eval()

    arcs, rels, probs = [], [], []
    for words, feats in loader():
        # ignore the first token of each sentence
        tmp_words = layers.pad(words[:, 1:],
                               paddings=[0, 0, 1, 0],
                               pad_value=args.pad_index)
        mask = tmp_words != args.pad_index
        lens = nn.reduce_sum(mask, -1)
        s_arc, s_rel = model(words, feats)
        arc_preds, rel_preds = decode(args, s_arc, s_rel, mask)
        arcs.extend(
            layers.split(nn.masked_select(arc_preds, mask),
                         lens.numpy().tolist()))
        rels.extend(
            layers.split(nn.masked_select(rel_preds, mask),
                         lens.numpy().tolist()))
        if args.prob:
            arc_probs = nn.index_sample(layers.softmax(s_arc, -1),
                                        layers.unsqueeze(arc_preds, -1))
            probs.extend(
                layers.split(
                    nn.masked_select(layers.squeeze(arc_probs, axes=[-1]),
                                     mask),
                    lens.numpy().tolist()))
    arcs = [seq.numpy().tolist() for seq in arcs]
    rels = [env.REL.vocab[seq.numpy().tolist()] for seq in rels]
    probs = [[round(p, 3) for p in seq.numpy().tolist()] for seq in probs]

    return arcs, rels, probs
예제 #3
0
def loss_function(s_arc, s_rel, arcs, rels, mask):
    """Loss function"""
    arcs = nn.masked_select(arcs, mask)
    rels = nn.masked_select(rels, mask)
    s_arc = nn.masked_select(s_arc, mask)
    s_rel = nn.masked_select(s_rel, mask)
    s_rel = nn.index_sample(s_rel, layers.unsqueeze(arcs, 1))
    arc_loss = layers.cross_entropy(layers.softmax(s_arc), arcs)
    rel_loss = layers.cross_entropy(layers.softmax(s_rel), rels)
    loss = layers.reduce_mean(arc_loss + rel_loss)

    return loss
예제 #4
0
def decode(args, s_arc, s_rel, mask):
    """Decode function"""
    mask = mask.numpy()
    lens = np.sum(mask, -1)
    # prevent self-loops
    arc_preds = layers.argmax(s_arc, -1).numpy()
    bad = [not utils.istree(seq[:i + 1]) for i, seq in zip(lens, arc_preds)]
    if args.tree and any(bad):
        arc_preds[bad] = utils.eisner(s_arc.numpy()[bad], mask[bad])
    arc_preds = dygraph.to_variable(arc_preds, zero_copy=False)
    rel_preds = layers.argmax(s_rel, axis=-1)
    # batch_size, seq_len, _ = rel_preds.shape
    rel_preds = nn.index_sample(rel_preds, layers.unsqueeze(arc_preds, -1))
    rel_preds = layers.squeeze(rel_preds, axes=[-1])
    return arc_preds, rel_preds
예제 #5
0
파일: embedding.py 프로젝트: baidu/DDParser
    def forward(self, words, feats):
        words, position = self.flat_words(words)
        word_embed = self.word_embed(words)
        # word_embed = self.embed_dropout(word_embed)
        # concatenate the word and feat representations
        # embed.size = (batch, seq_len, n_embed * 2)
        embed = word_embed
        mask = words != self.args.pad_index
        x = self.lstm(embed, mask)
        x = layers.reshape(nn.index_sample(x, position),
                           shape=position.shape[:2] + [x.shape[2]])
        words = paddle.index_sample(words, position)
        x = self.lstm_dropout(x)

        return words, x
예제 #6
0
def epoch_predict(env, args, model, loader):
    """Predict in one epoch"""
    connections, deprels, probabilities = [], [], []
    pad_index = args.pad_index
    bos_index = args.bos_index
    eos_index = args.eos_index
    for batch, inputs in enumerate(loader(), start=1):
        if args.encoding_model.startswith("ernie"):
            words = inputs[0]
            connection_prob, deprel_prob, words = model(words)
        else:
            words, feats = inputs
            connection_prob, deprel_prob, words = model(words, feats)
        mask = layers.logical_and(
            layers.logical_and(words != pad_index, words != bos_index),
            words != eos_index,
        )
        lens = nn.reduce_sum(mask, -1)
        connection_predicts, deprel_predicts = decode(args, connection_prob,
                                                      deprel_prob, mask)
        connections.extend(
            layers.split(nn.masked_select(connection_predicts, mask),
                         lens.numpy().tolist()))
        deprels.extend(
            layers.split(nn.masked_select(deprel_predicts, mask),
                         lens.numpy().tolist()))
        if args.prob:
            arc_probs = nn.index_sample(
                layers.softmax(connection_prob, -1),
                layers.unsqueeze(connection_predicts, -1))
            probabilities.extend(
                layers.split(
                    nn.masked_select(layers.squeeze(arc_probs, axes=[-1]),
                                     mask),
                    lens.numpy().tolist(),
                ))
    connections = [seq.numpy().tolist() for seq in connections]
    deprels = [env.REL.vocab[seq.numpy().tolist()] for seq in deprels]
    probabilities = [[round(p, 3) for p in seq.numpy().tolist()]
                     for seq in probabilities]

    return connections, deprels, probabilities