class MultiLabelRnnNadeDecoder(BaseLstmDecoder): """LSTM decoder with multi-label output provided by a NADE.""" def build(self, hparams, output_depth, is_training=False): self._nade = Nade(output_depth, hparams.nade_num_hidden, name='decoder/nade') super(MultiLabelRnnNadeDecoder, self).build(hparams, output_depth, is_training) # Overwrite output layer for NADE parameterization. self._output_layer = layers_core.Dense(self._nade.num_hidden + output_depth, name='output_projection') def _flat_reconstruction_loss(self, flat_x_target, flat_rnn_output): b_enc, b_dec = tf.split(flat_rnn_output, [self._nade.num_hidden, self._output_depth], axis=1) ll, cond_probs = self._nade.log_prob(flat_x_target, b_enc=b_enc, b_dec=b_dec) r_loss = -ll flat_truth = tf.cast(flat_x_target, tf.bool) flat_predictions = tf.greater_equal(cond_probs, 0.5) metric_map = { 'metrics/accuracy': tf.metrics.mean( tf.reduce_all(tf.equal(flat_truth, flat_predictions), axis=-1)), 'metrics/recall': tf.metrics.recall(flat_truth, flat_predictions), 'metrics/precision': tf.metrics.precision(flat_truth, flat_predictions), } return r_loss, metric_map, flat_truth, flat_predictions def _sample(self, rnn_output, temperature=None): """Sample from NADE, returning the argmax if no temperature is provided.""" b_enc, b_dec = tf.split(rnn_output, [self._nade.num_hidden, self._output_depth], axis=1) sample, _ = self._nade.sample(b_enc=b_enc, b_dec=b_dec, temperature=temperature) return sample
class MultiLabelRnnNadeDecoder(BaseLstmDecoder): """LSTM decoder with multi-label output provided by a NADE.""" def build(self, hparams, output_depth, is_training=False): self._nade = Nade( output_depth, hparams.nade_num_hidden, name='decoder/nade') super(MultiLabelRnnNadeDecoder, self).build( hparams, output_depth, is_training) # Overwrite output layer for NADE parameterization. self._output_layer = layers_core.Dense( self._nade.num_hidden + output_depth, name='output_projection') def _flat_reconstruction_loss(self, flat_x_target, flat_rnn_output): b_enc, b_dec = tf.split( flat_rnn_output, [self._nade.num_hidden, self._output_depth], axis=1) ll, cond_probs = self._nade.log_prob( flat_x_target, b_enc=b_enc, b_dec=b_dec) r_loss = -ll flat_truth = tf.cast(flat_x_target, tf.bool) flat_predictions = tf.greater_equal(cond_probs, 0.5) metric_map = { 'metrics/accuracy': tf.metrics.mean( tf.reduce_all(tf.equal(flat_truth, flat_predictions), axis=-1)), 'metrics/recall': tf.metrics.recall(flat_truth, flat_predictions), 'metrics/precision': tf.metrics.precision(flat_truth, flat_predictions), } return r_loss, metric_map, flat_truth, flat_predictions def _sample(self, rnn_output, temperature=None): """Sample from NADE, returning the argmax if no temperature is provided.""" b_enc, b_dec = tf.split( rnn_output, [self._nade.num_hidden, self._output_depth], axis=1) sample, _ = self._nade.sample( b_enc=b_enc, b_dec=b_dec, temperature=temperature) return sample
class RnnNade(object): """RNN-NADE [2], a NADE parameterized by an RNN. The NADE's bias parameters are given by the output of the RNN. [2]: https://arxiv.org/abs/1206.6392 Args: rnn_cell: The tf.contrib.rnn.RnnCell to use. num_dims: The number of binary dimensions for each observation. num_hidden: The number of hidden units in the NADE. """ def __init__(self, rnn_cell, num_dims, num_hidden): self._num_dims = num_dims self._rnn_cell = rnn_cell self._fc_layer = tf_layers_core.Dense(units=num_dims + num_hidden) self._nade = Nade(num_dims, num_hidden) def _get_rnn_zero_state(self, batch_size): """Return a tensor or tuple of tensors for an initial rnn state.""" return self._rnn_cell.zero_state(batch_size, tf.float32) class SampleNadeLayer(tf_layers_base.Layer): """Layer that computes samples from a NADE.""" def __init__(self, nade, name=None, **kwargs): super(RnnNade.SampleNadeLayer, self).__init__(name=name, **kwargs) self._nade = nade self._empty_result = tf.zeros([0, nade.num_dims]) def call(self, inputs): b_enc, b_dec = tf.split( inputs, [self._nade.num_hidden, self._nade.num_dims], axis=1) return self._nade.sample(b_enc, b_dec)[0] def _get_state(self, inputs, lengths=None, initial_state=None): """Computes the state of the RNN-NADE (NADE bias parameters and RNN state). Args: inputs: A batch of sequences to compute the state from, sized `[batch_size, max(lengths), num_dims]` or `[batch_size, num_dims]`. lengths: The length of each sequence, sized `[batch_size]`. initial_state: An RnnNadeStateTuple, the initial state of the RNN-NADE, or None if the zero state should be used. Returns: final_state: An RnnNadeStateTuple, the final state of the RNN-NADE. """ batch_size = inputs.shape[0].value lengths = (tf.tile(tf.shape(inputs)[1:2], [batch_size]) if lengths is None else lengths) initial_rnn_state = (self._get_rnn_zero_state(batch_size) if initial_state is None else initial_state.rnn_state) helper = tf.contrib.seq2seq.TrainingHelper(inputs=inputs, sequence_length=lengths) decoder = tf.contrib.seq2seq.BasicDecoder( cell=self._rnn_cell, helper=helper, initial_state=initial_rnn_state, output_layer=self._fc_layer) final_outputs, final_rnn_state = tf.contrib.seq2seq.dynamic_decode( decoder)[0:2] # Flatten time dimension. final_outputs_flat = magenta.common.flatten_maybe_padded_sequences( final_outputs.rnn_output, lengths) b_enc, b_dec = tf.split(final_outputs_flat, [self._nade.num_hidden, self._nade.num_dims], axis=1) return RnnNadeStateTuple(b_enc, b_dec, final_rnn_state) def log_prob(self, sequences, lengths=None): """Computes the log probability of a sequence of values. Flattens the time dimension. Args: sequences: A batch of sequences to compute the log probabilities of, sized `[batch_size, max(lengths), num_dims]`. lengths: The length of each sequence, sized `[batch_size]` or None if all are equal. Returns: log_prob: The log probability of each sequence value, sized `[sum(lengths), 1]`. cond_prob: The conditional probabilities at each non-padded value for every batch, sized `[sum(lengths), num_dims]`. """ assert self._num_dims == sequences.shape[2].value # Remove last value from input sequences. inputs = sequences[:, 0:-1, :] # Add initial padding value to input sequences. inputs = tf.pad(inputs, [[0, 0], [1, 0], [0, 0]]) state = self._get_state(inputs, lengths=lengths) # Flatten time dimension. labels_flat = magenta.common.flatten_maybe_padded_sequences( sequences, lengths) return self._nade.log_prob(labels_flat, state.b_enc, state.b_dec) def steps(self, inputs, state): """Computes the new RNN-NADE state from a batch of inputs. Args: inputs: A batch of values to compute the log probabilities of, sized `[batch_size, length, num_dims]`. state: An RnnNadeStateTuple containing the RNN-NADE for each value, sized `([batch_size, self._nade.num_hidden], [batch_size, num_dims], [batch_size, self._rnn_cell.state_size]`). Returns: new_state: The updated RNN-NADE state tuple given the new inputs. """ return self._get_state(inputs, initial_state=state) def sample_single(self, state): """Computes a sample and its probability from each of a batch of states. Args: state: An RnnNadeStateTuple containing the state of the RNN-NADE for each sample, sized `([batch_size, self._nade.num_hidden], [batch_size, num_dims], [batch_size, self._rnn_cell.state_size]`). Returns: sample: A sample for each input state, sized `[batch_size, num_dims]`. log_prob: The log probability of each sample, sized `[batch_size, 1]`. """ sample, log_prob = self._nade.sample(state.b_enc, state.b_dec) return sample, log_prob def zero_state(self, batch_size): """Create an RnnNadeStateTuple of zeros. Args: batch_size: batch size. Returns: An RnnNadeStateTuple of zeros. """ with tf.name_scope('RnnNadeZeroState', values=[batch_size]): zero_state = self._get_rnn_zero_state(batch_size) return RnnNadeStateTuple( tf.zeros((batch_size, self._nade.num_hidden), name='b_enc'), tf.zeros((batch_size, self._num_dims), name='b_dec'), zero_state)
class RnnNade(object): """RNN-NADE [2], a NADE parameterized by an RNN. The NADE's bias parameters are given by the output of the RNN. [2]: https://arxiv.org/abs/1206.6392 Args: rnn_cell: The tf.contrib.rnn.RnnCell to use. num_dims: The number of binary dimensions for each observation. num_hidden: The number of hidden units in the NADE. """ def __init__(self, rnn_cell, num_dims, num_hidden): self._num_dims = num_dims self._rnn_cell = rnn_cell self._fc_layer = tf_layers_core.Dense(units=num_dims + num_hidden) self._nade = Nade(num_dims, num_hidden) def _get_rnn_zero_state(self, batch_size): """Return a tensor or tuple of tensors for an initial rnn state.""" return self._rnn_cell.zero_state(batch_size, tf.float32) class SampleNadeLayer(tf_layers_base.Layer): """Layer that computes samples from a NADE.""" def __init__(self, nade, name=None, **kwargs): super(RnnNade.SampleNadeLayer, self).__init__(name=name, **kwargs) self._nade = nade self._empty_result = tf.zeros([0, nade.num_dims]) def call(self, inputs): b_enc, b_dec = tf.split( inputs, [self._nade.num_hidden, self._nade.num_dims], axis=1) return self._nade.sample(b_enc, b_dec)[0] def _get_state(self, inputs, lengths=None, initial_state=None): """Computes the state of the RNN-NADE (NADE bias parameters and RNN state). Args: inputs: A batch of sequences to compute the state from, sized `[batch_size, max(lengths), num_dims]` or `[batch_size, num_dims]`. lengths: The length of each sequence, sized `[batch_size]`. initial_state: An RnnNadeStateTuple, the initial state of the RNN-NADE, or None if the zero state should be used. Returns: final_state: An RnnNadeStateTuple, the final state of the RNN-NADE. """ batch_size = inputs.shape[0].value lengths = ( tf.tile(tf.shape(inputs)[1:2], [batch_size]) if lengths is None else lengths) initial_rnn_state = ( self._get_rnn_zero_state(batch_size) if initial_state is None else initial_state.rnn_state) helper = tf.contrib.seq2seq.TrainingHelper( inputs=inputs, sequence_length=lengths) decoder = tf.contrib.seq2seq.BasicDecoder( cell=self._rnn_cell, helper=helper, initial_state=initial_rnn_state, output_layer=self._fc_layer) final_outputs, final_rnn_state = tf.contrib.seq2seq.dynamic_decode( decoder)[0:2] # Flatten time dimension. final_outputs_flat = magenta.common.flatten_maybe_padded_sequences( final_outputs.rnn_output, lengths) b_enc, b_dec = tf.split( final_outputs_flat, [self._nade.num_hidden, self._nade.num_dims], axis=1) return RnnNadeStateTuple(b_enc, b_dec, final_rnn_state) def log_prob(self, sequences, lengths=None): """Computes the log probability of a sequence of values. Flattens the time dimension. Args: sequences: A batch of sequences to compute the log probabilities of, sized `[batch_size, max(lengths), num_dims]`. lengths: The length of each sequence, sized `[batch_size]` or None if all are equal. Returns: log_prob: The log probability of each sequence value, sized `[sum(lengths), 1]`. cond_prob: The conditional probabilities at each non-padded value for every batch, sized `[sum(lengths), num_dims]`. """ assert self._num_dims == sequences.shape[2].value # Remove last value from input sequences. inputs = sequences[:, 0:-1, :] # Add initial padding value to input sequences. inputs = tf.pad(inputs, [[0, 0], [1, 0], [0, 0]]) state = self._get_state(inputs, lengths=lengths) # Flatten time dimension. labels_flat = magenta.common.flatten_maybe_padded_sequences( sequences, lengths) return self._nade.log_prob(labels_flat, state.b_enc, state.b_dec) def steps(self, inputs, state): """Computes the new RNN-NADE state from a batch of inputs. Args: inputs: A batch of values to compute the log probabilities of, sized `[batch_size, length, num_dims]`. state: An RnnNadeStateTuple containing the RNN-NADE for each value, sized `([batch_size, self._nade.num_hidden], [batch_size, num_dims], [batch_size, self._rnn_cell.state_size]`). Returns: new_state: The updated RNN-NADE state tuple given the new inputs. """ return self._get_state(inputs, initial_state=state) def sample_single(self, state): """Computes a sample and its probability from each of a batch of states. Args: state: An RnnNadeStateTuple containing the state of the RNN-NADE for each sample, sized `([batch_size, self._nade.num_hidden], [batch_size, num_dims], [batch_size, self._rnn_cell.state_size]`). Returns: sample: A sample for each input state, sized `[batch_size, num_dims]`. log_prob: The log probability of each sample, sized `[batch_size, 1]`. """ sample, log_prob = self._nade.sample(state.b_enc, state.b_dec) return sample, log_prob def zero_state(self, batch_size): """Create an RnnNadeStateTuple of zeros. Args: batch_size: batch size. Returns: An RnnNadeStateTuple of zeros. """ with tf.name_scope('RnnNadeZeroState', values=[batch_size]): zero_state = self._get_rnn_zero_state(batch_size) return RnnNadeStateTuple( tf.zeros((batch_size, self._nade.num_hidden), name='b_enc'), tf.zeros((batch_size, self._num_dims), name='b_dec'), zero_state)