def forward(self, words, feats): words, position = self.flat_words(words) _, encoded = self.ernie(words) x = layers.reshape( nn.index_sample(encoded, position), shape=position.shape[:2] + [encoded.shape[2]], ) words = nn.index_sample(words, position) return words, x
def epoch_predict(env, args, model, loader): """Predict in one epoch""" model.eval() arcs, rels, probs = [], [], [] for words, feats in loader(): # ignore the first token of each sentence tmp_words = layers.pad(words[:, 1:], paddings=[0, 0, 1, 0], pad_value=args.pad_index) mask = tmp_words != args.pad_index lens = nn.reduce_sum(mask, -1) s_arc, s_rel = model(words, feats) arc_preds, rel_preds = decode(args, s_arc, s_rel, mask) arcs.extend( layers.split(nn.masked_select(arc_preds, mask), lens.numpy().tolist())) rels.extend( layers.split(nn.masked_select(rel_preds, mask), lens.numpy().tolist())) if args.prob: arc_probs = nn.index_sample(layers.softmax(s_arc, -1), layers.unsqueeze(arc_preds, -1)) probs.extend( layers.split( nn.masked_select(layers.squeeze(arc_probs, axes=[-1]), mask), lens.numpy().tolist())) arcs = [seq.numpy().tolist() for seq in arcs] rels = [env.REL.vocab[seq.numpy().tolist()] for seq in rels] probs = [[round(p, 3) for p in seq.numpy().tolist()] for seq in probs] return arcs, rels, probs
def loss_function(s_arc, s_rel, arcs, rels, mask): """Loss function""" arcs = nn.masked_select(arcs, mask) rels = nn.masked_select(rels, mask) s_arc = nn.masked_select(s_arc, mask) s_rel = nn.masked_select(s_rel, mask) s_rel = nn.index_sample(s_rel, layers.unsqueeze(arcs, 1)) arc_loss = layers.cross_entropy(layers.softmax(s_arc), arcs) rel_loss = layers.cross_entropy(layers.softmax(s_rel), rels) loss = layers.reduce_mean(arc_loss + rel_loss) return loss
def decode(args, s_arc, s_rel, mask): """Decode function""" mask = mask.numpy() lens = np.sum(mask, -1) # prevent self-loops arc_preds = layers.argmax(s_arc, -1).numpy() bad = [not utils.istree(seq[:i + 1]) for i, seq in zip(lens, arc_preds)] if args.tree and any(bad): arc_preds[bad] = utils.eisner(s_arc.numpy()[bad], mask[bad]) arc_preds = dygraph.to_variable(arc_preds, zero_copy=False) rel_preds = layers.argmax(s_rel, axis=-1) # batch_size, seq_len, _ = rel_preds.shape rel_preds = nn.index_sample(rel_preds, layers.unsqueeze(arc_preds, -1)) rel_preds = layers.squeeze(rel_preds, axes=[-1]) return arc_preds, rel_preds
def forward(self, words, feats): words, position = self.flat_words(words) word_embed = self.word_embed(words) # word_embed = self.embed_dropout(word_embed) # concatenate the word and feat representations # embed.size = (batch, seq_len, n_embed * 2) embed = word_embed mask = words != self.args.pad_index x = self.lstm(embed, mask) x = layers.reshape(nn.index_sample(x, position), shape=position.shape[:2] + [x.shape[2]]) words = paddle.index_sample(words, position) x = self.lstm_dropout(x) return words, x
def epoch_predict(env, args, model, loader): """Predict in one epoch""" connections, deprels, probabilities = [], [], [] pad_index = args.pad_index bos_index = args.bos_index eos_index = args.eos_index for batch, inputs in enumerate(loader(), start=1): if args.encoding_model.startswith("ernie"): words = inputs[0] connection_prob, deprel_prob, words = model(words) else: words, feats = inputs connection_prob, deprel_prob, words = model(words, feats) mask = layers.logical_and( layers.logical_and(words != pad_index, words != bos_index), words != eos_index, ) lens = nn.reduce_sum(mask, -1) connection_predicts, deprel_predicts = decode(args, connection_prob, deprel_prob, mask) connections.extend( layers.split(nn.masked_select(connection_predicts, mask), lens.numpy().tolist())) deprels.extend( layers.split(nn.masked_select(deprel_predicts, mask), lens.numpy().tolist())) if args.prob: arc_probs = nn.index_sample( layers.softmax(connection_prob, -1), layers.unsqueeze(connection_predicts, -1)) probabilities.extend( layers.split( nn.masked_select(layers.squeeze(arc_probs, axes=[-1]), mask), lens.numpy().tolist(), )) connections = [seq.numpy().tolist() for seq in connections] deprels = [env.REL.vocab[seq.numpy().tolist()] for seq in deprels] probabilities = [[round(p, 3) for p in seq.numpy().tolist()] for seq in probabilities] return connections, deprels, probabilities