示例#1
0
  def _generate_step(self, event_sequences, inputs, initial_states,
                     temperature):
    """Extends a list of event sequences by a single step each.

    This method modifies the event sequences in place.

    Args:
      event_sequences: A list of event sequence objects, which are extended by
          this method.
      inputs: A Python list of model inputs, with length equal to the number of
          event sequences.
      initial_states: A collection of structures for the initial RNN states,
          one for each event sequence.
      temperature: The softmax temperature.

    Returns:
      final_states: The final RNN states, a list the same size as
          `initial_states`.
      loglik: The log-likelihood of the chosen softmax value for each event
          sequence, a 1-D numpy array of length
          `self._batch_size()`. If `inputs` is a full-length inputs batch, the
          log-likelihood of each entire sequence up to and including the
          generated step will be computed and returned.
    """
    # Split the sequences to extend into batches matching the model batch size.
    batch_size = self._batch_size()
    num_seqs = len(event_sequences)
    num_batches = int(np.ceil(num_seqs / float(batch_size)))

    final_states = []
    loglik = np.empty(num_seqs)

    # Add padding to fill the final batch.
    pad_amt = -len(event_sequences) % batch_size
    padded_event_sequences = event_sequences + [
        copy.deepcopy(event_sequences[-1]) for _ in range(pad_amt)]
    padded_inputs = inputs + [inputs[-1]] * pad_amt
    padded_initial_states = initial_states + [initial_states[-1]] * pad_amt

    for b in range(num_batches):
      i, j = b * batch_size, (b + 1) * batch_size
      pad_amt = max(0, j - num_seqs)
      # Generate a single step for one batch of event sequences.
      batch_final_state, batch_loglik = self._generate_step_for_batch(
          padded_event_sequences[i:j],
          padded_inputs[i:j],
          state_util.batch(padded_initial_states[i:j], batch_size),
          temperature)
      final_states += state_util.unbatch(
          batch_final_state, batch_size)[:j - i - pad_amt]
      loglik[i:j - pad_amt] = batch_loglik[:j - i - pad_amt]

    return final_states, loglik
示例#2
0
  def test_Unbatch(self):
    unbatched_states = state_util.unbatch(self._batched_states, batch_size=2)

    self._assert_sructures_equal(self._unbatched_states, unbatched_states)
