def reconstruction_loss(self, x_input, x_target, x_length, z=None): """Reconstruction loss calculation. Args: x_input: Batch of decoder input sequences for teacher forcing, sized `[batch_size, max(x_length), output_depth]`. x_target: Batch of expected output sequences to compute loss against, sized `[batch_size, max(x_length), output_depth]`. x_length: Length of input/output sequences, sized `[batch_size]`. z: (Optional) Latent vectors. Required if model is conditional. Sized `[n, z_size]`. Returns: r_loss: The reconstruction loss for each sequence in the batch. metric_map: Map from metric name to tf.metrics return values for logging. truths: Ground truth labels. predictions: Predicted labels. final_state: The final states of the decoder, or None if using Cudnn. """ batch_size = x_input.shape[0].value has_z = z is not None z = tf.zeros([batch_size, 0]) if z is None else z repeated_z = tf.tile(tf.expand_dims(z, axis=1), [1, tf.shape(x_input)[1], 1]) sampling_probability_static = tensor_util.constant_value( self._sampling_probability) if sampling_probability_static == 0.0: # Use teacher forcing. x_input = tf.concat([x_input, repeated_z], axis=2) helper = seq2seq.TrainingHelper(x_input, x_length) else: # Use scheduled sampling. helper = seq2seq.ScheduledOutputTrainingHelper( inputs=x_input, sequence_length=x_length, auxiliary_inputs=repeated_z if has_z else None, sampling_probability=self._sampling_probability, next_inputs_fn=self._sample) decoder_outputs, final_state = self._decode(z, helper=helper, x_input=x_input) flat_x_target = flatten_maybe_padded_sequences(x_target, x_length) flat_rnn_output = flatten_maybe_padded_sequences( decoder_outputs.rnn_output, x_length) r_loss, metric_map, truths, predictions = self._flat_reconstruction_loss( flat_x_target, flat_rnn_output) # Sum loss over sequences. cum_x_len = tf.concat([(0, ), tf.cumsum(x_length)], axis=0) r_losses = [] for i in range(batch_size): b, e = cum_x_len[i], cum_x_len[i + 1] r_losses.append(tf.reduce_sum(r_loss[b:e])) r_loss = tf.stack(r_losses) return r_loss, metric_map, truths, predictions, final_state
def _build_decoder(self, encoder_states, target_sequence, keep_prob, sampling_prob, attention_mechanism): """Define decoder architecture. """ # connect each layer sequentially, building a graph that resembles a # feed-forward network made of recurrent units decoder_cell = self._multi_cell(num_units=self.num_units, num_layers=self.num_layers, keep_prob=keep_prob) # connect attention to decoder attention_layer_size = self.num_units decoder = seq2seq.AttentionWrapper( cell=decoder_cell, attention_mechanism=attention_mechanism, attention_layer_size=attention_layer_size) # decoder start symbol decoder_raw_seq = target_sequence[:, :-1] prefix = tf.fill([tf.shape(target_sequence)[0], 1, self.target_depth], 0.0) decoder_input_seq = tf.concat([prefix, decoder_raw_seq], axis=1) # the model is using fixed lengths of target sequences so tile the defined # length in the batch dimension decoder_sequence_length = tf.tile([self.target_length], [tf.shape(target_sequence)[0]]) # decoder sampling scheduler feeds decoder output to next time input # instead of using ground-truth target vals during training helper = seq2seq.ScheduledOutputTrainingHelper( inputs=decoder_input_seq, sequence_length=decoder_sequence_length, sampling_probability=sampling_prob) # output layer projection_layer = Dense(units=self.target_depth, use_bias=True) # clone encoder state initial_state = decoder.zero_state( tf.shape(target_sequence)[0], tf.float32) initial_state = initial_state.clone(cell_state=encoder_states) # wrapper for decoder decoder = seq2seq.BasicDecoder(cell=decoder, helper=helper, initial_state=initial_state, output_layer=projection_layer) # build the unrolled graph of the recurrent neural network outputs, decoder_state, _sequence_lengths = seq2seq.dynamic_decode( decoder=decoder, maximum_iterations=self.target_length) return (outputs, decoder_state)
def build_predict_op(self): with tf.variable_scope('predict'): decoder_cell = self.decoder_cell targets = self.targets sequence_lengths = self.training_seq_lens predict_helper = seq2seq.ScheduledOutputTrainingHelper( targets, sequence_lengths, sampling_probability=1.0, next_input_layer=self.projection_layer) decoder = seq2seq.BasicDecoder(decoder_cell, predict_helper, self.get_zero_state()) output, _, _ = seq2seq.dynamic_decode(decoder, output_time_major=True) self.predictions = self.projection_layer(output.rnn_output)
def build_train_op(self): with tf.variable_scope('training'): decoder_cell = self.decoder_cell targets = self.targets sequence_lengths = self.training_seq_lens training_helper = seq2seq.ScheduledOutputTrainingHelper( targets, sequence_lengths, sampling_probability=self.sampling_rate, next_input_layer=self.projection_layer) decoder = seq2seq.BasicDecoder(decoder_cell, helper=training_helper, initial_state=self.get_zero_state()) output, _, _ = seq2seq.dynamic_decode(decoder, output_time_major=True) predictions = self.projection_layer.apply(output.rnn_output) time_major = tf.transpose(targets, perm=[1, 0, 2]) x_entropy = tf.nn.softmax_cross_entropy_with_logits( labels=time_major, logits=predictions, name='CrossEntropy') loss = tf.reduce_mean(x_entropy) self.train_op = tf.contrib.layers.optimize_loss( loss, tf.contrib.framework.get_global_step(), 0.001, 'Adam')