예제 #1
0
파일: metrics_test.py 프로젝트: yliu45/trax
  def test_masked_sequence_accuracy(self):
    layer = tl.MaskedSequenceAccuracy()
    targets = np.array([[0, 1, 0, 0],
                        [1, 0, 1, 0]])
    weights = np.array([[1., 1., 1., 0.],
                        [1., 1., 1., 0.]])

    # Model gets both sequences right; output in final position would give
    # wrong category but is ignored.
    model_outputs = np.array([[[.9, .1], [.2, .8], [.7, .3], [.35, .65]],
                              [[.3, .7], [.8, .2], [.1, .9], [.35, .65]]])
    accuracy = layer([model_outputs, targets, weights])
    self.assertEqual(accuracy, 1.)

    # Model gets the first element of the first sequence barely wrong.
    model_outputs = np.array([[[.45, .55], [.2, .8], [.7, .3], [.6, .4]],
                              [[.3, .7], [.8, .2], [.1, .9], [.6, .4]]])
    accuracy = layer([model_outputs, targets, weights])
    self.assertEqual(accuracy, .5)

    # Model gets second-to-last element of each sequence barely wrong.
    model_outputs = np.array([[[.9, .1], [.2, .8], [.48, .52], [.6, .4]],
                              [[.3, .7], [.8, .2], [.51, .49], [.6, .4]]])
    accuracy = layer([model_outputs, targets, weights])
    self.assertEqual(accuracy, 0.)
예제 #2
0
        'history',  # trax.history.History.
        'model_state',  # Auxilliary state of the model.
    ])

OptState = collections.namedtuple(
    '_OptState',
    [
        'weights',  # Model weights.
        'slots',  # Per-parameter optimizer state, e.g. gradient moments.
        'opt_params',  # Optimizer (hyper)parameters, e.g. learning rate, momentum.
    ])

_DEFAULT_METRICS = {
    'loss': tl.WeightedCategoryCrossEntropy(),
    'accuracy': tl.WeightedCategoryAccuracy(),
    'sequence_accuracy': tl.MaskedSequenceAccuracy(),
    'neg_log_perplexity': tl.Serial(tl.WeightedCategoryCrossEntropy(),
                                    tl.Negate()),
    'weights_per_batch_per_core': tl.Serial(tl.Drop(), tl.Drop(), tl.Sum()),
}


class Trainer:
    """Trax trainer.

  A trainer allows to make training steps, train for full epochs,
  save the training state and access evaluation data.
  """
    def __init__(self,
                 model,
                 loss_fn,