def concat_and_gather_tuple_states(pc_states, c_state): rc_states = ( tf.concat(axis=1, values=[pc_states[0], tf.expand_dims(c_state[0], 1)]), tf.concat(axis=1, values=[pc_states[1], tf.expand_dims(c_state[1], 1)]) ) c_states = ( nest_map(lambda element: tf.gather(element, parent_refs), rc_states[0]), nest_map(lambda element: tf.gather(element, parent_refs), rc_states[1]) ) return c_states
def _tile_along_beam(cls, beam_size, state): if nest.is_sequence(state): return nest_map( lambda val: cls._tile_along_beam(beam_size, val), state ) if not isinstance(state, tf.Tensor): raise ValueError("State should be a sequence or tensor") tensor = state tensor_shape = tensor.get_shape().with_rank_at_least(1) try: new_first_dim = tensor_shape[0] * beam_size except: new_first_dim = None dynamic_tensor_shape = tf.unstack(tf.shape(tensor)) res = tf.expand_dims(tensor, 1) res = tf.tile(res, [1, beam_size] + [1] * (tensor_shape.ndims-1)) res = tf.reshape(res, [-1] + list(dynamic_tensor_shape[1:])) res.set_shape([new_first_dim] + list(tensor_shape[1:])) return res
def _create_state(self, batch_size, dtype, cell_state=None): cand_symbols = tf.fill([batch_size, 1], tf.constant(self.start_token, dtype=tf.int32)) cand_logprobs = tf.ones( (batch_size, ), dtype=tf.float32) * -float('inf') if cell_state is None: cell_state = self.cell.zero_state(batch_size * self.beam_size, dtype=dtype) else: cell_state = BeamDecoder._tile_along_beam(self.beam_size, cell_state) full_size = batch_size * self.beam_size first_in_beam_mask = tf.equal(tf.range(full_size) % self.beam_size, 0) beam_symbols = tf.fill([full_size, 1], tf.constant(self.start_token, dtype=tf.int32)) beam_logprobs = tf.where( first_in_beam_mask, tf.fill([full_size], 0.0), tf.fill([full_size], -1e18), # top_k does not play well with -inf # TODO: dtype-dependent value here ) return (cand_symbols, cand_logprobs, beam_symbols, beam_logprobs, nest_map(lambda element: tf.expand_dims(element, 1), cell_state))
def __call__(self, cell_inputs, state, scope=None): ( past_beam_symbols, # [batch_size*self.beam_size, :], right-aligned!!! past_beam_logprobs, # [batch_size*self.beam_size] past_cell_states # LSTM: ([batch_size*self.beam_size, :, dim], # [batch_size*self.beam_size, :, dim]) # GRU: [batch_size*self.beam_size, :, dim] ) = state past_cell_state = self.get_last_cell_state(past_cell_states) if self.use_copy and self.copy_fun == 'copynet': cell_output, cell_state, alignments, attns = \ self.cell(cell_inputs, past_cell_state, scope) elif self.use_attention: cell_output, cell_state, alignments, attns = \ self.cell(cell_inputs, past_cell_state, scope) else: cell_output, cell_state = \ self.cell(cell_inputs, past_cell_state, scope) # [batch_size*beam_size, num_classes] if self.use_copy and self.copy_fun == 'copynet': logprobs = tf.log(cell_output) else: W, b = self.output_project if self.locally_normalized: logprobs = tf.nn.log_softmax(tf.matmul(cell_output, W) + b) else: logprobs = tf.matmul(cell_output, W) + b num_classes = logprobs.get_shape()[1].value # stop_mask: indicates partial sequences ending with a stop token # [batch_size * beam_size] # x 0 # _STOP 1 # x 0 # x 0 input_symbols = past_beam_symbols[:, -1] stop_mask = tf.expand_dims(tf.cast( tf.equal(input_symbols, self.stop_token), tf.float32), 1) # done_mask: indicates stop token in the output vocabulary # [1, num_classes] # [- - _STOP - - -] # [0 0 1 0 0 0] done_mask = tf.cast(tf.reshape(tf.equal(tf.range(num_classes), self.stop_token), [1, num_classes]), tf.float32) # set the next token distribution of partial sequences ending with # a stop token to: # [- - _STOP - - -] # [-inf -inf 0 -inf -inf -inf] logprobs = tf.add(logprobs, tf.multiply( stop_mask, -1e18 * (tf.ones_like(done_mask) - done_mask))) logprobs = tf.multiply(logprobs, (1 - tf.multiply(stop_mask, done_mask))) # length normalization past_logprobs_unormalized = \ tf.multiply(past_beam_logprobs, tf.pow(self.seq_len, self.alpha)) logprobs_unormalized = \ tf.expand_dims(past_logprobs_unormalized, 1) + logprobs seq_len = tf.expand_dims(self.seq_len, 1) + (1 - stop_mask) logprobs_batched = tf.div(logprobs_unormalized, tf.pow(seq_len, self.alpha)) beam_logprobs, indices = tf.nn.top_k( tf.reshape(logprobs_batched, [-1, self.beam_size * num_classes]), self.beam_size ) beam_logprobs = tf.reshape(beam_logprobs, [-1]) # For continuing to the next symbols parent_refs_offsets = \ (tf.range(self.full_size) // self.beam_size) * self.beam_size symbols = indices % num_classes # [batch_size, self.beam_size] parent_refs = tf.reshape(indices // num_classes, [-1]) # [batch_size*self.beam_size] parent_refs = parent_refs + parent_refs_offsets beam_symbols = tf.concat(axis=1, values=[tf.gather(past_beam_symbols, parent_refs), tf.reshape(symbols, [-1, 1])]) self.seq_len = tf.squeeze(tf.gather(seq_len, parent_refs), squeeze_dims=[1]) if self.use_attention: ranked_alignments = nest_map( lambda element: tf.gather(element, parent_refs), alignments) ranked_attns = nest_map( lambda element: tf.gather(element, parent_refs), attns) # update cell_states def concat_and_gather_tuple_states(pc_states, c_state): rc_states = ( tf.concat(axis=1, values=[pc_states[0], tf.expand_dims(c_state[0], 1)]), tf.concat(axis=1, values=[pc_states[1], tf.expand_dims(c_state[1], 1)]) ) c_states = ( nest_map(lambda element: tf.gather(element, parent_refs), rc_states[0]), nest_map(lambda element: tf.gather(element, parent_refs), rc_states[1]) ) return c_states if nest.is_sequence(cell_state): if self.num_layers > 1: ranked_cell_states = [concat_and_gather_tuple_states(pc_states, c_state) for pc_states, c_state in zip(past_cell_states, cell_state)] else: ranked_cell_states = concat_and_gather_tuple_states( past_cell_states, cell_state) else: ranked_cell_states = tf.gather( tf.concat(axis=1, values=[past_cell_states, tf.expand_dims(cell_state, 1)]), parent_refs) compound_cell_state = ( beam_symbols, beam_logprobs, ranked_cell_states ) ranked_cell_output = tf.gather(cell_output, parent_refs) if self.use_copy and self.copy_fun == 'copynet': return ranked_cell_output, compound_cell_state, ranked_alignments, \ ranked_attns elif self.use_attention: return ranked_cell_output, compound_cell_state, ranked_alignments, \ ranked_attns else: return ranked_cell_output, compound_cell_state
def __call__(self, inputs, state, scope=None): ( past_cand_symbols, # [batch_size, :] past_cand_logprobs, # [batch_size] past_beam_symbols, # [batch_size*self.beam_size, :], right-aligned!!! past_beam_logprobs, # [batch_size*self.beam_size] past_cell_states, # LSTM: ([batch_size*self.beam_size, :, dim], # [batch_size*self.beam_size, :, dim]) # GRU: [batch_size*self.beam_size, :, dim] ) = state batch_size = past_cand_symbols.get_shape()[0].value full_size = batch_size * self.beam_size if self.parent_refs_offsets is None: self.parent_refs_offsets = \ (tf.range(full_size) // self.beam_size) * self.beam_size input_symbols = past_beam_symbols[:, -1] # [batch_size * beam_size] # - 0 # _STOP 1 # - 0 # - 0 stop_mask = tf.cast(tf.equal(input_symbols, self.stop_token), tf.float32) cell_inputs = inputs past_cell_state = self.get_last_cell_state(past_cell_states) if self.use_copy and self.copy_fun != 'supervised': cell_output, cell_state, alignments, attns, read_copy_source = \ self.cell(cell_inputs, past_cell_state, scope) elif self.use_attention: cell_output, cell_state, alignments, attns = \ self.cell(cell_inputs, past_cell_state, scope) else: cell_output, cell_state = \ self.cell(cell_inputs, past_cell_state, scope) # [batch_size*beam_size, num_classes] if self.use_copy and self.copy_fun == 'copynet': logprobs = tf.log(cell_output) else: W, b = self.output_project if self.locally_normalized: logprobs = tf.nn.log_softmax(tf.matmul(cell_output, W) + b) else: logprobs = tf.matmul(cell_output, W) + b # set the probabilities of all other symbols following the stop symbol # to a very small number stop_mask_2d = tf.expand_dims(stop_mask, 1) # [- - _STOP - - - ] # [0 0 0 0 0 0] # [-100 -100 0 -100 -100 -100] # [0 0 0 0 0 0] # [0 0 0 0 0 0] done_only_mask = tf.multiply(stop_mask_2d, self._done_mask) # [- - _STOP - - - ] # [1 1 1 1 1 1] # [1 1 0 1 1 1] # [1 1 1 1 1 1] # [1 1 1 1 1 1] zero_done_mask = tf.ones([full_size, self.num_classes]) - \ tf.multiply(stop_mask_2d, tf.cast(tf.equal(self._done_mask, 0), tf.float32)) logprobs = tf.add(logprobs, done_only_mask) logprobs = tf.multiply(logprobs, zero_done_mask) # length normalization past_beam_acc_logprobs = \ tf.multiply(past_beam_logprobs, tf.pow(self.seq_len, self.alpha)) logprobs_batched = tf.expand_dims(past_beam_acc_logprobs, 1) + logprobs float_done_mask = tf.cast(tf.not_equal(self._done_mask, 0), tf.float32) seq_len = tf.expand_dims(self.seq_len, 1) + \ tf.multiply(tf.ones([full_size, 1]) - stop_mask_2d, float_done_mask) logprobs_batched = tf.div(logprobs_batched, tf.pow(seq_len, self.alpha)) logprobs_batched = \ tf.reshape(logprobs_batched, [-1, self.beam_size * self.num_classes]) beam_logprobs, indices = tf.nn.top_k( #TODO: make sure it's reasonable to remove nondone mask tf.reshape(logprobs_batched, [-1, self.beam_size * self.num_classes]), self.beam_size) beam_logprobs = tf.reshape(beam_logprobs, [-1]) # For continuing to the next symbols symbols = indices % self.num_classes # [batch_size, self.beam_size] parent_refs = tf.reshape(indices // self.num_classes, [-1]) # [batch_size*self.beam_size] parent_refs = parent_refs + self.parent_refs_offsets beam_symbols = tf.concat(axis=1, values=[ tf.gather(past_beam_symbols, parent_refs), tf.reshape(symbols, [-1, 1]) ]) self.seq_len = tf.gather(self.seq_len, parent_refs) + \ tf.cast(tf.not_equal(tf.reshape(symbols, [-1]), self.stop_token), tf.float32) if self.use_copy and self.copy_fun != 'supervised': ranked_read_copy_source = tf.gather(read_copy_source, parent_refs) if self.use_attention: ranked_alignments = nest_map( lambda element: tf.gather(element, parent_refs), alignments) ranked_attns = nest_map( lambda element: tf.gather(element, parent_refs), attns) # update cell_states def concat_and_gather_tuple_states(pc_states, c_state): rc_states = (tf.concat( axis=1, values=[pc_states[0], tf.expand_dims(c_state[0], 1)]), tf.concat(axis=1, values=[ pc_states[1], tf.expand_dims(c_state[1], 1) ])) c_states = (nest_map( lambda element: tf.gather(element, parent_refs), rc_states[0]), nest_map( lambda element: tf.gather(element, parent_refs), rc_states[1])) return c_states if nest.is_sequence(cell_state): if self.num_layers > 1: ranked_cell_states = [ concat_and_gather_tuple_states(pc_states, c_state) for pc_states, c_state in zip(past_cell_states, cell_state) ] else: ranked_cell_states = concat_and_gather_tuple_states( past_cell_states, cell_state) else: ranked_cell_states = tf.gather( tf.concat( axis=1, values=[past_cell_states, tf.expand_dims(cell_state, 1)]), parent_refs) # Handling for getting a done token logprobs_batched_3D = tf.reshape( logprobs_batched, [-1, self.beam_size, self.num_classes]) logprobs_done = logprobs_batched_3D[:, :, self.stop_token] done_parent_refs = tf.to_int32(tf.argmax(logprobs_done, 1)) done_parent_refs_offsets = tf.range(batch_size) * self.beam_size done_symbols = tf.gather(past_beam_symbols[:, -1:], done_parent_refs + done_parent_refs_offsets) logprobs_done_max = tf.reduce_max(logprobs_done, 1) cand_symbols = tf.where(logprobs_done_max > past_cand_logprobs, done_symbols, past_cand_symbols) cand_logprobs = tf.maximum(logprobs_done_max, past_cand_logprobs) compound_cell_state = (cand_symbols, cand_logprobs, beam_symbols, beam_logprobs, ranked_cell_states) ranked_cell_output = tf.gather(cell_output, parent_refs) if self.use_copy and self.copy_fun == 'copynet': return ranked_cell_output, compound_cell_state, ranked_alignments, \ ranked_attns, ranked_read_copy_source elif self.use_attention: return ranked_cell_output, compound_cell_state, ranked_alignments, \ ranked_attns else: return ranked_cell_output, compound_cell_state