Beispiel #1
0
    def evaluate(self, n_eval_steps):
        """Evaluate the model and log metrics."""
        _, rng = jax_random.split(self._rngs[0])
        # TODO(lukaszkaiser): both model state and parameters by default include
        # the loss layer. Currently, we access the pure-model parameters by just
        # indexing, [0] here. But we should make it more explicit in a better API.
        weights = (self._opt_state.weights[0], self._metrics_weights)
        state = (self._model_state[0], self._metrics_state)
        self.log_step('Evaluation')
        train_eval_slice = itertools.islice(self._train_eval_stream,
                                            n_eval_steps)
        train_metrics, _ = self.evaluation_round(train_eval_slice, weights,
                                                 state, rng)
        self.log_metrics(train_metrics, self._train_sw, 'train')
        eval_slice = itertools.islice(self._eval_stream, n_eval_steps)
        eval_metrics, _ = self.evaluation_round(eval_slice, weights, state,
                                                rng)
        self.log_metrics(eval_metrics, self._eval_sw, 'eval')
        self.log_step('Finished evaluation')

        # Save the learning rate in history.
        self._history.append('train', 'training/learning_rate', self._step,
                             self.learning_rate)
Beispiel #2
0
    def __init__(self,
                 model,
                 loss_fn,
                 optimizer,
                 lr_schedule,
                 inputs,
                 output_dir=None,
                 random_seed=None,
                 n_devices=None,
                 checkpoints_at=None,
                 should_save_checkpoints=True,
                 should_write_summaries=True,
                 metrics=None,
                 checkpoint_highest=None,
                 checkpoint_lowest=None):

        self._is_chief, _, self._n_devices, rng = (
            training.init_host_and_devices(n_devices, random_seed))
        self._should_save_checkpoints = should_save_checkpoints and self._is_chief
        self._checkpoints_at = checkpoints_at if checkpoints_at is not None else []
        self._should_write_summaries = should_write_summaries
        if not output_dir:
            self._should_save_checkpoints = False
            self._should_write_summaries = False
        self._checkpoint_highest = checkpoint_highest
        self._checkpoint_lowest = checkpoint_lowest
        self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS
        # Inputs is either an Inputs instance or a function that returns it.
        self._inputs = inputs
        if callable(
                inputs):  # If we pass a function, e.g., through gin, call it.
            self._inputs = inputs()
        # Initialize the learning rate to a dummy value. It will be set in reset().
        opt = optimizer(learning_rate=0.0)

        # Setup the model.
        model_train = model(mode='train')
        model_predict_eval = model(mode='eval')
        self._model_with_loss = tl.Serial(model_train, loss_fn)

        # Setup state.
        rng, init_rng = jax_random.split(rng)
        self._rngs = np.stack(jax_random.split(rng, self._n_devices))
        shapes, dtypes = self._inputs.example_shape_dtype
        input_signature = tuple(
            ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes))

        def new_opt_state_and_model_state(rng):
            """Returns optimizer and model states suitable for training a model."""
            weights, state = self._model_with_loss.init(input_signature,
                                                        rng=rng)
            (slots, opt_params) = opt.tree_init(weights)
            return (OptState(weights, slots, opt_params), state)

        if fastmath.is_backend(fastmath.Backend.JAX):
            # JIT parameter initialization to avoid memory fragmentation
            new_opt_state_and_model_state = (
                fastmath.jit(new_opt_state_and_model_state))
        self._new_opt_state_and_model_state = (
            lambda: new_opt_state_and_model_state(init_rng))

        # Arrange and initialize metrics layers.
        self._metrics = list(sorted(self._metrics_dict.keys()))
        metrics_layers = [self._metrics_dict[m] for m in self._metrics]
        metrics_in_parallel = tl.Branch(*metrics_layers)
        metrics_in_parallel.rng = init_rng
        example_signature = tuple(
            ShapeDtype(s, d)
            for (s, d) in zip(*self._inputs.example_shape_dtype))
        model_predict_eval.init(example_signature)
        self._input_signature = example_signature
        output_signature = model_predict_eval.output_signature(
            example_signature)
        m_weights, m_state = metrics_in_parallel.init(output_signature)
        self._metrics_weights = self._for_n_devices(m_weights)
        self._metrics_state = self._for_n_devices(m_state)

        # Jit model_predict and update so they're fast.
        self._jit_eval = _jit_predict_fn(model_predict_eval,
                                         metrics_in_parallel, self._n_devices)
        self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt,
                                             self._n_devices)

        self._model_train = model_train
        self._model_predict_eval = model_predict_eval
        self._loss_fn = loss_fn
        self._lr_schedule = lr_schedule

        # Those fields will be set in reset().
        self._output_dir = None
        self._train_sw = None
        self._eval_sw = None
        self._history = None
        self._opt_state = None
        self._step = None
        self._model_state = None
        self.reset(output_dir)
Beispiel #3
0
 def mapped_compute_loss(opt_state, batch, state, rng):
     """This is a multi-device version of the update function above."""
     # We assume all tensors have the first dimension = n_devices.
     rng, subrng = jax_random.split(rng)
     loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state, rng)
     return loss_val, state, subrng
Beispiel #4
0
 def single_compute_loss(opt_state, batch, state, rng):
     rng, subrng = jax_random.split(rng[0])
     loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state,
                               rng)
     return loss_val, state, [subrng]