def decode_step(self, vars): context_vector = self.attention.compute_context_vector(vars.state, vars.encoder_states, precomputed_values=vars.precomputed_values, mask=vars.input_mask) decoder_input = T.concat([context_vector, vars.feedback]) new_state, new_cell = self.decoder_rnn.compute_step(vars.state, lstm_cell=vars.cell, input=decoder_input) vars.state = new_state vars.cell = new_cell
def decode(self, encoder_outputs, target_vars, input_mask=None, sampling=False, extra_outputs=None): """ Decoding graph. """ encoder_states = encoder_outputs.encoder_states batch_size = encoder_states.shape[0] feedbacks = T.concat( [T.ones((batch_size, 1), dtype="int32"), target_vars[:, :-1]], axis=1) feedback_embeds = self.lookup_feedback(feedbacks) # Process initial states decoder_outputs = {"t": T.constant(0, dtype="int32")} for state_name, size in zip(self._decoder_states, self._decoder_state_sizes): if "init_{}".format(state_name) in encoder_outputs: decoder_outputs[state_name] = encoder_outputs["init_{}".format( state_name)] else: decoder_outputs[state_name] = T.zeros((batch_size, size)) if extra_outputs: decoder_outputs.update(extra_outputs) # Process non-seqeuences non_sequences = {"input_mask": input_mask} for k, val in encoder_outputs.items(): if not k.startswith("init_"): non_sequences[k] = val loop = D.graph.loop(sequences={ "feedback_token": feedbacks.dimshuffle((1, 0)), "feedback": feedback_embeds.dimshuffle((1, 0, 2)) }, outputs=decoder_outputs, non_sequences=non_sequences) with loop as vars: if sampling: self.sample_step(vars) else: self.decode_step(vars) vars.t += 1 output_map = MapDict() for state_name in decoder_outputs: if loop.outputs[state_name].ndim == 2: output_map[state_name] = loop.outputs[state_name].dimshuffle( (1, 0)) elif loop.outputs[state_name].ndim == 3: output_map[state_name] = loop.outputs[state_name].dimshuffle( (1, 0, 2)) else: output_map[state_name] = loop.outputs[state_name] if loop.updates: self._decoder_updates.extend(loop.updates) return output_map
def decode_step(self, vars): align_weights = self.attention.compute_alignments( vars.state, vars.precomputed_values, vars.input_mask) context_vector = T.sum(align_weights[:, :, None] * vars.encoder_states, axis=1) decoder_input = T.concat([context_vector, vars.feedback]) new_state, new_cell = self.decoder_rnn.compute_step( vars.state, lstm_cell=vars.cell, input=decoder_input) vars.state = new_state vars.cell = new_cell if self.test_exporting: # Record attention weights vars.state += D.debug.record(align_weights, "att")
def encode(self, input_vars, input_mask=None): input_embeds = self.src_embed_layer.compute(input_vars, mask=input_mask) # Encoder forward_rnn_var = self.forward_encoder.compute(input_embeds, mask=input_mask) backward_rnn_var = T.reverse(self.backward_encoder.compute(input_embeds, mask=input_mask, backward=True), axis=1) encoder_states = T.concat([forward_rnn_var, backward_rnn_var], axis=2) precomputed_att_values = self.attention.precompute(encoder_states) return { "encoder_states": encoder_states, "init_state": self.first_state_nn.compute(backward_rnn_var[:, 0]), "precomputed_values": precomputed_att_values }