Exemplo n.º 1
0
class MultiLabelRnnNadeDecoder(BaseLstmDecoder):
    """LSTM decoder with multi-label output provided by a NADE."""
    def build(self, hparams, output_depth, is_training=False):
        self._nade = Nade(output_depth,
                          hparams.nade_num_hidden,
                          name='decoder/nade')
        super(MultiLabelRnnNadeDecoder, self).build(hparams, output_depth,
                                                    is_training)
        # Overwrite output layer for NADE parameterization.
        self._output_layer = layers_core.Dense(self._nade.num_hidden +
                                               output_depth,
                                               name='output_projection')

    def _flat_reconstruction_loss(self, flat_x_target, flat_rnn_output):
        b_enc, b_dec = tf.split(flat_rnn_output,
                                [self._nade.num_hidden, self._output_depth],
                                axis=1)
        ll, cond_probs = self._nade.log_prob(flat_x_target,
                                             b_enc=b_enc,
                                             b_dec=b_dec)
        r_loss = -ll
        flat_truth = tf.cast(flat_x_target, tf.bool)
        flat_predictions = tf.greater_equal(cond_probs, 0.5)

        metric_map = {
            'metrics/accuracy':
            tf.metrics.mean(
                tf.reduce_all(tf.equal(flat_truth, flat_predictions),
                              axis=-1)),
            'metrics/recall':
            tf.metrics.recall(flat_truth, flat_predictions),
            'metrics/precision':
            tf.metrics.precision(flat_truth, flat_predictions),
        }

        return r_loss, metric_map, flat_truth, flat_predictions

    def _sample(self, rnn_output, temperature=None):
        """Sample from NADE, returning the argmax if no temperature is provided."""
        b_enc, b_dec = tf.split(rnn_output,
                                [self._nade.num_hidden, self._output_depth],
                                axis=1)
        sample, _ = self._nade.sample(b_enc=b_enc,
                                      b_dec=b_dec,
                                      temperature=temperature)
        return sample
Exemplo n.º 2
0
class MultiLabelRnnNadeDecoder(BaseLstmDecoder):
  """LSTM decoder with multi-label output provided by a NADE."""

  def build(self, hparams, output_depth, is_training=False):
    self._nade = Nade(
        output_depth, hparams.nade_num_hidden, name='decoder/nade')
    super(MultiLabelRnnNadeDecoder, self).build(
        hparams, output_depth, is_training)
    # Overwrite output layer for NADE parameterization.
    self._output_layer = layers_core.Dense(
        self._nade.num_hidden + output_depth, name='output_projection')

  def _flat_reconstruction_loss(self, flat_x_target, flat_rnn_output):
    b_enc, b_dec = tf.split(
        flat_rnn_output,
        [self._nade.num_hidden, self._output_depth], axis=1)
    ll, cond_probs = self._nade.log_prob(
        flat_x_target, b_enc=b_enc, b_dec=b_dec)
    r_loss = -ll
    flat_truth = tf.cast(flat_x_target, tf.bool)
    flat_predictions = tf.greater_equal(cond_probs, 0.5)

    metric_map = {
        'metrics/accuracy':
            tf.metrics.mean(
                tf.reduce_all(tf.equal(flat_truth, flat_predictions), axis=-1)),
        'metrics/recall':
            tf.metrics.recall(flat_truth, flat_predictions),
        'metrics/precision':
            tf.metrics.precision(flat_truth, flat_predictions),
    }

    return r_loss, metric_map, flat_truth, flat_predictions

  def _sample(self, rnn_output, temperature=None):
    """Sample from NADE, returning the argmax if no temperature is provided."""
    b_enc, b_dec = tf.split(
        rnn_output, [self._nade.num_hidden, self._output_depth], axis=1)
    sample, _ = self._nade.sample(
        b_enc=b_enc, b_dec=b_dec, temperature=temperature)
    return sample
