def _ranking(self, inputs, predictions): """ Reranking generated responses. """ src_token = inputs["src_token"] src_mask = inputs["src_mask"] src_pos = inputs["src_pos"] src_type = inputs["src_type"] src_turn = inputs["src_turn"] src_embed = self.embedder(src_token, src_pos, src_type, src_turn) batch_size, num_latent, tgt_seq_len = predictions.shape # shape: [batch_size, num_latent, seq_len, 1] preds_token = F.unsqueeze(predictions, [3]) preds_mask = F.not_equal(preds_token, self.padding_idx, "int64") preds_pos = layers.range(0, tgt_seq_len, 1, dtype="float32") preds_pos = F.unsqueeze(preds_pos, [0, 0, 1]) preds_pos = layers.expand(preds_pos, [batch_size, num_latent, 1, 1]) preds_pos = layers.cast(preds_pos, "int64") preds_type = layers.zeros_like(preds_token) preds_turn = layers.zeros_like(preds_token) scores = [] for i in range(num_latent): pred_token = preds_token[:, i] pred_mask = preds_mask[:, i] pred_pos = preds_pos[:, i] pred_type = preds_type[:, i] pred_turn = preds_turn[:, i] input_mask = layers.concat([src_mask, pred_mask], axis=1) input_mask.stop_gradient = True pred_embed = self.embedder(pred_token, pred_pos, pred_type, pred_turn) embed = layers.concat([src_embed, pred_embed], axis=1) embed = self.embed_layer_norm(embed) mask_embed = self.mask_embed mask_embed = layers.expand(mask_embed, [batch_size, 1, 1]) mask_embed = self.embed_layer_norm(mask_embed) out = layers.concat([mask_embed, embed], axis=1) mask = self._create_mask(input_mask, append_head=True) for layer in self.layers: out = layer(out, mask, None) mask_embed = out[:, 0] score = self.discriminator(mask_embed) scores.append(score[:, 0]) scores = layers.stack(scores, axis=1) return scores
def _collect_metrics(self, inputs, outputs): """ Calculate loss function by using inputs and outputs. """ metrics = {} tgt_len = layers.reduce_sum( layers.reduce_sum(inputs["tgt_mask"], dim=1) - 1) tgt_len.stop_gradient = True label = inputs["tgt_token"][:, 1:] if self.label_smooth > 0: one_hot_label = layers.one_hot(label, self.num_token_embeddings) smooth_label = layers.label_smooth(one_hot_label, epsilon=self.label_smooth, dtype=self._dtype) nll = layers.cross_entropy(outputs["dec_pred"], smooth_label, soft_label=True, ignore_index=self.padding_idx) else: nll = layers.cross_entropy(outputs["dec_probs"], label, ignore_index=self.padding_idx) nll = layers.reduce_sum(nll, dim=1) token_nll = layers.reduce_sum(nll) / tgt_len nll = layers.reduce_mean(nll) metrics["nll"] = nll metrics["token_nll"] = token_nll loss = nll if self.num_latent > 0 and self.with_bow: bow_probs = F.unsqueeze(outputs["bow_probs"], [1]) bow_probs = layers.expand(bow_probs, [1, label.shape[1], 1]) if self.label_smooth > 0: bow = layers.cross_entropy(bow_probs, smooth_label, soft_label=True, ignore_index=self.padding_idx) else: bow = layers.cross_entropy(bow_probs, label, ignore_index=self.padding_idx) bow = layers.reduce_sum(bow, dim=1) token_bow = layers.reduce_sum(bow) / tgt_len bow = layers.reduce_mean(bow) metrics["bow"] = bow metrics["token_bow"] = token_bow loss = loss + bow if self.num_latent > 0 and self.use_discriminator: dis = 0.0 - (layers.log(outputs["pos_probs"]) + layers.log(1.0 - outputs["neg_probs"])) dis = layers.reduce_mean(dis) metrics["dis"] = dis loss = loss + dis * self.dis_ratio metrics["loss"] = loss metrics["token_num"] = tgt_len return metrics
def repeat(var, times): if isinstance(var, list): return [repeat(x, times) for x in var] elif isinstance(var, dict): return {k: repeat(v, times) for k, v in var.items()} elif isinstance(var, Variable): var = F.unsqueeze(var, [1]) expand_times = [1] * len(var.shape) expand_times[1] = times dtype = var.dtype var = layers.cast(var, "float32") var = layers.expand(var, expand_times) shape = [var.shape[0] * var.shape[1]] + var.shape[2:] var = layers.reshape(var, shape) var = layers.cast(var, dtype) return var else: return var
def _attn(self, query, key, value, mask): # shape: [batch_size, num_head, seq_len, seq_len] scores = layers.matmul(x=query, y=key, alpha=self.scale) if mask is not None: mask = F.unsqueeze(mask, [1]) mask = layers.expand(mask, [1, self.num_heads, 1, 1]) mask.stop_gradient = True scores = (1 - mask) * scores + layers.scale(mask, scale=-1e10) attn = layers.softmax(scores, axis=-1) attn = F.dropout(attn, self.dropout) if mask is not None: attn = (1 - mask) * attn out = layers.matmul(x=attn, y=value) return out
def _generation_network(self, input_mask, embed, batch_size, src_len, tgt_len, latent_embed): """ Basic generation network implement. """ if self.num_latent > 0: latent_embed = F.unsqueeze(latent_embed, [1]) latent_embed = self.embed_layer_norm(latent_embed) dec_embed = layers.concat([latent_embed, embed], axis=1) else: dec_embed = embed # Create generation network mask src_mask = input_mask[:, :src_len] tgt_mask = input_mask[:, src_len:] enc_mask = self._create_mask( src_mask, auto_regressive=not self.bidirectional_context, append_head=self.num_latent > 0) dec_mask = self._create_mask(tgt_mask, auto_regressive=True) mask = self._join_mask(enc_mask, dec_mask) for layer in self.layers: dec_embed = layer(dec_embed, mask, None) if self.num_latent > 0: latent_embed = dec_embed[:, 0] else: latent_embed = None dec_embed = dec_embed[:, -tgt_len:] if self.two_layer_predictor: dec_embed = self.pre_predictor(dec_embed) if self.weight_sharing: token_embedding = self.embedder.token_embedding._w dec_logits = layers.matmul(x=dec_embed, y=token_embedding, transpose_y=True) else: dec_logits = self.predictor(dec_embed) dec_probs = layers.softmax(dec_logits, axis=-1) return latent_embed, dec_probs
def __call__(self, step_fn, state): """ Running beam search. @param : step_fn : decoding one step @type : function @param : state : initial state @type : dict """ batch_size = state["batch_size"] beam_size = self.beam_size # shape: [batch_size, 1] pos_index = layers.range(0, batch_size, 1, dtype="int64") pos_index = layers.scale(pos_index, beam_size) pos_index = F.unsqueeze(pos_index, [1]) # shape: [batch_size, beam_size, 1] predictions = layers.fill_constant(shape=[batch_size, beam_size, 1], dtype="int64", value=self.bos_id) # initial input state["pred_token"] = predictions[:, :1] # shape: [batch_size, vocab_size] scores, state = step_fn(state) unk_penalty = np.zeros(self.vocab_size, dtype="float32") unk_penalty[self.unk_id] = -1e10 unk_penalty = layers.assign(unk_penalty) eos_penalty = np.zeros(self.vocab_size, dtype="float32") eos_penalty[self.eos_id] = -1e10 eos_penalty = layers.assign(eos_penalty) scores_after_end = np.full(self.vocab_size, -1e10, dtype="float32") scores_after_end[self.pad_id] = 0 scores_after_end = layers.assign(scores_after_end) if self.ignore_unk: scores = scores + unk_penalty scores = scores + eos_penalty # shape: [batch_size, beam_size] sequence_scores, preds = layers.topk(scores, self.beam_size) predictions = layers.concat( [predictions, F.unsqueeze(preds, [2])], axis=2) state = repeat(state, beam_size) parent_idx_list = [] pred_list = [] for step in range(2, self.max_gen_len + 1): pre_ids = predictions[:, :, -1:] state["pred_token"] = layers.reshape( pre_ids, shape=[batch_size * beam_size, 1, 1]) state["pred_mask"] = 1 - F.equal(state["pred_token"], self.pad_id) state["pred_pos"] = state["pred_pos"] + 1 scores, state = step_fn(state) # Generate next # scores shape: [batch_size, beam_size, vocab_size] if self.ignore_unk: scores = scores + unk_penalty if step <= self.min_gen_len: scores = scores + eos_penalty scores = layers.reshape( scores, shape=[batch_size, beam_size, self.vocab_size]) # previous token is [PAD] or [EOS] pre_eos_mask = F.equal(pre_ids, self.eos_id) + F.equal( pre_ids, self.pad_id) scores = scores * (1 - pre_eos_mask) + \ layers.expand(pre_eos_mask, [1, 1, self.vocab_size]) * scores_after_end if self.length_average: scaled_value = pre_eos_mask + (1 - pre_eos_mask) * (1 - 1 / step) sequence_scores = F.unsqueeze(sequence_scores, [2]) * scaled_value scaled_value = pre_eos_mask + (1 - pre_eos_mask) * (1 / step) scores = scores * scaled_value elif self.length_penalty >= 0.0: scaled_value = pre_eos_mask + (1 - pre_eos_mask) * \ (math.pow((4 + step) / (5 + step), self.length_penalty)) sequence_scores = layers.elementwise_mul(scaled_value, sequence_scores, axis=0) scaled_value = pre_eos_mask + (1 - pre_eos_mask) * \ (math.pow(1 / (5 + step), self.length_penalty)) scores = scores * scaled_value scores = layers.elementwise_add(scores, sequence_scores, axis=0) scores = layers.reshape( scores, shape=[batch_size, beam_size * self.vocab_size]) topk_scores, topk_indices = layers.topk(scores, beam_size) vocab_size = layers.fill_constant(shape=[1], dtype="int64", value=self.vocab_size) parent_idx = layers.elementwise_floordiv(topk_indices, vocab_size) preds = layers.elementwise_mod(topk_indices, vocab_size) # Gather state / sequence_scores parent_idx = layers.elementwise_add(parent_idx, pos_index, axis=0) parent_idx = layers.reshape(parent_idx, [batch_size * beam_size]) state = gather(state, parent_idx) sequence_scores = topk_scores predictions = layers.reshape(predictions, shape=[batch_size * beam_size, step]) predictions = gather(predictions, parent_idx) predictions = layers.reshape(predictions, shape=[batch_size, beam_size, step]) predictions = layers.concat( [predictions, F.unsqueeze(preds, [2])], axis=2) pre_ids = predictions[:, :, -1] pre_eos_mask = F.equal(pre_ids, self.eos_id) + F.equal( pre_ids, self.pad_id) sequence_scores = sequence_scores * pre_eos_mask + layers.scale( 1 - pre_eos_mask, -1e10) _, indices = layers.argsort(sequence_scores, axis=1) indices = indices + pos_index indices = layers.reshape(indices, [-1]) sequence_scores = layers.reshape(sequence_scores, [batch_size * beam_size]) predictions = layers.reshape(predictions, [batch_size * beam_size, -1]) sequence_scores = gather(sequence_scores, indices) predictions = layers.gather(predictions, indices) sequence_scores = layers.reshape(sequence_scores, [batch_size, beam_size]) predictions = layers.reshape(predictions, [batch_size, beam_size, -1]) results = { "preds": predictions[:, -1], "scores": sequence_scores[:, -1] } return results
def __call__(self, step_fn, state): """ Running generation. @param : step_fn : decoding one step @type : function @param : state : initial state @type : dict """ batch_size = state["batch_size"] vocab_size = self.vocab_size pos_index = layers.range(0, batch_size, 1, dtype="int64") pos_index = layers.scale(pos_index, vocab_size) # shape: [batch_size, beam_size, 1] predictions = layers.fill_constant(shape=[batch_size, 1], dtype="int64", value=self.bos_id) sequence_scores = layers.fill_constant(shape=[batch_size], dtype="float32", value=0.0) unk_penalty = np.zeros(vocab_size, dtype="float32") unk_penalty[self.unk_id] = -1e10 unk_penalty = layers.assign(unk_penalty) eos_penalty = np.zeros(vocab_size, dtype="float32") eos_penalty[self.eos_id] = -1e10 eos_penalty = layers.assign(eos_penalty) scores_after_end = np.full(vocab_size, -1e10, dtype="float32") scores_after_end[self.pad_id] = 0 scores_after_end = layers.assign(scores_after_end) # initial input for step in range(1, self.max_gen_len + 1): pre_ids = predictions[:, -1:] state["pred_token"] = F.unsqueeze(pre_ids, [2]) if step > 1: state["pred_mask"] = 1 - F.equal(state["pred_token"], self.pad_id) state["pred_pos"] = state["pred_pos"] + 1 scores, state = step_fn(state) # Generate next # scores shape: [batch_size, vocab_size] if self.ignore_unk: scores = scores + unk_penalty if step <= self.min_gen_len: scores = scores + eos_penalty # previous token is [PAD] or [EOS] # shape: [batch_size, 1] pre_eos_mask = F.equal(pre_ids, self.eos_id) + F.equal( pre_ids, self.pad_id) scores = scores * (1 - pre_eos_mask) + \ layers.expand(pre_eos_mask, [1, vocab_size]) * scores_after_end scores = scores / self.temperature preds = self._sampling(scores) predictions = layers.concat( [predictions, F.unsqueeze(preds, [1])], axis=1) scores = layers.reshape(scores, [batch_size * vocab_size]) preds = preds + pos_index scores = gather(scores, preds) sequence_scores = sequence_scores + scores results = {"preds": predictions, "scores": sequence_scores} return results
def _init_state(self, inputs): """ Initialize decode state. """ state = {} src_token = inputs["src_token"] src_mask = inputs["src_mask"] src_pos = inputs["src_pos"] src_type = inputs["src_type"] src_turn = inputs["src_turn"] batch_size = src_token.shape[0] seq_len = src_token.shape[1] src_embed = self.embedder(src_token, src_pos, src_type, src_turn) src_embed = self.embed_layer_norm(src_embed) mask = self._create_mask(src_mask, append_head=self.num_latent > 0) if self.num_latent > 0: src_embed = F.unsqueeze(src_embed, [1]) src_embed = layers.expand(src_embed, [1, self.num_latent, 1, 1]) src_embed = layers.reshape(src_embed, [-1, seq_len, self.hidden_dim]) latent_embed = self.latent_embeddings latent_embed = F.unsqueeze(latent_embed, [1]) latent_embed = layers.expand(latent_embed, [batch_size, 1, 1]) latent_embed = self.embed_layer_norm(latent_embed) enc_out = layers.concat([latent_embed, src_embed], axis=1) mask = F.unsqueeze(mask, [1]) mask = layers.expand(mask, [1, self.num_latent, 1, 1]) mask = layers.reshape(mask, [-1, seq_len + 1, seq_len + 1]) else: enc_out = src_embed cache = {} for l, layer in enumerate(self.layers): cache[f"layer_{l}"] = {} enc_out = layer(enc_out, mask, cache[f"layer_{l}"]) state["cache"] = cache state["mask"] = mask[:, :1] if self.num_latent > 0: state["batch_size"] = batch_size * self.num_latent shape = [batch_size * self.num_latent, 1, 1] else: state["batch_size"] = batch_size shape = [batch_size, 1, 1] state["pred_mask"] = layers.ones(shape, self._dtype) state["pred_pos"] = layers.zeros(shape, "int64") state["pred_type"] = layers.zeros(shape, "int64") state["pred_turn"] = layers.zeros(shape, "int64") if "tgt_token" in inputs and self.num_latent > 0: tgt_token = inputs["tgt_token"][:, :-1] tgt_mask = inputs["tgt_mask"][:, :-1] tgt_pos = inputs["tgt_pos"][:, :-1] tgt_type = inputs["tgt_type"][:, :-1] tgt_turn = inputs["tgt_turn"][:, :-1] input_mask = layers.concat([src_mask, tgt_mask], axis=1) input_mask.stop_gradient = True src_embed = self.embedder(src_token, src_pos, src_type, src_turn) tgt_embed = self.embedder(tgt_token, tgt_pos, tgt_type, tgt_turn) embed = layers.concat([src_embed, tgt_embed], axis=1) embed = self.embed_layer_norm(embed) batch_size = src_token.shape[0] src_len = src_token.shape[1] tgt_len = tgt_token.shape[1] post_embed, post_probs, post_logits = self._posteriori_network( input_mask, embed, batch_size, src_len, tgt_len) state["post_probs"] = post_probs return state
def _forward(self, inputs, is_training): """ Real forward process of model in different mode(train/test). """ outputs = {} src_token = inputs["src_token"] src_mask = inputs["src_mask"] src_pos = inputs["src_pos"] src_type = inputs["src_type"] src_turn = inputs["src_turn"] tgt_token = inputs["tgt_token"][:, :-1] tgt_mask = inputs["tgt_mask"][:, :-1] tgt_pos = inputs["tgt_pos"][:, :-1] tgt_type = inputs["tgt_type"][:, :-1] tgt_turn = inputs["tgt_turn"][:, :-1] input_mask = layers.concat([src_mask, tgt_mask], axis=1) input_mask.stop_gradient = True src_embed = self.embedder(src_token, src_pos, src_type, src_turn) tgt_embed = self.embedder(tgt_token, tgt_pos, tgt_type, tgt_turn) embed = layers.concat([src_embed, tgt_embed], axis=1) embed = self.embed_layer_norm(embed) batch_size = src_token.shape[0] src_len = src_token.shape[1] tgt_len = tgt_token.shape[1] if self.num_latent > 0: post_embed, post_probs, post_logits = self._posteriori_network( input_mask, embed, batch_size, src_len, tgt_len) outputs["post_logits"] = post_logits if self.use_discriminator: pos_probs, neg_probs = self._discriminator_network( input_mask, embed, batch_size, src_len, tgt_len, post_embed) outputs["pos_probs"] = pos_probs outputs["neg_probs"] = neg_probs if is_training: z = F.gumbel_softmax(post_logits, self.tau) else: indices = layers.argmax(post_logits, axis=1) z = layers.one_hot(F.unsqueeze(indices, [1]), self.num_latent) latent_embeddings = self.latent_embeddings latent_embed = layers.matmul(z, latent_embeddings) outputs["latent_embed"] = latent_embed else: latent_embed = None latent_embed, dec_probs = self._generation_network( input_mask, embed, batch_size, src_len, tgt_len, latent_embed) outputs["dec_probs"] = dec_probs if self.num_latent > 0 and self.with_bow: if self.two_layer_predictor: latent_embed = self.pre_bow_predictor(latent_embed) bow_logits = self.bow_predictor(latent_embed) bow_probs = layers.softmax(bow_logits) outputs["bow_probs"] = bow_probs return outputs