def _apply_gradients(grads_and_vars, global_step=None): if hasattr(optimizer, "allreduce_gradients"): grads_and_vars = optimizer.allreduce_gradients(grads_and_vars) kwargs = {} if "decay_var_list" in misc.function_args(optimizer.apply_gradients): kwargs["decay_var_list"] = [ var for _, var in grads_and_vars if not _is_bias(var) ] return optimizer.apply_gradients(grads_and_vars, global_step=global_step, **kwargs)
def dynamic_decode(symbols_to_logits_fn, start_ids, end_id=constants.END_OF_SENTENCE_ID, initial_state=None, decoding_strategy=None, sampler=None, maximum_iterations=None, minimum_iterations=0, attention_history=False): """Dynamic decoding. Args: symbols_to_logits_fn: A callable taking ``(symbols, step, state)`` and returning ``(logits, state, attention)`` (``attention`` is optional). start_ids: Initial input IDs of shape :math:`[B]`. end_id: ID of the end of sequence token. initial_state: Initial decoder state. decoding_strategy: A :class:`opennmt.utils.decoding.DecodingStrategy` instance that define the decoding logic. Defaults to a greedy search. sampler: A :class:`opennmt.utils.decoding.Sampler` instance that samples predictions from the model output. Defaults to an argmax sampling. maximum_iterations: The maximum number of iterations to decode for. minimum_iterations: The minimum number of iterations to decode for. attention_history: Gather attention history during the decoding. Returns: ids: The predicted ids of shape :math:`[B, H, T]`. lengths: The produced sequences length of shape :math:`[B, H]`. log_probs: The cumulated log probabilities of shape :math:`[B, H]`. attention_history: The attention history of shape :math:`[B, H, T_t, T_s]`. state: The final decoding state. """ if "maximum_iterations" not in misc.function_args(tf.while_loop): raise NotImplementedError("Unified decoding does not support TensorFlow 1.4. " "Please update your TensorFlow installation or open " "an issue for assistance.") if decoding_strategy is None: decoding_strategy = GreedySearch() if sampler is None: sampler = BestSampler() def _cond(step, finished, state, inputs, outputs, attention, cum_log_probs, extra_vars): # pylint: disable=unused-argument return tf.reduce_any(tf.logical_not(finished)) def _body(step, finished, state, inputs, outputs, attention, cum_log_probs, extra_vars): # Get log probs from the model. result = symbols_to_logits_fn(inputs, step, state) logits, state = result[0], result[1] attn = result[2] if len(result) > 2 else None logits = tf.cast(logits, tf.float32) # Penalize or force EOS. batch_size, vocab_size = misc.shape_list(logits) eos_max_prob = tf.one_hot( tf.fill([batch_size], end_id), vocab_size, on_value=logits.dtype.max, off_value=logits.dtype.min) logits = tf.cond( step < minimum_iterations, true_fn=lambda: _penalize_token(logits, end_id), false_fn=lambda: tf.where( tf.tile(tf.expand_dims(finished, 1), [1, vocab_size]), x=eos_max_prob, y=logits)) log_probs = tf.nn.log_softmax(logits) # Run one decoding strategy step. output, next_cum_log_probs, finished, state, extra_vars = decoding_strategy.step( step, sampler, log_probs, cum_log_probs, finished, state, extra_vars) # Update loop vars. if attention_history: if attn is None: raise ValueError("attention_history is set but the model did not return attention") attention = attention.write(step, tf.cast(attn, tf.float32)) outputs = outputs.write(step, output) cum_log_probs = tf.where(finished, x=cum_log_probs, y=next_cum_log_probs) finished = tf.logical_or(finished, tf.equal(output, end_id)) return step + 1, finished, state, output, outputs, attention, cum_log_probs, extra_vars batch_size = tf.shape(start_ids)[0] ids_dtype = start_ids.dtype start_ids = tf.cast(start_ids, tf.int32) start_ids, finished, initial_log_probs, extra_vars = decoding_strategy.initialize( batch_size, start_ids) step = tf.constant(0, dtype=tf.int32) outputs = tf.TensorArray(tf.int32, size=0, dynamic_size=True) attention = tf.TensorArray(tf.float32, size=0, dynamic_size=True) _, _, state, _, outputs, attention, log_probs, extra_vars = tf.while_loop( _cond, _body, loop_vars=( step, finished, initial_state, start_ids, outputs, attention, initial_log_probs, extra_vars), shape_invariants=( step.shape, finished.shape, compat.nest.map_structure(_get_shape_invariants, initial_state), start_ids.shape, tf.TensorShape(None), tf.TensorShape(None), initial_log_probs.shape, compat.nest.map_structure(_get_shape_invariants, extra_vars)), parallel_iterations=1, back_prop=False, maximum_iterations=maximum_iterations) ids, attention, lengths = decoding_strategy.finalize( outputs, end_id, extra_vars, attention=attention if attention_history else None) if attention is not None: attention = attention[:, :, 1:] # Ignore attention for <s>. log_probs = tf.reshape(log_probs, [batch_size, decoding_strategy.num_hypotheses]) ids = tf.cast(ids, ids_dtype) return ids, lengths, log_probs, attention, state
def dynamic_decode(symbols_to_logits_fn_trans, symbols_to_logits_fn_ae, trans_model_name, ae_model_name, start_ids, end_id=constants.END_OF_SENTENCE_ID, initial_state_trans=None, initial_state_ae=None, decoding_strategy=None, sampler=None, maximum_iterations=None, minimum_iterations=0, attention_history=False, attention_size=None, low_prob=None): """Dynamic decoding. Args: symbols_to_logits_fn: A callable taking ``(symbols, step, state)`` and returning ``(logits, state, attention)`` (``attention`` is optional). start_ids: Initial input IDs of shape :math:`[B]`. end_id: ID of the end of sequence token. initial_state: Initial decoder state. decoding_strategy: A :class:`opennmt.utils.decoding.DecodingStrategy` instance that define the decoding logic. Defaults to a greedy search. sampler: A :class:`opennmt.utils.decoding.Sampler` instance that samples predictions from the model output. Defaults to an argmax sampling. maximum_iterations: The maximum number of iterations to decode for. minimum_iterations: The minimum number of iterations to decode for. attention_history: Gather attention history during the decoding. attention_size: If known, the size of the attention vectors (i.e. the maximum source length). Returns: ids: The predicted ids of shape :math:`[B, H, T]`. lengths: The produced sequences length of shape :math:`[B, H]`. log_probs: The cumulated log probabilities of shape :math:`[B, H]`. attention_history: The attention history of shape :math:`[B, H, T_t, T_s]`. state: The final decoding state. """ if "maximum_iterations" not in misc.function_args(tf.while_loop): raise NotImplementedError( "Unified decoding does not support TensorFlow 1.4. " "Please update your TensorFlow installation or open " "an issue for assistance.") if decoding_strategy is None: decoding_strategy = GreedySearch() if sampler is None: sampler = BestSampler() def _cond(step, finished, state_trans, state_ae, inputs, outputs, attention, cum_log_probs, extra_vars): # pylint: disable=unused-argument return tf.reduce_any(tf.logical_not(finished)) def _body(step, finished, state_trans, state_ae, inputs, outputs, attention, cum_log_probs, extra_vars): # Get log probs from the model. with tf.variable_scope(trans_model_name, reuse=tf.AUTO_REUSE): with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE): result_trans = symbols_to_logits_fn_trans( inputs, step, state_trans) with tf.variable_scope(ae_model_name, reuse=tf.AUTO_REUSE): with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE): result_ae = symbols_to_logits_fn_ae(inputs, step, state_ae) logits_trans, state_trans = result_trans[0], result_trans[1] logits_ae, state_ae = result_ae[0], result_ae[1] logits_trans = tf.cast(logits_trans, tf.float32) logits_ae = tf.cast(logits_ae, tf.float32) logits = 1 * logits_trans + 0.5 * logits_ae if (low_prob != None): print("=============================================") print("lower some token's prob") logits = logits - tf.abs(logits * low_prob) print("=============================================") print(logits) # Penalize or force EOS. batch_size, vocab_size = misc.shape_list(logits_trans) eos_max_prob = tf.one_hot(tf.fill([batch_size], end_id), vocab_size, on_value=logits_trans.dtype.max, off_value=logits_trans.dtype.min) logits = tf.where(finished, x=eos_max_prob, y=logits) log_probs = tf.nn.log_softmax(logits) #logits_trans = tf.cond( # step < minimum_iterations, # true_fn=lambda: _penalize_token(logits_trans, end_id), # false_fn=lambda: tf.where(finished, x=eos_max_prob, y=logits_trans)) #logits_ae = tf.cond( # step < minimum_iterations, # true_fn=lambda: _penalize_token(logits_ae , end_id), # false_fn=lambda: tf.where(finished, x=eos_max_prob, y=logits_ae )) #log_probs_trans = tf.nn.log_softmax(logits_trans) #log_probs_ae = tf.nn.log_softmax(logits_ae ) #log_probs = 1 * log_probs_trans + 0.2 * log_probs_ae # Run one decoding strategy step. print("============================================================") print(type(decoding_strategy)) output, next_cum_log_probs, finished, state_trans, state_ae, extra_vars = decoding_strategy.step( step, sampler, log_probs, cum_log_probs, finished, state_trans, state_ae, extra_vars, attention=None) outputs = outputs.write(step, output) cum_log_probs = tf.where(finished, x=cum_log_probs, y=next_cum_log_probs) finished = tf.logical_or(finished, tf.equal(output, end_id)) return step + 1, finished, state_trans, state_ae, output, outputs, attention, cum_log_probs, extra_vars batch_size = tf.shape(start_ids)[0] ids_dtype = start_ids.dtype start_ids = tf.cast(start_ids, tf.int32) start_ids, finished, initial_log_probs, extra_vars = decoding_strategy.initialize( batch_size, start_ids, attention_size=attention_size) step = tf.constant(0, dtype=tf.int32) outputs = tf.TensorArray(tf.int32, size=0, dynamic_size=True) attention = tf.TensorArray(tf.float32, size=0, dynamic_size=True) _, _, state_trans, state_ae, _, outputs, attention, log_probs, extra_vars = tf.while_loop( _cond, _body, loop_vars=(step, finished, initial_state_trans, initial_state_ae, start_ids, outputs, attention, initial_log_probs, extra_vars), shape_invariants=(step.shape, finished.shape, compat.nest.map_structure(_get_shape_invariants, initial_state_trans), compat.nest.map_structure(_get_shape_invariants, initial_state_ae), start_ids.shape, tf.TensorShape(None), tf.TensorShape(None), initial_log_probs.shape, compat.nest.map_structure(_get_shape_invariants, extra_vars)), parallel_iterations=1, back_prop=False, maximum_iterations=maximum_iterations) ids, attention, lengths = decoding_strategy.finalize( outputs, end_id, extra_vars, attention=attention if attention_history else None) if attention is not None: attention = attention[:, :, :-1] # Ignore attention for </s>. log_probs = tf.reshape(log_probs, [batch_size, decoding_strategy.num_hypotheses]) ids = tf.cast(ids, ids_dtype) return ids, lengths