コード例 #1
0
    def testBranchingSingleBeamEntry(self):
        sequence, state, score = beam_search(
            initial_sequence=[], initial_state=1,
            generate_step_fn=self._generate_step_fn, num_steps=5, beam_size=1,
            branch_factor=32, steps_per_iteration=1)

        # Here the beam search should greedily choose ones.
        self.assertEqual(sequence, [1, 1, 1, 1, 1])
        self.assertEqual(state, 1)
        self.assertEqual(score, 5)
コード例 #2
0
  def testBranchingSingleBeamEntry(self):
    sequence, state, score = beam_search(
        initial_sequence=[], initial_state=1,
        generate_step_fn=self._generate_step_fn, num_steps=5, beam_size=1,
        branch_factor=32, steps_per_iteration=1)

    # Here the beam search should greedily choose ones.
    self.assertEqual(sequence, [1, 1, 1, 1, 1])
    self.assertEqual(state, 1)
    self.assertEqual(score, 5)
コード例 #3
0
    def testNoBranchingMultipleBeamEntries(self):
        sequence, state, score = beam_search(
            initial_sequence=[], initial_state=1,
            generate_step_fn=self._generate_step_fn, num_steps=5, beam_size=32,
            branch_factor=1, steps_per_iteration=1)

        # Here the beam has enough capacity to find the optimal solution without
        # branching.
        self.assertEqual(sequence, [0, 0, 0, 0, 1])
        self.assertEqual(state, 1)
        self.assertEqual(score, 16)
コード例 #4
0
    def testNoBranchingMultipleStepsPerIteration(self):
        sequence, state, score = beam_search(
            initial_sequence=[], initial_state=1,
            generate_step_fn=self._generate_step_fn, num_steps=5, beam_size=1,
            branch_factor=1, steps_per_iteration=2)

        # Like the above case, the counter should never reach one as only a single
        # sequence is ever considered.
        self.assertEqual(sequence, [0, 0, 0, 0, 0])
        self.assertEqual(state, 32)
        self.assertEqual(score, 0)
コード例 #5
0
    def testNoBranchingSingleStepPerIteration(self):
        sequence, state, score = beam_search(
            initial_sequence=[], initial_state=1,
            generate_step_fn=self._generate_step_fn, num_steps=5, beam_size=1,
            branch_factor=1, steps_per_iteration=1)

        # The generator should emit all zeros, as only a single sequence is ever
        # considered so the counter doesn't reach one.
        self.assertEqual(sequence, [0, 0, 0, 0, 0])
        self.assertEqual(state, 32)
        self.assertEqual(score, 0)
コード例 #6
0
  def testNoBranchingMultipleBeamEntries(self):
    sequence, state, score = beam_search(
        initial_sequence=[], initial_state=1,
        generate_step_fn=self._generate_step_fn, num_steps=5, beam_size=32,
        branch_factor=1, steps_per_iteration=1)

    # Here the beam has enough capacity to find the optimal solution without
    # branching.
    self.assertEqual(sequence, [0, 0, 0, 0, 1])
    self.assertEqual(state, 1)
    self.assertEqual(score, 16)
コード例 #7
0
  def testNoBranchingMultipleStepsPerIteration(self):
    sequence, state, score = beam_search(
        initial_sequence=[], initial_state=1,
        generate_step_fn=self._generate_step_fn, num_steps=5, beam_size=1,
        branch_factor=1, steps_per_iteration=2)

    # Like the above case, the counter should never reach one as only a single
    # sequence is ever considered.
    self.assertEqual(sequence, [0, 0, 0, 0, 0])
    self.assertEqual(state, 32)
    self.assertEqual(score, 0)
コード例 #8
0
  def testNoBranchingSingleStepPerIteration(self):
    sequence, state, score = beam_search(
        initial_sequence=[], initial_state=1,
        generate_step_fn=self._generate_step_fn, num_steps=5, beam_size=1,
        branch_factor=1, steps_per_iteration=1)

    # The generator should emit all zeros, as only a single sequence is ever
    # considered so the counter doesn't reach one.
    self.assertEqual(sequence, [0, 0, 0, 0, 0])
    self.assertEqual(state, 32)
    self.assertEqual(score, 0)