示例#3
0
  def _beam_search(self, events, num_steps, temperature, beam_size,
                   branch_factor, steps_per_iteration, control_events=None,
                   modify_events_callback=None):
    """Generates an event sequence using beam search.

    Initially, the beam is filled with `beam_size` copies of the initial event
    sequence.

    Each iteration, the beam is pruned to contain only the `beam_size` event
    sequences with highest likelihood. Then `branch_factor` new event sequences
    are generated for each sequence in the beam. These new sequences are formed
    by extending each sequence in the beam by `steps_per_iteration` steps. So
    between a branching and a pruning phase, there will be `beam_size` *
    `branch_factor` active event sequences.

    Prior to the first "real" iteration, an initial branch generation will take
    place. This is for two reasons:

    1) The RNN model needs to be "primed" with the initial event sequence.
    2) The desired total number of steps `num_steps` might not be a multiple of
       `steps_per_iteration`, so the initial branching generates steps such that
       all subsequent iterations can generate `steps_per_iteration` steps.

    After the final iteration, the single event sequence in the beam with
    highest likelihood will be returned.

    Args:
      events: The initial event sequence, a Python list-like object.
      num_steps: The integer length in steps of the final event sequence, after
          generation.
      temperature: A float specifying how much to divide the logits by
         before computing the softmax. Greater than 1.0 makes events more
         random, less than 1.0 makes events less random.
      beam_size: The integer beam size to use.
      branch_factor: The integer branch factor to use.
      steps_per_iteration: The integer number of steps to take per iteration.
      control_events: A sequence of control events upon which to condition the
          generation. If not None, the encoder/decoder should be a
          ConditionalEventSequenceEncoderDecoder, and the control events will be
          used along with the target sequence to generate model inputs.
      modify_events_callback: An optional callback for modifying the event list.
          Can be used to inject events rather than having them generated. If not
          None, will be called with 3 arguments after every event: the current
          EventSequenceEncoderDecoder, a list of current EventSequences, and a
          list of current encoded event inputs.

    Returns:
      The highest-likelihood event sequence as computed by the beam search.
    """
    event_sequences = [copy.deepcopy(events) for _ in range(beam_size)]
    graph_initial_state = self._session.graph.get_collection('initial_state')
    loglik = np.zeros(beam_size)

    # Choose the number of steps for the first iteration such that subsequent
    # iterations can all take the same number of steps.
    first_iteration_num_steps = (num_steps - 1) % steps_per_iteration + 1

    if control_events is not None:
      # We are conditioning on a control sequence.
      inputs = self._config.encoder_decoder.get_inputs_batch(
          control_events, event_sequences, full_length=True)
    else:
      inputs = self._config.encoder_decoder.get_inputs_batch(
          event_sequences, full_length=True)

    if modify_events_callback:
      modify_events_callback(
          self._config.encoder_decoder, event_sequences, inputs)

    zero_state = state_util.unbatch(self._session.run(graph_initial_state))[0]
    initial_states = [zero_state] * beam_size
    event_sequences, final_state, loglik = self._generate_branches(
        event_sequences, loglik, branch_factor, first_iteration_num_steps,
        inputs, initial_states, temperature)

    num_iterations = (num_steps -
                      first_iteration_num_steps) / steps_per_iteration

    for _ in range(num_iterations):
      event_sequences, final_state, loglik = self._prune_branches(
          event_sequences, final_state, loglik, k=beam_size)
      if control_events is not None:
        # We are conditioning on a control sequence.
        inputs = self._config.encoder_decoder.get_inputs_batch(
            control_events, event_sequences)
      else:
        inputs = self._config.encoder_decoder.get_inputs_batch(event_sequences)

      if modify_events_callback:
        modify_events_callback(
            self._config.encoder_decoder, event_sequences, inputs)

      event_sequences, final_state, loglik = self._generate_branches(
          event_sequences, loglik, branch_factor, steps_per_iteration, inputs,
          final_state, temperature)

    # Prune to a single sequence.
    event_sequences, final_state, loglik = self._prune_branches(
        event_sequences, final_state, loglik, k=1)

    tf.logging.info('Beam search yields sequence with log-likelihood: %f ',
                    loglik[0])

    return event_sequences[0]