Exemplo n.º 3
0
class RnnNade(object):
    """RNN-NADE [2], a NADE parameterized by an RNN.

  The NADE's bias parameters are given by the output of the RNN.

  [2]: https://arxiv.org/abs/1206.6392

  Args:
    rnn_cell: The tf.contrib.rnn.RnnCell to use.
    num_dims: The number of binary dimensions for each observation.
    num_hidden: The number of hidden units in the NADE.
  """
    def __init__(self, rnn_cell, num_dims, num_hidden):
        self._num_dims = num_dims
        self._rnn_cell = rnn_cell
        self._fc_layer = tf_layers_core.Dense(units=num_dims + num_hidden)
        self._nade = Nade(num_dims, num_hidden)

    def _get_rnn_zero_state(self, batch_size):
        """Return a tensor or tuple of tensors for an initial rnn state."""
        return self._rnn_cell.zero_state(batch_size, tf.float32)

    class SampleNadeLayer(tf_layers_base.Layer):
        """Layer that computes samples from a NADE."""
        def __init__(self, nade, name=None, **kwargs):
            super(RnnNade.SampleNadeLayer, self).__init__(name=name, **kwargs)
            self._nade = nade
            self._empty_result = tf.zeros([0, nade.num_dims])

        def call(self, inputs):
            b_enc, b_dec = tf.split(
                inputs, [self._nade.num_hidden, self._nade.num_dims], axis=1)
            return self._nade.sample(b_enc, b_dec)[0]

    def _get_state(self, inputs, lengths=None, initial_state=None):
        """Computes the state of the RNN-NADE (NADE bias parameters and RNN state).

    Args:
      inputs: A batch of sequences to compute the state from, sized
          `[batch_size, max(lengths), num_dims]` or `[batch_size, num_dims]`.
      lengths: The length of each sequence, sized `[batch_size]`.
      initial_state: An RnnNadeStateTuple, the initial state of the RNN-NADE, or
          None if the zero state should be used.

    Returns:
      final_state: An RnnNadeStateTuple, the final state of the RNN-NADE.
    """
        batch_size = inputs.shape[0].value

        lengths = (tf.tile(tf.shape(inputs)[1:2], [batch_size])
                   if lengths is None else lengths)
        initial_rnn_state = (self._get_rnn_zero_state(batch_size)
                             if initial_state is None else
                             initial_state.rnn_state)

        helper = tf.contrib.seq2seq.TrainingHelper(inputs=inputs,
                                                   sequence_length=lengths)

        decoder = tf.contrib.seq2seq.BasicDecoder(
            cell=self._rnn_cell,
            helper=helper,
            initial_state=initial_rnn_state,
            output_layer=self._fc_layer)

        final_outputs, final_rnn_state = tf.contrib.seq2seq.dynamic_decode(
            decoder)[0:2]

        # Flatten time dimension.
        final_outputs_flat = magenta.common.flatten_maybe_padded_sequences(
            final_outputs.rnn_output, lengths)

        b_enc, b_dec = tf.split(final_outputs_flat,
                                [self._nade.num_hidden, self._nade.num_dims],
                                axis=1)

        return RnnNadeStateTuple(b_enc, b_dec, final_rnn_state)

    def log_prob(self, sequences, lengths=None):
        """Computes the log probability of a sequence of values.

    Flattens the time dimension.

    Args:
      sequences: A batch of sequences to compute the log probabilities of,
          sized `[batch_size, max(lengths), num_dims]`.
      lengths: The length of each sequence, sized `[batch_size]` or None if
          all are equal.

    Returns:
      log_prob: The log probability of each sequence value, sized
          `[sum(lengths), 1]`.
      cond_prob: The conditional probabilities at each non-padded value for
          every batch, sized `[sum(lengths), num_dims]`.
    """
        assert self._num_dims == sequences.shape[2].value

        # Remove last value from input sequences.
        inputs = sequences[:, 0:-1, :]

        # Add initial padding value to input sequences.
        inputs = tf.pad(inputs, [[0, 0], [1, 0], [0, 0]])

        state = self._get_state(inputs, lengths=lengths)

        # Flatten time dimension.
        labels_flat = magenta.common.flatten_maybe_padded_sequences(
            sequences, lengths)

        return self._nade.log_prob(labels_flat, state.b_enc, state.b_dec)

    def steps(self, inputs, state):
        """Computes the new RNN-NADE state from a batch of inputs.

    Args:
      inputs: A batch of values to compute the log probabilities of,
          sized `[batch_size, length, num_dims]`.
      state: An RnnNadeStateTuple containing the RNN-NADE for each value, sized
          `([batch_size, self._nade.num_hidden], [batch_size, num_dims],
            [batch_size, self._rnn_cell.state_size]`).

    Returns:
      new_state: The updated RNN-NADE state tuple given the new inputs.
    """
        return self._get_state(inputs, initial_state=state)

    def sample_single(self, state):
        """Computes a sample and its probability from each of a batch of states.

    Args:
      state: An RnnNadeStateTuple containing the state of the RNN-NADE for each
          sample, sized
          `([batch_size, self._nade.num_hidden], [batch_size, num_dims],
            [batch_size, self._rnn_cell.state_size]`).

    Returns:
      sample: A sample for each input state, sized `[batch_size, num_dims]`.
      log_prob: The log probability of each sample, sized `[batch_size, 1]`.
    """
        sample, log_prob = self._nade.sample(state.b_enc, state.b_dec)

        return sample, log_prob

    def zero_state(self, batch_size):
        """Create an RnnNadeStateTuple of zeros.

    Args:
      batch_size: batch size.

    Returns:
      An RnnNadeStateTuple of zeros.
    """
        with tf.name_scope('RnnNadeZeroState', values=[batch_size]):
            zero_state = self._get_rnn_zero_state(batch_size)
            return RnnNadeStateTuple(
                tf.zeros((batch_size, self._nade.num_hidden), name='b_enc'),
                tf.zeros((batch_size, self._num_dims), name='b_dec'),
                zero_state)
