def forward(self, x): """Forward network""" mask = layers.reduce_any(x != self.pad_index, -1) lens = nn.reduce_sum(mask, -1) masked_x = nn.masked_select(x, mask) char_mask = masked_x != self.pad_index emb = self.embed(masked_x) word_lens = nn.reduce_sum(char_mask, -1) _, (h, _) = self.lstm(emb, sequence_length=word_lens) h = layers.concat(layers.unstack(h), axis=-1) feat_embed = nn.pad_sequence_paddle( layers.split(h, lens.numpy().tolist(), dim=0), self.pad_index) return feat_embed
def epoch_predict(env, args, model, loader): """Predict in one epoch""" model.eval() arcs, rels, probs = [], [], [] for words, feats in loader(): # ignore the first token of each sentence tmp_words = layers.pad(words[:, 1:], paddings=[0, 0, 1, 0], pad_value=args.pad_index) mask = tmp_words != args.pad_index lens = nn.reduce_sum(mask, -1) s_arc, s_rel = model(words, feats) arc_preds, rel_preds = decode(args, s_arc, s_rel, mask) arcs.extend( layers.split(nn.masked_select(arc_preds, mask), lens.numpy().tolist())) rels.extend( layers.split(nn.masked_select(rel_preds, mask), lens.numpy().tolist())) if args.prob: arc_probs = nn.index_sample(layers.softmax(s_arc, -1), layers.unsqueeze(arc_preds, -1)) probs.extend( layers.split( nn.masked_select(layers.squeeze(arc_probs, axes=[-1]), mask), lens.numpy().tolist())) arcs = [seq.numpy().tolist() for seq in arcs] rels = [env.REL.vocab[seq.numpy().tolist()] for seq in rels] probs = [[round(p, 3) for p in seq.numpy().tolist()] for seq in probs] return arcs, rels, probs
def forward(self, x): """Forward network""" mask = layers.reduce_any(x != self.pad_index, -1) lens = nn.reduce_sum(mask, -1) masked_x = nn.masked_select(x, mask) h, _ = self.transformer(masked_x) feat_embed = nn.pad_sequence_paddle( layers.split(h, lens.numpy().tolist(), dim=0), self.pad_index) return feat_embed
def flat_words(self, words): pad_index = self.args.pad_index lens = nn.reduce_sum(words != pad_index, dim=-1) position = layers.cumsum(lens + layers.cast((lens == 0), "int32"), axis=1) - 1 flat_words = nn.masked_select(words, words != pad_index) flat_words = nn.pad_sequence_paddle( layers.split(flat_words, layers.reduce_sum(lens, -1).numpy().tolist(), pad_index)) max_len = flat_words.shape[1] position = nn.mask_fill(position, position >= max_len, max_len - 1) return flat_words, position
def epoch_predict(env, args, model, loader): """Predict in one epoch""" connections, deprels, probabilities = [], [], [] pad_index = args.pad_index bos_index = args.bos_index eos_index = args.eos_index for batch, inputs in enumerate(loader(), start=1): if args.encoding_model.startswith("ernie"): words = inputs[0] connection_prob, deprel_prob, words = model(words) else: words, feats = inputs connection_prob, deprel_prob, words = model(words, feats) mask = layers.logical_and( layers.logical_and(words != pad_index, words != bos_index), words != eos_index, ) lens = nn.reduce_sum(mask, -1) connection_predicts, deprel_predicts = decode(args, connection_prob, deprel_prob, mask) connections.extend( layers.split(nn.masked_select(connection_predicts, mask), lens.numpy().tolist())) deprels.extend( layers.split(nn.masked_select(deprel_predicts, mask), lens.numpy().tolist())) if args.prob: arc_probs = nn.index_sample( layers.softmax(connection_prob, -1), layers.unsqueeze(connection_predicts, -1)) probabilities.extend( layers.split( nn.masked_select(layers.squeeze(arc_probs, axes=[-1]), mask), lens.numpy().tolist(), )) connections = [seq.numpy().tolist() for seq in connections] deprels = [env.REL.vocab[seq.numpy().tolist()] for seq in deprels] probabilities = [[round(p, 3) for p in seq.numpy().tolist()] for seq in probabilities] return connections, deprels, probabilities
def forward(self, x, seq_mask): """Forward network""" seq_lens = nn.reduce_sum(seq_mask, -1) y, _ = self.lstm(x, sequence_length=seq_lens) return y