def testBranchingSingleBeamEntry(self): sequence, state, score = beam_search( initial_sequence=[], initial_state=1, generate_step_fn=self._generate_step_fn, num_steps=5, beam_size=1, branch_factor=32, steps_per_iteration=1) # Here the beam search should greedily choose ones. self.assertEqual(sequence, [1, 1, 1, 1, 1]) self.assertEqual(state, 1) self.assertEqual(score, 5)
def testNoBranchingMultipleBeamEntries(self): sequence, state, score = beam_search( initial_sequence=[], initial_state=1, generate_step_fn=self._generate_step_fn, num_steps=5, beam_size=32, branch_factor=1, steps_per_iteration=1) # Here the beam has enough capacity to find the optimal solution without # branching. self.assertEqual(sequence, [0, 0, 0, 0, 1]) self.assertEqual(state, 1) self.assertEqual(score, 16)
def testNoBranchingMultipleStepsPerIteration(self): sequence, state, score = beam_search( initial_sequence=[], initial_state=1, generate_step_fn=self._generate_step_fn, num_steps=5, beam_size=1, branch_factor=1, steps_per_iteration=2) # Like the above case, the counter should never reach one as only a single # sequence is ever considered. self.assertEqual(sequence, [0, 0, 0, 0, 0]) self.assertEqual(state, 32) self.assertEqual(score, 0)
def testNoBranchingSingleStepPerIteration(self): sequence, state, score = beam_search( initial_sequence=[], initial_state=1, generate_step_fn=self._generate_step_fn, num_steps=5, beam_size=1, branch_factor=1, steps_per_iteration=1) # The generator should emit all zeros, as only a single sequence is ever # considered so the counter doesn't reach one. self.assertEqual(sequence, [0, 0, 0, 0, 0]) self.assertEqual(state, 32) self.assertEqual(score, 0)
def _generate_events( self, num_steps, primer_events, temperature=1.0, beam_size=1, branch_factor=1, steps_per_iteration=1, control_events=None, control_state=None, extend_control_events_callback=(_extend_control_events_default), modify_events_callback=None): """Generate an event sequence from a primer sequence. Args: num_steps: The integer length in steps of the final event sequence, after generation. Includes the primer. primer_events: The primer event sequence, a Python list-like object. temperature: A float specifying how much to divide the logits by before computing the softmax. Greater than 1.0 makes events more random, less than 1.0 makes events less random. beam_size: An integer, beam size to use when generating event sequences via beam search. branch_factor: An integer, beam search branch factor to use. steps_per_iteration: An integer, number of steps to take per beam search iteration. control_events: A sequence of control events upon which to condition the generation. If not None, the encoder/decoder should be a ConditionalEventSequenceEncoderDecoder, and the control events will be used along with the target sequence to generate model inputs. In some cases, the control event sequence cannot be fully-determined as later control events depend on earlier generated events; use the `extend_control_events_callback` argument to provide a function that extends the control event sequence. control_state: Initial state used by `extend_control_events_callback`. extend_control_events_callback: A function that takes three arguments: a current control event sequence, a current generated event sequence, and the control state. The function should a) extend the control event sequence to be one longer than the generated event sequence (or do nothing if it is already at least this long), and b) return the resulting control state. modify_events_callback: An optional callback for modifying the event list. Can be used to inject events rather than having them generated. If not None, will be called with 3 arguments after every event: the current EventSequenceEncoderDecoder, a list of current EventSequences, and a list of current encoded event inputs. Returns: The generated event sequence (which begins with the provided primer). Raises: EventSequenceRnnModelError: If the primer sequence has zero length or is not shorter than num_steps. """ if (control_events is not None and not isinstance(self._config.encoder_decoder, mm.ConditionalEventSequenceEncoderDecoder)): raise EventSequenceRnnModelError( 'control sequence provided but encoder/decoder is not a ' 'ConditionalEventSequenceEncoderDecoder') if control_events is not None and extend_control_events_callback is None: raise EventSequenceRnnModelError( 'must provide callback for extending control sequence (or use' 'default)') if not primer_events: raise EventSequenceRnnModelError( 'primer sequence must have non-zero length') if len(primer_events) >= num_steps: raise EventSequenceRnnModelError( 'primer sequence must be shorter than `num_steps`') if len(primer_events) >= num_steps: # Sequence is already long enough, no need to generate. return primer_events event_sequences = [copy.deepcopy(primer_events)] # Construct inputs for first step after primer. if control_events is not None: # We are conditioning on a control sequence. Make sure it is longer than # the primer sequence. control_state = extend_control_events_callback( control_events, primer_events, control_state) inputs = self._config.encoder_decoder.get_inputs_batch( [control_events], event_sequences, full_length=True) else: inputs = self._config.encoder_decoder.get_inputs_batch( event_sequences, full_length=True) if modify_events_callback: # Modify event sequences and inputs for first step after primer. modify_events_callback(self._config.encoder_decoder, event_sequences, inputs) graph_initial_state = self._session.graph.get_collection( 'initial_state') initial_states = state_util.unbatch( self._session.run(graph_initial_state)) # Beam search will maintain a state for each sequence consisting of the next # inputs to feed the model, and the current RNN state. We start out with the # initial full inputs batch and the zero state. initial_state = ModelState(inputs=inputs[0], rnn_state=initial_states[0], control_events=control_events, control_state=control_state) generate_step_fn = functools.partial( self._generate_step, temperature=temperature, extend_control_events_callback=extend_control_events_callback if control_events is not None else None, modify_events_callback=modify_events_callback) events, _, loglik = beam_search( initial_sequence=event_sequences[0], initial_state=initial_state, generate_step_fn=generate_step_fn, num_steps=num_steps - len(primer_events), beam_size=beam_size, branch_factor=branch_factor, steps_per_iteration=steps_per_iteration) tf.logging.info('Beam search yields sequence with log-likelihood: %f ', loglik) return events
def _generate_events(self, num_steps, primer_events, temperature=1.0, beam_size=1, branch_factor=1, steps_per_iteration=1, control_events=None, control_state=None, extend_control_events_callback=( _extend_control_events_default), modify_events_callback=None): """Generate an event sequence from a primer sequence. Args: num_steps: The integer length in steps of the final event sequence, after generation. Includes the primer. primer_events: The primer event sequence, a Python list-like object. temperature: A float specifying how much to divide the logits by before computing the softmax. Greater than 1.0 makes events more random, less than 1.0 makes events less random. beam_size: An integer, beam size to use when generating event sequences via beam search. branch_factor: An integer, beam search branch factor to use. steps_per_iteration: An integer, number of steps to take per beam search iteration. control_events: A sequence of control events upon which to condition the generation. If not None, the encoder/decoder should be a ConditionalEventSequenceEncoderDecoder, and the control events will be used along with the target sequence to generate model inputs. In some cases, the control event sequence cannot be fully-determined as later control events depend on earlier generated events; use the `extend_control_events_callback` argument to provide a function that extends the control event sequence. control_state: Initial state used by `extend_control_events_callback`. extend_control_events_callback: A function that takes three arguments: a current control event sequence, a current generated event sequence, and the control state. The function should a) extend the control event sequence to be one longer than the generated event sequence (or do nothing if it is already at least this long), and b) return the resulting control state. modify_events_callback: An optional callback for modifying the event list. Can be used to inject events rather than having them generated. If not None, will be called with 3 arguments after every event: the current EventSequenceEncoderDecoder, a list of current EventSequences, and a list of current encoded event inputs. Returns: The generated event sequence (which begins with the provided primer). Raises: EventSequenceRnnModelException: If the primer sequence has zero length or is not shorter than num_steps. """ if (control_events is not None and not isinstance(self._config.encoder_decoder, mm.ConditionalEventSequenceEncoderDecoder)): raise EventSequenceRnnModelException( 'control sequence provided but encoder/decoder is not a ' 'ConditionalEventSequenceEncoderDecoder') if control_events is not None and extend_control_events_callback is None: raise EventSequenceRnnModelException( 'must provide callback for extending control sequence (or use' 'default)') if not primer_events: raise EventSequenceRnnModelException( 'primer sequence must have non-zero length') if len(primer_events) >= num_steps: raise EventSequenceRnnModelException( 'primer sequence must be shorter than `num_steps`') if len(primer_events) >= num_steps: # Sequence is already long enough, no need to generate. return primer_events event_sequences = [copy.deepcopy(primer_events)] # Construct inputs for first step after primer. if control_events is not None: # We are conditioning on a control sequence. Make sure it is longer than # the primer sequence. control_state = extend_control_events_callback( control_events, primer_events, control_state) inputs = self._config.encoder_decoder.get_inputs_batch( [control_events], event_sequences, full_length=True) else: inputs = self._config.encoder_decoder.get_inputs_batch( event_sequences, full_length=True) if modify_events_callback: # Modify event sequences and inputs for first step after primer. modify_events_callback( self._config.encoder_decoder, event_sequences, inputs) graph_initial_state = self._session.graph.get_collection('initial_state') initial_states = state_util.unbatch(self._session.run(graph_initial_state)) # Beam search will maintain a state for each sequence consisting of the next # inputs to feed the model, and the current RNN state. We start out with the # initial full inputs batch and the zero state. initial_state = ModelState( inputs=inputs[0], rnn_state=initial_states[0], control_events=control_events, control_state=control_state) events, _, loglik = beam_search( initial_sequence=event_sequences[0], initial_state=initial_state, generate_step_fn=functools.partial( self._generate_step, temperature=temperature, extend_control_events_callback=( extend_control_events_callback if control_events is not None else None), modify_events_callback=modify_events_callback), num_steps=num_steps - len(primer_events), beam_size=beam_size, branch_factor=branch_factor, steps_per_iteration=steps_per_iteration) tf.logging.info('Beam search yields sequence with log-likelihood: %f ', loglik) return events