class RnnNade(object):
  """RNN-NADE [2], a NADE parameterized by an RNN.

  The NADE's bias parameters are given by the output of the RNN.

  [2]: https://arxiv.org/abs/1206.6392

  Args:
    rnn_cell: The tf.contrib.rnn.RnnCell to use.
    num_dims: The number of binary dimensions for each observation.
    num_hidden: The number of hidden units in the NADE.
  """

  def __init__(self, rnn_cell, num_dims, num_hidden):
    self._num_dims = num_dims
    self._rnn_cell = rnn_cell
    self._fc_layer = tf_layers_core.Dense(units=num_dims + num_hidden)
    self._nade = Nade(num_dims, num_hidden)

  def _get_rnn_zero_state(self, batch_size):
    """Return a tensor or tuple of tensors for an initial rnn state."""
    return self._rnn_cell.zero_state(batch_size, tf.float32)

  class SampleNadeLayer(tf_layers_base.Layer):
    """Layer that computes samples from a NADE."""

    def __init__(self, nade, name=None, **kwargs):
      super(RnnNade.SampleNadeLayer, self).__init__(name=name, **kwargs)
      self._nade = nade
      self._empty_result = tf.zeros([0, nade.num_dims])

    def call(self, inputs):
      b_enc, b_dec = tf.split(
          inputs, [self._nade.num_hidden, self._nade.num_dims], axis=1)
      return self._nade.sample(b_enc, b_dec)[0]

  def _get_state(self,
                 inputs,
                 lengths=None,
                 initial_state=None):
    """Computes the state of the RNN-NADE (NADE bias parameters and RNN state).

    Args:
      inputs: A batch of sequences to compute the state from, sized
          `[batch_size, max(lengths), num_dims]` or `[batch_size, num_dims]`.
      lengths: The length of each sequence, sized `[batch_size]`.
      initial_state: An RnnNadeStateTuple, the initial state of the RNN-NADE, or
          None if the zero state should be used.

    Returns:
      final_state: An RnnNadeStateTuple, the final state of the RNN-NADE.
    """
    batch_size = inputs.shape[0].value

    lengths = (
        tf.tile(tf.shape(inputs)[1:2], [batch_size]) if lengths is None else
        lengths)
    initial_rnn_state = (
        self._get_rnn_zero_state(batch_size) if initial_state is None else
        initial_state.rnn_state)

    helper = tf.contrib.seq2seq.TrainingHelper(
        inputs=inputs,
        sequence_length=lengths)

    decoder = tf.contrib.seq2seq.BasicDecoder(
        cell=self._rnn_cell,
        helper=helper,
        initial_state=initial_rnn_state,
        output_layer=self._fc_layer)

    final_outputs, final_rnn_state = tf.contrib.seq2seq.dynamic_decode(
        decoder)[0:2]

    # Flatten time dimension.
    final_outputs_flat = magenta.common.flatten_maybe_padded_sequences(
        final_outputs.rnn_output, lengths)

    b_enc, b_dec = tf.split(
        final_outputs_flat, [self._nade.num_hidden, self._nade.num_dims],
        axis=1)

    return RnnNadeStateTuple(b_enc, b_dec, final_rnn_state)

  def log_prob(self, sequences, lengths=None):
    """Computes the log probability of a sequence of values.

    Flattens the time dimension.

    Args:
      sequences: A batch of sequences to compute the log probabilities of,
          sized `[batch_size, max(lengths), num_dims]`.
      lengths: The length of each sequence, sized `[batch_size]` or None if
          all are equal.

    Returns:
      log_prob: The log probability of each sequence value, sized
          `[sum(lengths), 1]`.
      cond_prob: The conditional probabilities at each non-padded value for
          every batch, sized `[sum(lengths), num_dims]`.
    """
    assert self._num_dims == sequences.shape[2].value

    # Remove last value from input sequences.
    inputs = sequences[:, 0:-1, :]

    # Add initial padding value to input sequences.
    inputs = tf.pad(inputs, [[0, 0], [1, 0], [0, 0]])

    state = self._get_state(inputs, lengths=lengths)

    # Flatten time dimension.
    labels_flat = magenta.common.flatten_maybe_padded_sequences(
        sequences, lengths)

    return self._nade.log_prob(labels_flat, state.b_enc, state.b_dec)

  def steps(self, inputs, state):
    """Computes the new RNN-NADE state from a batch of inputs.

    Args:
      inputs: A batch of values to compute the log probabilities of,
          sized `[batch_size, length, num_dims]`.
      state: An RnnNadeStateTuple containing the RNN-NADE for each value, sized
          `([batch_size, self._nade.num_hidden], [batch_size, num_dims],
            [batch_size, self._rnn_cell.state_size]`).

    Returns:
      new_state: The updated RNN-NADE state tuple given the new inputs.
    """
    return self._get_state(inputs, initial_state=state)

  def sample_single(self, state):
    """Computes a sample and its probability from each of a batch of states.

    Args:
      state: An RnnNadeStateTuple containing the state of the RNN-NADE for each
          sample, sized
          `([batch_size, self._nade.num_hidden], [batch_size, num_dims],
            [batch_size, self._rnn_cell.state_size]`).

    Returns:
      sample: A sample for each input state, sized `[batch_size, num_dims]`.
      log_prob: The log probability of each sample, sized `[batch_size, 1]`.
    """
    sample, log_prob = self._nade.sample(state.b_enc, state.b_dec)

    return sample, log_prob

  def zero_state(self, batch_size):
    """Create an RnnNadeStateTuple of zeros.

    Args:
      batch_size: batch size.

    Returns:
      An RnnNadeStateTuple of zeros.
    """
    with tf.name_scope('RnnNadeZeroState', values=[batch_size]):
      zero_state = self._get_rnn_zero_state(batch_size)
      return RnnNadeStateTuple(
          tf.zeros((batch_size, self._nade.num_hidden), name='b_enc'),
          tf.zeros((batch_size, self._num_dims), name='b_dec'),
          zero_state)