def neg_log_perplexity(batch, model_predictions): """Calculate negative log perplexity.""" _, targets = batch model_predictions, targets = _make_list(model_predictions, targets) xent = [] for (prediction, target) in zip(model_predictions, targets): hot_target = layers.one_hot(target, prediction.shape[-1]) xent.append(np.sum(prediction * hot_target, axis=-1)) return masked_mean(xent, targets)
def loss(params, batch, model_predict, rng): """Calculate loss.""" inputs, targets = batch predictions = model_predict(inputs, params, rng=rng) predictions, targets = _make_list(predictions, targets) xent = [] for (pred, target) in zip(predictions, targets): xent.append(np.sum(pred * layers.one_hot(target, pred.shape[-1]), axis=-1)) return - masked_mean(xent, targets)
def loss(params, batch, model_predict, state, rng, has_weights): """Calculate loss.""" inputs, targets, weights = unpack_batch(batch, has_weights) model_input, get_preds = _stack_inputs_targets_and_get_predictions( [inputs, targets]) # Call model, predictions will be the returned stack, usually consisting of # the prediction tensor and the targets. predictions, state = model_predict(model_input, params, state, rng=rng) predictions = get_preds(predictions) predictions, targets, weights = _make_list(predictions, targets, weights) xent = [] for (pred, target) in zip(predictions, targets): xent.append( np.sum(pred * layers.one_hot(target, pred.shape[-1]), axis=-1)) return -masked_mean(xent, targets, weights), state
def loss(params, batch, model_predict, rng): """Calculate loss.""" inputs, targets = batch preds = model_predict(inputs, params, rng=rng) xent = np.sum(preds * layers.one_hot(targets, preds.shape[-1]), axis=-1) return -masked_mean(xent, targets)
def neg_log_perplexity(batch, model_predictions): """Calculate negative log perplexity.""" _, targets = batch hot_targets = layers.one_hot(targets, model_predictions.shape[-1]) xent = np.sum(model_predictions * hot_targets, axis=-1) return masked_mean(xent, targets)