コード例 #1
0
ファイル: metrics_test.py プロジェクト: yliu45/trax
  def test_sequence_accuracy_last_position_zero_weight(self):
    layer = tl.SequenceAccuracy()
    targets = np.array([[0, 1, 0, 0],
                        [1, 0, 1, 0]])
    weights = np.array([[1., 1., 1., 0.],
                        [1., 1., 1., 0.]])

    # Model gets both sequences right; output in final position would give
    # wrong category but is ignored.
    model_outputs = np.array([[[.9, .1], [.2, .8], [.7, .3], [.35, .65]],
                              [[.3, .7], [.8, .2], [.1, .9], [.35, .65]]])
    accuracy = layer([model_outputs, targets, weights])
    self.assertEqual(accuracy, 1.)

    # Model gets the first element of the first sequence barely wrong.
    model_outputs = np.array([[[.45, .55], [.2, .8], [.7, .3], [.6, .4]],
                              [[.3, .7], [.8, .2], [.1, .9], [.6, .4]]])
    accuracy = layer([model_outputs, targets, weights])
    self.assertEqual(accuracy, .5)

    # Model gets second-to-last element of each sequence barely wrong.
    model_outputs = np.array([[[.9, .1], [.2, .8], [.48, .52], [.6, .4]],
                              [[.3, .7], [.8, .2], [.51, .49], [.6, .4]]])
    accuracy = layer([model_outputs, targets, weights])
    self.assertEqual(accuracy, 0.)
コード例 #2
0
ファイル: metrics_test.py プロジェクト: yangliuy/trax
 def test_names(self):
     layer = tl.L2Loss()
     self.assertEqual('L2Loss_in3', str(layer))
     layer = tl.Accuracy()
     self.assertEqual('Accuracy_in3', str(layer))
     layer = tl.SequenceAccuracy()
     self.assertEqual('SequenceAccuracy_in3', str(layer))
     layer = tl.CrossEntropyLoss()
     self.assertEqual('CrossEntropyLoss_in3', str(layer))
     layer = tl.CrossEntropySum()
     self.assertEqual('CrossEntropySum_in3', str(layer))
コード例 #3
0
 def test_names(self):
     layer = tl.L2Loss()
     self.assertEqual('L2Loss_in3', str(layer))
     layer = tl.BinaryClassifier()
     self.assertEqual('BinaryClassifier', str(layer))
     layer = tl.MulticlassClassifier()
     self.assertEqual('MulticlassClassifier', str(layer))
     layer = tl.Accuracy()
     self.assertEqual('Accuracy_in3', str(layer))
     layer = tl.SequenceAccuracy()
     self.assertEqual('SequenceAccuracy_in3', str(layer))
     layer = tl.BinaryCrossEntropyLoss()
     self.assertEqual('BinaryCrossEntropyLoss_in3', str(layer))
     layer = tl.CrossEntropyLoss()
     self.assertEqual('CrossEntropyLoss_in3', str(layer))
     layer = tl.BinaryCrossEntropySum()
     self.assertEqual('BinaryCrossEntropySum_in3', str(layer))
     layer = tl.CrossEntropySum()
     self.assertEqual('CrossEntropySum_in3', str(layer))
コード例 #4
0
    def test_sequence_accuracy_weights_all_ones(self):
        layer = tl.SequenceAccuracy()
        targets = np.array([[0, 1, 0, 1], [1, 0, 1, 1]])
        weights = np.ones_like(targets)

        # Model gets both sequences right; for each position in each sequence, the
        # category (integer ID) selected by argmax matches the target category.
        model_outputs = np.array([[[.9, .1], [.2, .8], [.7, .3], [.4, .6]],
                                  [[.3, .7], [.8, .2], [.1, .9], [.4, .6]]])
        accuracy = layer([model_outputs, targets, weights])
        self.assertEqual(accuracy, 1.)

        # Model gets the first element of the first sequence barely wrong.
        model_outputs = np.array([[[.45, .55], [.2, .8], [.7, .3], [.4, .6]],
                                  [[.3, .7], [.8, .2], [.1, .9], [.4, .6]]])
        accuracy = layer([model_outputs, targets, weights])
        self.assertEqual(accuracy, .5)

        # Model gets the last element of each sequence barely wrong.
        model_outputs = np.array([[[.9, .1], [.2, .8], [.7, .3], [.55, .45]],
                                  [[.3, .7], [.8, .2], [.1, .9], [.52, .48]]])
        accuracy = layer([model_outputs, targets, weights])
        self.assertEqual(accuracy, 0.)