示例#4
0
    def _generate_step(self,
                       event_sequences,
                       model_states,
                       logliks,
                       temperature,
                       extend_control_events_callback=None,
                       modify_events_callback=None):
        """Extends a list of event sequences by a single step each.

    This method modifies the event sequences in place. It also returns the
    modified event sequences and updated model states and log-likelihoods.

    Args:
      event_sequences: A list of event sequence objects, which are extended by
          this method.
      model_states: A list of model states, each of which contains model inputs
          and initial RNN states.
      logliks: A list containing the current log-likelihood for each event
          sequence.
      temperature: The softmax temperature.
      extend_control_events_callback: A function that takes three arguments: a
          current control event sequence, a current generated event sequence,
          and the control state. The function should a) extend the control event
          sequence to be one longer than the generated event sequence (or do
          nothing if it is already at least this long), and b) return the
          resulting control state.
      modify_events_callback: An optional callback for modifying the event list.
          Can be used to inject events rather than having them generated. If not
          None, will be called with 3 arguments after every event: the current
          EventSequenceEncoderDecoder, a list of current EventSequences, and a
          list of current encoded event inputs.

    Returns:
      event_sequences: A list of extended event sequences. These are modified in
          place but also returned.
      final_states: A list of resulting model states, containing model inputs
          for the next step along with RNN states for each event sequence.
      logliks: A list containing the updated log-likelihood for each event
          sequence.
    """
        # Split the sequences to extend into batches matching the model batch size.
        batch_size = self._batch_size()
        num_seqs = len(event_sequences)
        num_batches = int(np.ceil(num_seqs / float(batch_size)))

        # Extract inputs and RNN states from the model states.
        inputs = [model_state.inputs for model_state in model_states]
        initial_states = [
            model_state.rnn_state for model_state in model_states
        ]

        # Also extract control sequences and states.
        control_sequences = [
            model_state.control_events for model_state in model_states
        ]
        control_states = [
            model_state.control_state for model_state in model_states
        ]

        final_states = []
        logliks = np.array(logliks, dtype=np.float32)

        # Add padding to fill the final batch.
        pad_amt = -len(event_sequences) % batch_size
        padded_event_sequences = event_sequences + [
            copy.deepcopy(event_sequences[-1]) for _ in range(pad_amt)
        ]
        padded_inputs = inputs + [inputs[-1]] * pad_amt
        padded_initial_states = initial_states + [initial_states[-1]] * pad_amt

        for b in range(num_batches):
            i, j = b * batch_size, (b + 1) * batch_size
            pad_amt = max(0, j - num_seqs)
            # Generate a single step for one batch of event sequences.
            batch_final_state, batch_loglik = self._generate_step_for_batch(
                padded_event_sequences[i:j], padded_inputs[i:j],
                state_util.batch(padded_initial_states[i:j], batch_size),
                temperature)
            final_states += state_util.unbatch(batch_final_state,
                                               batch_size)[:j - i - pad_amt]
            logliks[i:j - pad_amt] += batch_loglik[:j - i - pad_amt]

        # Construct inputs for next step.
        if extend_control_events_callback is not None:
            # We are conditioning on control sequences.
            for idx in range(len(control_sequences)):
                # Extend each control sequence to ensure that it is longer than the
                # corresponding event sequence.
                control_states[idx] = extend_control_events_callback(
                    control_sequences[idx], event_sequences[idx],
                    control_states[idx])
            next_inputs = self._config.encoder_decoder.get_inputs_batch(
                control_sequences, event_sequences)
        else:
            next_inputs = self._config.encoder_decoder.get_inputs_batch(
                event_sequences)

        if modify_events_callback:
            # Modify event sequences and inputs for next step.
            modify_events_callback(self._config.encoder_decoder,
                                   event_sequences, next_inputs)

        model_states = [
            ModelState(inputs=inputs,
                       rnn_state=final_state,
                       control_events=control_events,
                       control_state=control_state)
            for inputs, final_state, control_events, control_state in zip(
                next_inputs, final_states, control_sequences, control_states)
        ]

        return event_sequences, model_states, logliks
