def cross_entropy(logits, labels, input_length=None, label_length=None, smoothing=0.0, reduction=tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS): ''' cross entropy function for classfication and seq classfication :param, label_length, for seq task, this for target seq length, e.g. a b c </s>, 4 ''' del input_length onehot_labels = tf.cond( pred=tf.equal(tf.rank(logits) - tf.rank(labels), 1), true_fn=lambda: tf.one_hot(labels, tf.shape(logits)[-1], dtype=tf.int32), false_fn=lambda: labels) if label_length is not None: weights = utils.len_to_mask(label_length) else: weights = 1.0 loss = tf.losses.softmax_cross_entropy( onehot_labels=onehot_labels, logits=logits, weights=weights, label_smoothing=smoothing, reduction=reduction) return loss
def mask_sequence_loss(logits, labels, input_length, label_length, smoothing=0.0): ''' softmax cross entropy loss for sequence to sequence :param logits: [batch_size, max_seq_len, vocab_size] :param labels: [batch_size, max_seq_len] :param input_length: [batch_size] :param label_length: [batch_size] :return: loss, scalar ''' del smoothing del input_length if label_length is not None: weights = tf.cast(utils.len_to_mask(label_length), tf.float32) else: weights = tf.ones_like(labels) loss = tf.contrib.seq2seq.sequence_loss(logits, labels, weights) return loss