コード例 #5
0
OptState = collections.namedtuple(
    '_OptState',
    [
        'weights',  # Model weights.
        'slots',  # Per-parameter optimizer state, e.g. gradient moments.
        'opt_params',  # Optimizer (hyper)parameters, e.g. learning rate, momentum.
    ])

_DEFAULT_METRICS = {
    'loss':
    tl.Serial(tl.LogSoftmax(), tl.CrossEntropyLoss()),
    'accuracy':
    tl.Accuracy(),
    'sequence_accuracy':
    tl.SequenceAccuracy(),
    'neg_log_perplexity':
    tl.Serial(tl.LogSoftmax(), tl.CrossEntropyLoss(), tl.Negate()),
    'weights_per_batch_per_core':
    tl.Serial(tl.Drop(), tl.Drop(), tl.Sum()),
}


class Trainer:
    """Trax trainer.

  A trainer allows to make training steps, train for full epochs,
  save the training state and access evaluation data.
  """
    def __init__(self,
                 model,
コード例 #6
0
        'history',  # trax.history.History.
        'model_state',  # Auxilliary state of the model.
    ])

OptState = collections.namedtuple(
    '_OptState',
    [
        'weights',  # Model weights.
        'slots',  # Per-parameter optimizer state, e.g. gradient moments.
        'opt_params',  # Optimizer (hyper)parameters, e.g. learning rate, momentum.
    ])

_DEFAULT_METRICS = {
    'loss': tl.CrossEntropyLoss(),
    'accuracy': tl.Accuracy(),
    'sequence_accuracy': tl.SequenceAccuracy(),
    'neg_log_perplexity': tl.Serial(tl.CrossEntropyLoss(), tl.Negate()),
    'weights_per_batch_per_core': tl.SumOfWeights(),
}


class Trainer(object):
    """Trax trainer.

  A trainer allows to make training steps, train for full epochs,
  save the training state and access evaluation data.
  """
    def __init__(self,
                 model,
                 loss_fn,
                 optimizer,
コード例 #7
0
ファイル: latent.py プロジェクト: tonywu95/trax
Latent_METRICS = {
    'next_state_loss':
    tl.Serial(tl.Select([0, 1, 9]),
              tl.WeightedCategoryCrossEntropy()),  # DropLast()),
    'recon_state_loss':
    tl.Serial(tl.Select([2, 3, 10]), tl.WeightedCategoryCrossEntropy()),
    'recon_action_loss':
    tl.Serial(tl.Select([4, 5, 11]), tl.WeightedCategoryCrossEntropy()),
    'next_state_accuracy':
    tl.Serial(tl.Select([0, 1, 9]), tl.Accuracy()),  # DropLast()),
    'recon_state_accuracy':
    tl.Serial(tl.Select([2, 3, 10]), tl.Accuracy()),
    'recon_action_accuracy':
    tl.Serial(tl.Select([4, 5, 11]), tl.Accuracy()),
    'next_state_sequence_accuracy':
    tl.Serial(tl.Select([0, 1, 9]), tl.SequenceAccuracy()),  # DropLast()),
    'recon_state_sequence_accuracy':
    tl.Serial(tl.Select([2, 3, 10]), tl.SequenceAccuracy()),
    'recon_action_sequence_accuracy':
    tl.Serial(tl.Select([4, 5, 11]), tl.SequenceAccuracy()),
    # 'neg_log_perplexity': Serial(WeightedCategoryCrossEntropy(),
    #                                 Negate()),
    # 'weights_per_batch_per_core': Serial(tl.Drop(), Drop(), Sum()),
}


@gin.configurable
def latent_fn():
    return Latent_METRICS