コード例 #9
0
    def _generate_events(
            self,
            num_steps,
            primer_events,
            temperature=1.0,
            beam_size=1,
            branch_factor=1,
            steps_per_iteration=1,
            control_events=None,
            control_state=None,
            extend_control_events_callback=(_extend_control_events_default),
            modify_events_callback=None):
        """Generate an event sequence from a primer sequence.

    Args:
      num_steps: The integer length in steps of the final event sequence, after
          generation. Includes the primer.
      primer_events: The primer event sequence, a Python list-like object.
      temperature: A float specifying how much to divide the logits by
         before computing the softmax. Greater than 1.0 makes events more
         random, less than 1.0 makes events less random.
      beam_size: An integer, beam size to use when generating event sequences
          via beam search.
      branch_factor: An integer, beam search branch factor to use.
      steps_per_iteration: An integer, number of steps to take per beam search
          iteration.
      control_events: A sequence of control events upon which to condition the
          generation. If not None, the encoder/decoder should be a
          ConditionalEventSequenceEncoderDecoder, and the control events will be
          used along with the target sequence to generate model inputs. In some
          cases, the control event sequence cannot be fully-determined as later
          control events depend on earlier generated events; use the
          `extend_control_events_callback` argument to provide a function that
          extends the control event sequence.
      control_state: Initial state used by `extend_control_events_callback`.
      extend_control_events_callback: A function that takes three arguments: a
          current control event sequence, a current generated event sequence,
          and the control state. The function should a) extend the control event
          sequence to be one longer than the generated event sequence (or do
          nothing if it is already at least this long), and b) return the
          resulting control state.
      modify_events_callback: An optional callback for modifying the event list.
          Can be used to inject events rather than having them generated. If not
          None, will be called with 3 arguments after every event: the current
          EventSequenceEncoderDecoder, a list of current EventSequences, and a
          list of current encoded event inputs.

    Returns:
      The generated event sequence (which begins with the provided primer).

    Raises:
      EventSequenceRnnModelError: If the primer sequence has zero length or
          is not shorter than num_steps.
    """
        if (control_events is not None
                and not isinstance(self._config.encoder_decoder,
                                   mm.ConditionalEventSequenceEncoderDecoder)):
            raise EventSequenceRnnModelError(
                'control sequence provided but encoder/decoder is not a '
                'ConditionalEventSequenceEncoderDecoder')
        if control_events is not None and extend_control_events_callback is None:
            raise EventSequenceRnnModelError(
                'must provide callback for extending control sequence (or use'
                'default)')

        if not primer_events:
            raise EventSequenceRnnModelError(
                'primer sequence must have non-zero length')
        if len(primer_events) >= num_steps:
            raise EventSequenceRnnModelError(
                'primer sequence must be shorter than `num_steps`')

        if len(primer_events) >= num_steps:
            # Sequence is already long enough, no need to generate.
            return primer_events

        event_sequences = [copy.deepcopy(primer_events)]

        # Construct inputs for first step after primer.
        if control_events is not None:
            # We are conditioning on a control sequence. Make sure it is longer than
            # the primer sequence.
            control_state = extend_control_events_callback(
                control_events, primer_events, control_state)
            inputs = self._config.encoder_decoder.get_inputs_batch(
                [control_events], event_sequences, full_length=True)
        else:
            inputs = self._config.encoder_decoder.get_inputs_batch(
                event_sequences, full_length=True)

        if modify_events_callback:
            # Modify event sequences and inputs for first step after primer.
            modify_events_callback(self._config.encoder_decoder,
                                   event_sequences, inputs)

        graph_initial_state = self._session.graph.get_collection(
            'initial_state')
        initial_states = state_util.unbatch(
            self._session.run(graph_initial_state))

        # Beam search will maintain a state for each sequence consisting of the next
        # inputs to feed the model, and the current RNN state. We start out with the
        # initial full inputs batch and the zero state.
        initial_state = ModelState(inputs=inputs[0],
                                   rnn_state=initial_states[0],
                                   control_events=control_events,
                                   control_state=control_state)

        generate_step_fn = functools.partial(
            self._generate_step,
            temperature=temperature,
            extend_control_events_callback=extend_control_events_callback
            if control_events is not None else None,
            modify_events_callback=modify_events_callback)

        events, _, loglik = beam_search(
            initial_sequence=event_sequences[0],
            initial_state=initial_state,
            generate_step_fn=generate_step_fn,
            num_steps=num_steps - len(primer_events),
            beam_size=beam_size,
            branch_factor=branch_factor,
            steps_per_iteration=steps_per_iteration)

        tf.logging.info('Beam search yields sequence with log-likelihood: %f ',
                        loglik)

        return events
