Ejemplo n.º 1
0
def _mnist_tasks(head=None):
    """Creates MNIST training and evaluation tasks.

  Args:
    head: Adaptor layer to put before loss and accuracy layers in the tasks.

  Returns:
    A pair (train_task, eval_task) consisting of the MNIST training task and the
    MNIST evaluation task using cross-entropy as loss and accuracy as metric.
  """
    loss = tl.WeightedCategoryCrossEntropy()
    accuracy = tl.WeightedCategoryAccuracy()
    if head is not None:
        loss = tl.Serial(head, loss)
        accuracy = tl.Serial(head, accuracy)
    task = training.TrainTask(
        itertools.cycle(_mnist_dataset().train_stream(1)),
        loss,
        adam.Adam(0.001),
    )
    eval_task = training.EvalTask(
        itertools.cycle(_mnist_dataset().eval_stream(1)),
        [loss, accuracy],
        n_eval_batches=10,
        metric_names=['CrossEntropy', 'WeightedCategoryAccuracy'],
    )
    return (task, eval_task)
Ejemplo n.º 2
0
    def test_accelerated_weighted_category_accuracy(self):
        """Test multi-device aggregation of weights."""
        layer = tl.Accelerate(tl.WeightedCategoryAccuracy())
        weights = np.array([1., 1., 1., 0.])
        targets = np.array([0, 1, 2, 3])

        model_outputs = np.array([[.2, .1, .7, 0.], [.2, .1, .7, 0.],
                                  [.2, .1, .7, 0.], [.2, .1, .7, 0.]])
        accuracy = layer([model_outputs, targets, weights])
        self.assertEqual(np.mean(accuracy), 1 / 3)
Ejemplo n.º 3
0
    def test_weighted_category_accuracy_uneven_weights(self):
        layer = tl.WeightedCategoryAccuracy()
        weights = np.array([1., 5., 2.])
        targets = np.array([0, 1, 2])

        model_outputs = np.array([[.7, .2, .1, 0.], [.2, .7, .1, 0.],
                                  [.2, .1, .7, 0.]])
        accuracy = layer([model_outputs, targets, weights])
        self.assertEqual(accuracy, 1.0)

        model_outputs = np.array([[.2, .7, .1, 0.], [.2, .7, .1, 0.],
                                  [.2, .7, .1, 0.]])
        accuracy = layer([model_outputs, targets, weights])
        self.assertEqual(accuracy, .625)
Ejemplo n.º 4
0
        'opt_state',  # OptState.
        'history',  # trax.history.History.
        'model_state',  # Auxilliary state of the model.
    ])

OptState = collections.namedtuple(
    '_OptState',
    [
        'weights',  # Model weights.
        'slots',  # Per-parameter optimizer state, e.g. gradient moments.
        'opt_params',  # Optimizer (hyper)parameters, e.g. learning rate, momentum.
    ])

_DEFAULT_METRICS = {
    'loss': tl.WeightedCategoryCrossEntropy(),
    'accuracy': tl.WeightedCategoryAccuracy(),
    'sequence_accuracy': tl.MaskedSequenceAccuracy(),
    'neg_log_perplexity': tl.Serial(tl.WeightedCategoryCrossEntropy(),
                                    tl.Negate()),
    'weights_per_batch_per_core': tl.Serial(tl.Drop(), tl.Drop(), tl.Sum()),
}


class Trainer:
    """Trax trainer.

  A trainer allows to make training steps, train for full epochs,
  save the training state and access evaluation data.
  """
    def __init__(self,
                 model,
Ejemplo n.º 5
0
Archivo: train.py Proyecto: JEF1056/R5
    print(f"(device count, tokens per device) = {test.shape}\n")
del teststream, test

# Training task.
train_task = training.TrainTask(
    labeled_data=stream(trax.fastmath.device_count(), "train"),
    loss_layer=tl.WeightedCategoryCrossEntropy(),
    lr_schedule=trax.lr.multifactor(),
    optimizer=trax.optimizers.Adam(),
    n_steps_per_checkpoint=1000,
)

# Evaluaton task.
eval_task = training.EvalTask(
    labeled_data=stream(trax.fastmath.device_count(), "validation"),
    metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()],
    n_eval_batches=10  # For less variance in eval numbers.
)

output_dir = os.path.expanduser(args.dir)

print("~~Begin Training~~")
# Train tiny model with Loop.
training_loop = training.Loop(
    trax.models.ReformerLM(mode="train"),
    train_task,
    eval_tasks=[eval_task],
    output_dir=output_dir)

# run 1000 steps (batches)
training_loop.run(1000000)