示例#5
0
  def _generate_events(self, num_steps, primer_events, temperature=1.0,
                       beam_size=1, branch_factor=1, steps_per_iteration=1,
                       control_events=None, control_state=None,
                       extend_control_events_callback=(
                           _extend_control_events_default),
                       modify_events_callback=None):
    """Generate an event sequence from a primer sequence.

    Args:
      num_steps: The integer length in steps of the final event sequence, after
          generation. Includes the primer.
      primer_events: The primer event sequence, a Python list-like object.
      temperature: A float specifying how much to divide the logits by
         before computing the softmax. Greater than 1.0 makes events more
         random, less than 1.0 makes events less random.
      beam_size: An integer, beam size to use when generating event sequences
          via beam search.
      branch_factor: An integer, beam search branch factor to use.
      steps_per_iteration: An integer, number of steps to take per beam search
          iteration.
      control_events: A sequence of control events upon which to condition the
          generation. If not None, the encoder/decoder should be a
          ConditionalEventSequenceEncoderDecoder, and the control events will be
          used along with the target sequence to generate model inputs. In some
          cases, the control event sequence cannot be fully-determined as later
          control events depend on earlier generated events; use the
          `extend_control_events_callback` argument to provide a function that
          extends the control event sequence.
      control_state: Initial state used by `extend_control_events_callback`.
      extend_control_events_callback: A function that takes three arguments: a
          current control event sequence, a current generated event sequence,
          and the control state. The function should a) extend the control event
          sequence to be one longer than the generated event sequence (or do
          nothing if it is already at least this long), and b) return the
          resulting control state.
      modify_events_callback: An optional callback for modifying the event list.
          Can be used to inject events rather than having them generated. If not
          None, will be called with 3 arguments after every event: the current
          EventSequenceEncoderDecoder, a list of current EventSequences, and a
          list of current encoded event inputs.

    Returns:
      The generated event sequence (which begins with the provided primer).

    Raises:
      EventSequenceRnnModelError: If the primer sequence has zero length or
          is not shorter than num_steps.
    """
    if (control_events is not None and
        not isinstance(self._config.encoder_decoder,
                       mm.ConditionalEventSequenceEncoderDecoder)):
      raise EventSequenceRnnModelError(
          'control sequence provided but encoder/decoder is not a '
          'ConditionalEventSequenceEncoderDecoder')
    if control_events is not None and extend_control_events_callback is None:
      raise EventSequenceRnnModelError(
          'must provide callback for extending control sequence (or use'
          'default)')

    if not primer_events:
      raise EventSequenceRnnModelError(
          'primer sequence must have non-zero length')
    if len(primer_events) >= num_steps:
      raise EventSequenceRnnModelError(
          'primer sequence must be shorter than `num_steps`')

    if len(primer_events) >= num_steps:
      # Sequence is already long enough, no need to generate.
      return primer_events

    event_sequences = [copy.deepcopy(primer_events)]

    # Construct inputs for first step after primer.
    if control_events is not None:
      # We are conditioning on a control sequence. Make sure it is longer than
      # the primer sequence.
      control_state = extend_control_events_callback(
          control_events, primer_events, control_state)
      inputs = self._config.encoder_decoder.get_inputs_batch(
          [control_events], event_sequences, full_length=True)
    else:
      inputs = self._config.encoder_decoder.get_inputs_batch(
          event_sequences, full_length=True)

    if modify_events_callback:
      # Modify event sequences and inputs for first step after primer.
      modify_events_callback(
          self._config.encoder_decoder, event_sequences, inputs)

    graph_initial_state = self._session.graph.get_collection('initial_state')
    initial_states = state_util.unbatch(self._session.run(graph_initial_state))

    # Beam search will maintain a state for each sequence consisting of the next
    # inputs to feed the model, and the current RNN state. We start out with the
    # initial full inputs batch and the zero state.
    initial_state = ModelState(
        inputs=inputs[0], rnn_state=initial_states[0],
        control_events=control_events, control_state=control_state)

    generate_step_fn = functools.partial(
        self._generate_step,
        temperature=temperature,
        extend_control_events_callback=
        extend_control_events_callback if control_events is not None else None,
        modify_events_callback=modify_events_callback)

    if self._session.graph.get_collection('composer_softmax'):
      events, _, loglik, states, composer_softmax = beam_search(
          initial_sequence=event_sequences[0],
          initial_state=initial_state,
          generate_step_fn=generate_step_fn,
          num_steps=num_steps - len(primer_events),
          beam_size=beam_size,
          branch_factor=branch_factor,
          steps_per_iteration=steps_per_iteration,
          composer=True)

      tf.logging.info('Beam search yields sequence with log-likelihood: %f ',
                    loglik)
      return events, states, composer_softmax
    else:
      events, _, loglik, states = beam_search(
          initial_sequence=event_sequences[0],
          initial_state=initial_state,
          generate_step_fn=generate_step_fn,
          num_steps=num_steps - len(primer_events),
          beam_size=beam_size,
          branch_factor=branch_factor,
          steps_per_iteration=steps_per_iteration)

      tf.logging.info('Beam search yields sequence with log-likelihood: %f ',
                    loglik)
      return events, states, None
示例#6
0
    def test_Unbatch(self):
        unbatched_states = state_util.unbatch(self._batched_states,
                                              batch_size=2)

        self._assert_sructures_equal(self._unbatched_states, unbatched_states)
