Esempio n. 1
0
def cross_entropy(logits,
                  labels,
                  input_length=None,
                  label_length=None,
                  smoothing=0.0,
                  reduction=tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS):
  '''
  cross entropy function for classfication and seq classfication
  :param, label_length, for seq task, this for target seq length, e.g. a b c </s>, 4
  '''
  del input_length

  onehot_labels = tf.cond(
      pred=tf.equal(tf.rank(logits) - tf.rank(labels), 1),
      true_fn=lambda: tf.one_hot(labels, tf.shape(logits)[-1], dtype=tf.int32),
      false_fn=lambda: labels)

  if label_length is not None:
    weights = utils.len_to_mask(label_length)
  else:
    weights = 1.0

  loss = tf.losses.softmax_cross_entropy(
      onehot_labels=onehot_labels,
      logits=logits,
      weights=weights,
      label_smoothing=smoothing,
      reduction=reduction)

  return loss
Esempio n. 2
0
def mask_sequence_loss(logits,
                       labels,
                       input_length,
                       label_length,
                       smoothing=0.0):
  '''
  softmax cross entropy loss for sequence to sequence
  :param logits: [batch_size, max_seq_len, vocab_size]
  :param labels: [batch_size, max_seq_len]
  :param input_length: [batch_size]
  :param label_length: [batch_size]
  :return: loss, scalar
  '''
  del smoothing
  del input_length

  if label_length is not None:
    weights = tf.cast(utils.len_to_mask(label_length), tf.float32)
  else:
    weights = tf.ones_like(labels)
  loss = tf.contrib.seq2seq.sequence_loss(logits, labels, weights)
  return loss