def testSetFinal(self): with self.test_session(): sequences = np.arange(40, dtype=np.float32).reshape(4, 5, 2) lengths = np.array([0, 1, 2, 5]) final_values = np.arange(40, 48, dtype=np.float32).reshape(4, 2) expected_result = sequences.copy() for i, l in enumerate(lengths): expected_result[i, l:] = 0.0 expected_result[i, max(0, l - 1)] = final_values[i] expected_result[range(4), np.maximum(0, lengths - 1)] = final_values self.assertAllEqual( expected_result, lstm_utils.set_final(sequences, lengths, final_values, time_major=False).eval()) self.assertAllEqual( np.transpose(expected_result, [1, 0, 2]), lstm_utils.set_final(np.transpose(sequences, [1, 0, 2]), lengths, final_values, time_major=True).eval())
def testSetFinal(self): with self.test_session(): sequences = np.arange(40, dtype=np.float32).reshape(4, 5, 2) lengths = np.array([0, 1, 2, 5]) final_values = np.arange(40, 48, dtype=np.float32).reshape(4, 2) expected_result = sequences.copy() for i, l in enumerate(lengths): expected_result[i, l:] = 0.0 expected_result[i, max(0, l-1)] = final_values[i] expected_result[range(4), np.maximum(0, lengths - 1)] = final_values self.assertAllEqual( expected_result, lstm_utils.set_final( sequences, lengths, final_values, time_major=False).eval()) self.assertAllEqual( np.transpose(expected_result, [1, 0, 2]), lstm_utils.set_final( np.transpose(sequences, [1, 0, 2]), lengths, final_values, time_major=True).eval())
def base_train_fn(embedding, hier_index): """Base function for training hierarchical decoder.""" split_size = self._level_lengths[-1] split_input = hier_input[hier_index] split_target = hier_target[hier_index] split_length = hier_length[hier_index] split_control = ( hier_control[hier_index] if hier_control is not None else None) res = self._core_decoder.reconstruction_loss( split_input, split_target, split_length, embedding, split_control) loss_outputs.append(res) decode_results = res[-1] if self._hierarchical_encoder: # Get the approximate "sample" from the model. # Start with the inputs the RNN saw (excluding the start token). samples = decode_results.rnn_input[:, 1:] # Pad to be the max length. samples = tf.pad( samples, [(0, 0), (0, split_size - tf.shape(samples)[1]), (0, 0)]) samples.set_shape([batch_size, split_size, self._output_depth]) # Set the final value based on the target, since the scheduled sampling # helper does not sample the final value. samples = lstm_utils.set_final( samples, split_length, lstm_utils.get_final(split_target, split_length, time_major=False), time_major=False) # Return the re-encoded sample. return self._hierarchical_encoder.level(0).encode( sequence=samples, sequence_length=split_length) elif self._disable_autoregression: return None else: return tf.concat(tf.nest.flatten(decode_results.final_state), axis=-1)