示例#7
0
    def _beam_search(self,
                     events,
                     num_steps,
                     temperature,
                     beam_size,
                     branch_factor,
                     steps_per_iteration,
                     control_events=None,
                     modify_events_callback=None):
        """Generates an event sequence using beam search.

    Initially, the beam is filled with `beam_size` copies of the initial event
    sequence.

    Each iteration, the beam is pruned to contain only the `beam_size` event
    sequences with highest likelihood. Then `branch_factor` new event sequences
    are generated for each sequence in the beam. These new sequences are formed
    by extending each sequence in the beam by `steps_per_iteration` steps. So
    between a branching and a pruning phase, there will be `beam_size` *
    `branch_factor` active event sequences.

    Prior to the first "real" iteration, an initial branch generation will take
    place. This is for two reasons:

    1) The RNN model needs to be "primed" with the initial event sequence.
    2) The desired total number of steps `num_steps` might not be a multiple of
       `steps_per_iteration`, so the initial branching generates steps such that
       all subsequent iterations can generate `steps_per_iteration` steps.

    After the final iteration, the single event sequence in the beam with
    highest likelihood will be returned.

    Args:
      events: The initial event sequence, a Python list-like object.
      num_steps: The integer length in steps of the final event sequence, after
          generation.
      temperature: A float specifying how much to divide the logits by
         before computing the softmax. Greater than 1.0 makes events more
         random, less than 1.0 makes events less random.
      beam_size: The integer beam size to use.
      branch_factor: The integer branch factor to use.
      steps_per_iteration: The integer number of steps to take per iteration.
      control_events: A sequence of control events upon which to condition the
          generation. If not None, the encoder/decoder should be a
          ConditionalEventSequenceEncoderDecoder, and the control events will be
          used along with the target sequence to generate model inputs.
      modify_events_callback: An optional callback for modifying the event list.
          Can be used to inject events rather than having them generated. If not
          None, will be called with 3 arguments after every event: the current
          EventSequenceEncoderDecoder, a list of current EventSequences, and a
          list of current encoded event inputs.

    Returns:
      The highest-likelihood event sequence as computed by the beam search.
    """
        event_sequences = [copy.deepcopy(events) for _ in range(beam_size)]
        graph_initial_state = self._session.graph.get_collection(
            'initial_state')
        loglik = np.zeros(beam_size)

        # Choose the number of steps for the first iteration such that subsequent
        # iterations can all take the same number of steps.
        first_iteration_num_steps = (num_steps - 1) % steps_per_iteration + 1

        if control_events is not None:
            # We are conditioning on a control sequence.
            inputs = self._config.encoder_decoder.get_inputs_batch(
                control_events, event_sequences, full_length=True)
        else:
            inputs = self._config.encoder_decoder.get_inputs_batch(
                event_sequences, full_length=True)

        if modify_events_callback:
            modify_events_callback(self._config.encoder_decoder,
                                   event_sequences, inputs)

        zero_state = state_util.unbatch(
            self._session.run(graph_initial_state))[0]
        initial_states = [zero_state] * beam_size
        event_sequences, final_state, loglik = self._generate_branches(
            event_sequences, loglik, branch_factor, first_iteration_num_steps,
            inputs, initial_states, temperature)

        num_iterations = (num_steps -
                          first_iteration_num_steps) / steps_per_iteration

        for _ in range(num_iterations):
            event_sequences, final_state, loglik = self._prune_branches(
                event_sequences, final_state, loglik, k=beam_size)
            if control_events is not None:
                # We are conditioning on a control sequence.
                inputs = self._config.encoder_decoder.get_inputs_batch(
                    control_events, event_sequences)
            else:
                inputs = self._config.encoder_decoder.get_inputs_batch(
                    event_sequences)

            if modify_events_callback:
                modify_events_callback(self._config.encoder_decoder,
                                       event_sequences, inputs)

            event_sequences, final_state, loglik = self._generate_branches(
                event_sequences, loglik, branch_factor, steps_per_iteration,
                inputs, final_state, temperature)

        # Prune to a single sequence.
        event_sequences, final_state, loglik = self._prune_branches(
            event_sequences, final_state, loglik, k=1)

        tf.logging.info('Beam search yields sequence with log-likelihood: %f ',
                        loglik[0])

        return event_sequences[0]
