def calc_attention(self, state: dy.Expression) -> dy.Expression: V = dy.parameter(self.pV) U = dy.parameter(self.pU) WI = self.WI curr_sent_mask = self.curr_sent.mask if self.truncate_dec_batches: if curr_sent_mask: state, WI, curr_sent_mask = batchers.truncate_batches(state, WI, curr_sent_mask) else: state, WI = batchers.truncate_batches(state, WI) h = dy.tanh(dy.colwise_add(WI, V * state)) scores = dy.transpose(U * h) if curr_sent_mask is not None: scores = curr_sent_mask.add_to_tensor_expr(scores, multiplicator = -100.0) normalized = dy.softmax(scores) self.attention_vecs.append(normalized) return normalized
def calc_nll(self, src: Union[batchers.Batch, sent.Sentence], trg: Union[batchers.Batch, sent.Sentence]) -> dy.Expression: event_trigger.start_sent(src) if isinstance(src, batchers.CompoundBatch): src = src.batches[0] # Encode the sentence initial_state = self._encode_src(src) dec_state = initial_state trg_mask = trg.mask if batchers.is_batched(trg) else None cur_losses = [] seq_len = trg.sent_len() if settings.CHECK_VALIDITY and batchers.is_batched(src): for j, single_trg in enumerate(trg): assert single_trg.sent_len( ) == seq_len # assert consistent length assert 1 == len([ i for i in range(seq_len) if (trg_mask is None or trg_mask.np_arr[j, i] == 0) and single_trg[i] == vocabs.Vocab.ES ]) # assert exactly one unmasked ES token input_word = None for i in range(seq_len): ref_word = DefaultTranslator._select_ref_words( trg, i, truncate_masked=self.truncate_dec_batches) if self.truncate_dec_batches and batchers.is_batched(ref_word): dec_state.rnn_state, ref_word = batchers.truncate_batches( dec_state.rnn_state, ref_word) if input_word is not None: dec_state = self.decoder.add_input( dec_state, self.trg_embedder.embed(input_word)) rnn_output = dec_state.rnn_state.output() dec_state.context = self.attender.calc_context(rnn_output) word_loss = self.decoder.calc_loss(dec_state, ref_word) if not self.truncate_dec_batches and batchers.is_batched( src) and trg_mask is not None: word_loss = trg_mask.cmult_by_timestep_expr(word_loss, i, inverse=True) cur_losses.append(word_loss) input_word = ref_word if self.truncate_dec_batches: loss_expr = dy.esum([dy.sum_batches(wl) for wl in cur_losses]) else: loss_expr = dy.esum(cur_losses) return loss_expr
def add_input(self, mlp_dec_state: AutoRegressiveDecoderState, trg_embedding: dy.Expression) -> AutoRegressiveDecoderState: """Add an input and update the state. Args: mlp_dec_state: An object containing the current state. trg_embedding: The embedding of the word to input. Returns: The updated decoder state. """ inp = trg_embedding if self.input_feeding: inp = dy.concatenate([inp, mlp_dec_state.context]) rnn_state = mlp_dec_state.rnn_state if self.truncate_dec_batches: rnn_state, inp = batchers.truncate_batches(rnn_state, inp) return AutoRegressiveDecoderState(rnn_state=rnn_state.add_input(inp), context=mlp_dec_state.context)
def add_input(self, dec_state: AutoRegressiveDecoderState, trg_word: Any) -> AutoRegressiveDecoderState: """ Add an input and return a *new* update the state. Args: dec_state: An object containing the current state. trg_word: The word to input. Returns: The updated decoder state. """ trg_embedding = self.embedder.embed(trg_word) inp = trg_embedding if self.input_feeding: inp = dy.concatenate([inp, dec_state.context]) rnn_state = dec_state.rnn_state if self.truncate_dec_batches: rnn_state, inp = batchers.truncate_batches(rnn_state, inp) return AutoRegressiveDecoderState(rnn_state=rnn_state.add_input(inp), context=dec_state.context)
def calc_context(self, state): attention = self.calc_attention(state) I = self.curr_sent.as_tensor() if self.truncate_dec_batches: I, attention = batchers.truncate_batches(I, attention) return I * attention