コード例 #10
0
ファイル: events_rnn_model.py プロジェクト: Alice-ren/magenta
  def _generate_events(self, num_steps, primer_events, temperature=1.0,
                       beam_size=1, branch_factor=1, steps_per_iteration=1,
                       control_events=None, control_state=None,
                       extend_control_events_callback=(
                           _extend_control_events_default),
                       modify_events_callback=None):
    """Generate an event sequence from a primer sequence.

    Args:
      num_steps: The integer length in steps of the final event sequence, after
          generation. Includes the primer.
      primer_events: The primer event sequence, a Python list-like object.
      temperature: A float specifying how much to divide the logits by
         before computing the softmax. Greater than 1.0 makes events more
         random, less than 1.0 makes events less random.
      beam_size: An integer, beam size to use when generating event sequences
          via beam search.
      branch_factor: An integer, beam search branch factor to use.
      steps_per_iteration: An integer, number of steps to take per beam search
          iteration.
      control_events: A sequence of control events upon which to condition the
          generation. If not None, the encoder/decoder should be a
          ConditionalEventSequenceEncoderDecoder, and the control events will be
          used along with the target sequence to generate model inputs. In some
          cases, the control event sequence cannot be fully-determined as later
          control events depend on earlier generated events; use the
          `extend_control_events_callback` argument to provide a function that
          extends the control event sequence.
      control_state: Initial state used by `extend_control_events_callback`.
      extend_control_events_callback: A function that takes three arguments: a
          current control event sequence, a current generated event sequence,
          and the control state. The function should a) extend the control event
          sequence to be one longer than the generated event sequence (or do
          nothing if it is already at least this long), and b) return the
          resulting control state.
      modify_events_callback: An optional callback for modifying the event list.
          Can be used to inject events rather than having them generated. If not
          None, will be called with 3 arguments after every event: the current
          EventSequenceEncoderDecoder, a list of current EventSequences, and a
          list of current encoded event inputs.

    Returns:
      The generated event sequence (which begins with the provided primer).

    Raises:
      EventSequenceRnnModelException: If the primer sequence has zero length or
          is not shorter than num_steps.
    """
    if (control_events is not None and
        not isinstance(self._config.encoder_decoder,
                       mm.ConditionalEventSequenceEncoderDecoder)):
      raise EventSequenceRnnModelException(
          'control sequence provided but encoder/decoder is not a '
          'ConditionalEventSequenceEncoderDecoder')
    if control_events is not None and extend_control_events_callback is None:
      raise EventSequenceRnnModelException(
          'must provide callback for extending control sequence (or use'
          'default)')

    if not primer_events:
      raise EventSequenceRnnModelException(
          'primer sequence must have non-zero length')
    if len(primer_events) >= num_steps:
      raise EventSequenceRnnModelException(
          'primer sequence must be shorter than `num_steps`')

    if len(primer_events) >= num_steps:
      # Sequence is already long enough, no need to generate.
      return primer_events

    event_sequences = [copy.deepcopy(primer_events)]

    # Construct inputs for first step after primer.
    if control_events is not None:
      # We are conditioning on a control sequence. Make sure it is longer than
      # the primer sequence.
      control_state = extend_control_events_callback(
          control_events, primer_events, control_state)
      inputs = self._config.encoder_decoder.get_inputs_batch(
          [control_events], event_sequences, full_length=True)
    else:
      inputs = self._config.encoder_decoder.get_inputs_batch(
          event_sequences, full_length=True)

    if modify_events_callback:
      # Modify event sequences and inputs for first step after primer.
      modify_events_callback(
          self._config.encoder_decoder, event_sequences, inputs)

    graph_initial_state = self._session.graph.get_collection('initial_state')
    initial_states = state_util.unbatch(self._session.run(graph_initial_state))

    # Beam search will maintain a state for each sequence consisting of the next
    # inputs to feed the model, and the current RNN state. We start out with the
    # initial full inputs batch and the zero state.
    initial_state = ModelState(
        inputs=inputs[0], rnn_state=initial_states[0],
        control_events=control_events, control_state=control_state)

    events, _, loglik = beam_search(
        initial_sequence=event_sequences[0],
        initial_state=initial_state,
        generate_step_fn=functools.partial(
            self._generate_step,
            temperature=temperature,
            extend_control_events_callback=(
                extend_control_events_callback
                if control_events is not None
                else None),
            modify_events_callback=modify_events_callback),
        num_steps=num_steps - len(primer_events),
        beam_size=beam_size,
        branch_factor=branch_factor,
        steps_per_iteration=steps_per_iteration)

    tf.logging.info('Beam search yields sequence with log-likelihood: %f ',
                    loglik)

    return events