示例#8
0
  def _generate_events(self, num_steps, primer_events, temperature=1.0,
                       beam_size=1, branch_factor=1, steps_per_iteration=1,
                       control_events=None, control_state=None,
                       extend_control_events_callback=(
                           _extend_control_events_default),
                       modify_events_callback=None):
    """Generate an event sequence from a primer sequence.

    Args:
      num_steps: The integer length in steps of the final event sequence, after
          generation. Includes the primer.
      primer_events: The primer event sequence, a Python list-like object.
      temperature: A float specifying how much to divide the logits by
         before computing the softmax. Greater than 1.0 makes events more
         random, less than 1.0 makes events less random.
      beam_size: An integer, beam size to use when generating event sequences
          via beam search.
      branch_factor: An integer, beam search branch factor to use.
      steps_per_iteration: An integer, number of steps to take per beam search
          iteration.
      control_events: A sequence of control events upon which to condition the
          generation. If not None, the encoder/decoder should be a
          ConditionalEventSequenceEncoderDecoder, and the control events will be
          used along with the target sequence to generate model inputs. In some
          cases, the control event sequence cannot be fully-determined as later
          control events depend on earlier generated events; use the
          `extend_control_events_callback` argument to provide a function that
          extends the control event sequence.
      control_state: Initial state used by `extend_control_events_callback`.
      extend_control_events_callback: A function that takes three arguments: a
          current control event sequence, a current generated event sequence,
          and the control state. The function should a) extend the control event
          sequence to be one longer than the generated event sequence (or do
          nothing if it is already at least this long), and b) return the
          resulting control state.
      modify_events_callback: An optional callback for modifying the event list.
          Can be used to inject events rather than having them generated. If not
          None, will be called with 3 arguments after every event: the current
          EventSequenceEncoderDecoder, a list of current EventSequences, and a
          list of current encoded event inputs.

    Returns:
      The generated event sequence (which begins with the provided primer).

    Raises:
      EventSequenceRnnModelException: If the primer sequence has zero length or
          is not shorter than num_steps.
    """
    if (control_events is not None and
        not isinstance(self._config.encoder_decoder,
                       mm.ConditionalEventSequenceEncoderDecoder)):
      raise EventSequenceRnnModelException(
          'control sequence provided but encoder/decoder is not a '
          'ConditionalEventSequenceEncoderDecoder')
    if control_events is not None and extend_control_events_callback is None:
      raise EventSequenceRnnModelException(
          'must provide callback for extending control sequence (or use'
          'default)')

    if not primer_events:
      raise EventSequenceRnnModelException(
          'primer sequence must have non-zero length')
    if len(primer_events) >= num_steps:
      raise EventSequenceRnnModelException(
          'primer sequence must be shorter than `num_steps`')

    if len(primer_events) >= num_steps:
      # Sequence is already long enough, no need to generate.
      return primer_events

    event_sequences = [copy.deepcopy(primer_events)]

    # Construct inputs for first step after primer.
    if control_events is not None:
      # We are conditioning on a control sequence. Make sure it is longer than
      # the primer sequence.
      control_state = extend_control_events_callback(
          control_events, primer_events, control_state)
      inputs = self._config.encoder_decoder.get_inputs_batch(
          [control_events], event_sequences, full_length=True)
    else:
      inputs = self._config.encoder_decoder.get_inputs_batch(
          event_sequences, full_length=True)

    if modify_events_callback:
      # Modify event sequences and inputs for first step after primer.
      modify_events_callback(
          self._config.encoder_decoder, event_sequences, inputs)

    graph_initial_state = self._session.graph.get_collection('initial_state')
    initial_states = state_util.unbatch(self._session.run(graph_initial_state))

    # Beam search will maintain a state for each sequence consisting of the next
    # inputs to feed the model, and the current RNN state. We start out with the
    # initial full inputs batch and the zero state.
    initial_state = ModelState(
        inputs=inputs[0], rnn_state=initial_states[0],
        control_events=control_events, control_state=control_state)

    events, _, loglik = beam_search(
        initial_sequence=event_sequences[0],
        initial_state=initial_state,
        generate_step_fn=functools.partial(
            self._generate_step,
            temperature=temperature,
            extend_control_events_callback=(
                extend_control_events_callback
                if control_events is not None
                else None),
            modify_events_callback=modify_events_callback),
        num_steps=num_steps - len(primer_events),
        beam_size=beam_size,
        branch_factor=branch_factor,
        steps_per_iteration=steps_per_iteration)

    tf.logging.info('Beam search yields sequence with log-likelihood: %f ',
                    loglik)

    return events
示例#9
0
  def _generate_step(self, event_sequences, model_states, logliks, temperature,
                     extend_control_events_callback=None,
                     modify_events_callback=None):
    """Extends a list of event sequences by a single step each.

    This method modifies the event sequences in place. It also returns the
    modified event sequences and updated model states and log-likelihoods.

    Args:
      event_sequences: A list of event sequence objects, which are extended by
          this method.
      model_states: A list of model states, each of which contains model inputs
          and initial RNN states.
      logliks: A list containing the current log-likelihood for each event
          sequence.
      temperature: The softmax temperature.
      extend_control_events_callback: A function that takes three arguments: a
          current control event sequence, a current generated event sequence,
          and the control state. The function should a) extend the control event
          sequence to be one longer than the generated event sequence (or do
          nothing if it is already at least this long), and b) return the
          resulting control state.
      modify_events_callback: An optional callback for modifying the event list.
          Can be used to inject events rather than having them generated. If not
          None, will be called with 3 arguments after every event: the current
          EventSequenceEncoderDecoder, a list of current EventSequences, and a
          list of current encoded event inputs.

    Returns:
      event_sequences: A list of extended event sequences. These are modified in
          place but also returned.
      final_states: A list of resulting model states, containing model inputs
          for the next step along with RNN states for each event sequence.
      logliks: A list containing the updated log-likelihood for each event
          sequence.
    """
    # Split the sequences to extend into batches matching the model batch size.
    batch_size = self._batch_size()
    num_seqs = len(event_sequences)
    num_batches = int(np.ceil(num_seqs / float(batch_size)))

    # Extract inputs and RNN states from the model states.
    inputs = [model_state.inputs for model_state in model_states]
    initial_states = [model_state.rnn_state for model_state in model_states]

    # Also extract control sequences and states.
    control_sequences = [
        model_state.control_events for model_state in model_states]
    control_states = [
        model_state.control_state for model_state in model_states]

    final_states = []
    logliks = np.array(logliks, dtype=np.float32)

    # Add padding to fill the final batch.
    pad_amt = -len(event_sequences) % batch_size
    padded_event_sequences = event_sequences + [
        copy.deepcopy(event_sequences[-1]) for _ in range(pad_amt)]
    padded_inputs = inputs + [inputs[-1]] * pad_amt
    padded_initial_states = initial_states + [initial_states[-1]] * pad_amt

    for b in range(num_batches):
      i, j = b * batch_size, (b + 1) * batch_size
      pad_amt = max(0, j - num_seqs)
      # Generate a single step for one batch of event sequences.
      batch_final_state, batch_loglik = self._generate_step_for_batch(
          padded_event_sequences[i:j],
          padded_inputs[i:j],
          state_util.batch(padded_initial_states[i:j], batch_size),
          temperature)
      final_states += state_util.unbatch(
          batch_final_state, batch_size)[:j - i - pad_amt]
      logliks[i:j - pad_amt] += batch_loglik[:j - i - pad_amt]

    # Construct inputs for next step.
    if extend_control_events_callback is not None:
      # We are conditioning on control sequences.
      for idx in range(len(control_sequences)):
        # Extend each control sequence to ensure that it is longer than the
        # corresponding event sequence.
        control_states[idx] = extend_control_events_callback(
            control_sequences[idx], event_sequences[idx], control_states[idx])
      next_inputs = self._config.encoder_decoder.get_inputs_batch(
          control_sequences, event_sequences)
    else:
      next_inputs = self._config.encoder_decoder.get_inputs_batch(
          event_sequences)

    if modify_events_callback:
      # Modify event sequences and inputs for next step.
      modify_events_callback(
          self._config.encoder_decoder, event_sequences, next_inputs)

    model_states = [ModelState(inputs=inputs, rnn_state=final_state,
                               control_events=control_events,
                               control_state=control_state)
                    for inputs, final_state, control_events, control_state
                    in zip(next_inputs, final_states,
                           control_sequences, control_states)]

    return event_sequences, model_states, logliks