Esempio n. 1
0
    def reset(self, output_dir):
        """Reset the model parameters.

    Restores the parameters from the given output_dir if a checkpoint exists,
    otherwise randomly initializes them.

    Does not re-jit the model.

    Args:
      output_dir: Output directory.
    """
        self._output_dir = output_dir
        gfile.makedirs(output_dir)
        # Create summary writers and history.
        if self._should_write_summaries:
            self._train_sw = jaxboard.SummaryWriter(os.path.join(
                output_dir, 'train'),
                                                    enable=self.is_chief)
            self._eval_sw = jaxboard.SummaryWriter(os.path.join(
                output_dir, 'eval'),
                                                   enable=self.is_chief)

        # Reset the train and eval streams.
        self._train_stream = self._inputs.train_stream()
        # TODO(lukaszkaiser): add an option to evaluate exactly on the full eval
        #   set by adding a padding and stopping the stream when too large.
        self._eval_stream = _repeat_stream(self._inputs.eval_stream)
        self._train_eval_stream = _repeat_stream(
            self._inputs.train_eval_stream)

        # Restore the training state.
        state = load_trainer_state(output_dir)
        self._step = state.step or 0
        history = state.history
        self._lr_fn = self._lr_schedule(history)
        self._history = history
        if state.opt_state:
            opt_state = state.opt_state
            model_state = state.model_state
        else:
            opt_state, model_state = self._new_opt_state_and_model_state()
            model_state = layers.nested_map(self._maybe_replicate, model_state)
        self._opt_state = OptState(
            *layers.nested_map(self._maybe_replicate, opt_state))
        self._model_state = model_state
        if not state.opt_state and self.is_chief:
            self._maybe_save_state(keep=False)

        self.update_nontrainable_params()
 def predict(x, weights, state, rng):
   """Predict function jited and parallelized as requested."""
   res, state = backend.combine_devices(model_predict(
       backend.reshape_by_device(x, n_devices),
       weights,
       state,
       np.stack(jax_random.split(rng, n_devices))))
   return layers.nested_map(lambda y: np.mean(y, axis=0), res), state
def _sizes(x):
  """Get a structure of sizes for a structure of nested arrays."""
  def size(x):
    try:
      return x.size
    except Exception:  # pylint: disable=broad-except
      return 0
  return layers.nested_map(size, x)
def _print_n_weights(opt_state, n_devices, step):
  """Print out the number of parameters."""
  sizes = _sizes(opt_state.weights)
  if n_devices > 1:
    unreplicate = lambda x: x[0]
    single_weights = layers.nested_map(unreplicate, opt_state.weights)
    sizes = _sizes(single_weights)
  total_size = _nested_reduce(sum, sizes)
  step_log(step, 'Total number of trainable weights: %d' % total_size)
Esempio n. 5
0
def _print_n_params(opt_state, n_devices, step):
    """Print out the number of parameters."""
    sizes = layers.sizes(opt_state.params)
    if n_devices > 1:
        unreplicate = lambda x: x[0]
        single_params = layers.nested_map(unreplicate, opt_state.params)
        sizes = layers.sizes(single_params)
    total_size = layers.nested_reduce(sizes, sum)
    step_log(step, 'Total trainable parameters size: %d' % total_size)
Esempio n. 6
0
 def print_n_weights(self):
     """Prints the total count of trainable weights."""
     opt_state = self._opt_state
     sizes = _sizes(opt_state.weights)
     if self._n_devices > 1:
         unreplicate = lambda x: x[0]
         single_weights = layers.nested_map(unreplicate, opt_state.weights)
         sizes = _sizes(single_weights)
     total_size = _nested_reduce(sum, sizes)
     self.log_step('Total number of trainable weights: %d' % total_size)
def _save_replicated(opt_state, step, history, model_state, n_devices,
                     output_dir, keep):
  """Saves trainer state but given a possibly replicated opt_state."""
  if n_devices > 1:
    first_replica = lambda x: x[0]
    opt_state = OptState(*layers.nested_map(first_replica, opt_state))
  # This line, while optional, allows JAX to transfer arrays from the device to
  # the host in parallel, which is particularly important for cloud TPU.
  if backend.get_name() == 'jax':
    opt_state = jax.device_get(opt_state)
  save_trainer_state(
      TrainerState(opt_state=opt_state, step=step, history=history,
                   model_state=model_state), output_dir, keep=keep)
  def _train_step(self, next_train_batch):
    """Run one training step and update self._opt_state."""
    # Calculate the current optimizer parameters.
    # TODO(pkozakowski): Optimizer parameters get polluted with model state,
    # which doesn't break anything but is weird. Filter it out.
    opt_param_updates = layers.nested_map(
        lambda x: self._maybe_replicate(np.array(x)),
        self.nontrainable_params)
    opt_state = self._opt_state
    opt_state.opt_params.update(opt_param_updates)

    # Run the update.
    (weights, slots), self._model_state, self._rngs = self._jit_update_fn(
        self._step, opt_state, next_train_batch, self._model_state, self._rngs)
    self._model_state = self._map_to_state_dicts(self._state_dicts_update)
    self._opt_state = opt_state._replace(weights=weights, slots=slots)
    self._step += 1
Esempio n. 9
0
    def save_state(self, keep):
        """Save trainer state given a possibly replicated opt_state."""
        opt_state = self._opt_state
        if self._n_devices > 1:
            first_replica = lambda x: x[0]
            opt_state = OptState(*layers.nested_map(first_replica, opt_state))
        # This line, while optional, allows JAX to transfer arrays from the device
        # to the host in parallel, which is particularly important for cloud TPU.
        if backend.get_name() == 'jax':
            opt_state = jax.device_get(opt_state)
        step, history, model_state = self._step, self._history, self._model_state
        output_dir = self._output_dir

        pkl_module = utils.get_pickle_module()
        weights_file = os.path.join(output_dir, 'model.pkl')
        with gfile.GFile(weights_file, 'wb') as f:
            pkl_module.dump((tuple(opt_state), step, history, model_state), f)
        if keep:
            weights_file = os.path.join(output_dir,
                                        'model_{}.pkl'.format(step))
            with gfile.GFile(weights_file, 'wb') as f:
                pkl_module.dump((tuple(opt_state), step, history, model_state),
                                f)
        log('Model saved to %s' % weights_file, stdout=False)
  def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs,
               output_dir=None, random_seed=None, n_devices=None,
               save_steps=None, should_save_checkpoints=True,
               should_write_summaries=True, has_weights=False,
               nontrainable_param_map=None, mask_id=None, metrics=None):
    if save_steps is None:
      save_steps = []
    self._save_steps = save_steps
    self._should_save_checkpoints = should_save_checkpoints
    self._should_write_summaries = should_write_summaries
    self._has_weights = has_weights
    self._mask_id = mask_id
    self._metrics_dict = _METRICS if metrics is None else metrics
    loss_fn = loss_fn(has_weights=has_weights, mask_id=mask_id)
    device_count = backend.device_count()
    n_devices = n_devices or device_count
    if backend.get_name() == 'jax':
      self._host_id = jax.host_id()
      self._host_count = jax.host_count()
    else:
      self._host_id = 0
      self._host_count = 1
    # TODO(lukaszkaiser): remove this restriction when possible.
    if n_devices != device_count and backend.get_name() == 'jax':
      raise ValueError('Jax cannot work yet with n_devices != all devices: '
                       '%d != %d' % (n_devices, device_count))
    self._n_devices = n_devices

    # Simple differential seeding of RNG across hosts by host_id and time.
    if random_seed is None and self._host_count > 1:
      _, random_seed = divmod(int(time.time() * 1e6) +
                              int(self._host_id * 1e6), 2**32)
    rng = get_random_number_generator_and_set_seed(random_seed)
    inputs = inputs(n_devices)
    self._inputs = inputs

    # Initialize the learning rate to a dummy value. It will be set in reset().
    opt = optimizer(learning_rate=0.0)

    # Setup the model.
    model_train = model(mode='train')
    model_predict_eval = model(mode='eval')

    # Setup state.
    rng, init_rng = jax_random.split(rng)
    self._rngs = np.stack(jax_random.split(rng, n_devices))
    first_shape = inputs.input_shape[0]
    # If the inputs are a tuple/list, add [None] (batch) to each element.
    if isinstance(first_shape, (list, tuple)):
      model_input_shape = tuple(
          tuple([None] + list(shape)) for shape in inputs.input_shape)
      model_target_shape = tuple(
          tuple([None] + list(shape)) for shape in inputs.target_shape)
    else:  # Otherwise just add [None] to the input shape.
      model_input_shape = tuple([None] + list(inputs.input_shape))
      model_target_shape = tuple([None] + list(inputs.target_shape))
    # Change all None to 1 in input and target shape.
    model_input_shape = layers.nested_map(lambda x: x if x else 1,
                                          model_input_shape)
    model_target_shape = layers.nested_map(lambda x: x if x else 1,
                                           model_target_shape)
    def new_opt_state_and_model_state(input_shape, input_dtype, target_shape,
                                      target_dtype, rng):
      """Returns optimizer and model states suitable for training a model."""
      # Combine inputs and targets on the stack.
      if not isinstance(input_dtype, (list, tuple)):
        input_dtype = [input_dtype]
        input_shape = [input_shape]
      if not isinstance(target_dtype, (list, tuple)):
        target_dtype = [target_dtype]
        target_shape = [target_shape]
      dtypes = list(input_dtype) + list(target_dtype)
      shapes = list(input_shape) + list(target_shape)
      if self._has_weights:
        shapes += list(target_shape)
        dtypes += [np.float32 for _ in target_dtype]
      input_signature = tuple(ShapeDtype(s, d)
                              for (s, d) in zip(shapes, dtypes))
      # We need to create a new model instance and not reuse `model_train` here,
      # because `m.initialize` puts cached parameter values in `m` and hence the
      # next call of `m.initialize` will give wrong results.
      m = layers.Serial(model(mode='train'), loss_fn)
      m._set_rng(rng)  # pylint: disable=protected-access
      weights, state = m.initialize_once(input_signature)
      (slots, opt_params) = opt.tree_init(weights)
      return (OptState(weights, slots, opt_params), state)
    if _is_jit_init():
      # JIT parameter initialization to avoid memory fragmentation
      new_opt_state_and_model_state = backend.jit(new_opt_state_and_model_state,
                                                  static_argnums=(0, 1, 2, 3))
    self._new_opt_state_and_model_state = (
        lambda: new_opt_state_and_model_state(  # pylint: disable=g-long-lambda
            model_input_shape, self._inputs.input_dtype,
            model_target_shape, self._inputs.target_dtype, init_rng))

    # jit model_predict and update so they're fast
    # TODO(lukaszkaiser): the code below creates a layer computing
    # multiple metrics from a single model output; re-factor for clarity.
    dup_layer = layers.Dup3() if self._has_weights else layers.Dup2()
    def lower(layer):
      """Apply layer below the current inputs, targets, and possibly weights."""
      if self._has_weights:
        # Apply layer below inputs, targets, and loss weights.
        return layers.Parallel([], [], [], layer)
      else:
        # Apply layer below inputs and targets.
        return layers.Parallel([], [], layer)
    metrics_layer = []
    self._metrics = list(sorted(self._metrics_dict.keys()))
    for i, m in enumerate(reversed(self._metrics)):
      metric = self._metrics_dict[m](has_weights=self._has_weights,
                                     mask_id=self._mask_id)
      if i != len(self._metrics) - 1:
        metrics_layer.append(dup_layer)
        metrics_layer.append(lower(metric))
      else:
        metrics_layer.append(metric)
    # TODO(lukaszkaiser): clean this up once layer API stabilizes.
    # For now, we need to initialize metric layers somehow, so here we go.
    # We assume that they do not have any parameters, so this is a dummy.
    dummy_shapes = ((1, 2), (1,), (1,)) if self._has_weights else ((1, 2), (1,))
    dummy_dtypes = [np.float32] * (3 if self._has_weights else 2)
    dummy_signature = tuple(ShapeDtype(s, d)
                            for s, d in zip(dummy_shapes, dummy_dtypes))
    metrics_layer = layers.Serial(metrics_layer)
    metrics_layer._set_rng(init_rng)  # pylint: disable=protected-access
    metrics_weights, metrics_state = (
        metrics_layer.initialize_once(dummy_signature))
    self._metrics_weights = layers.nested_map(self._maybe_replicate,
                                              metrics_weights)
    self._metrics_state = layers.nested_map(self._maybe_replicate,
                                            metrics_state)
    self._jit_eval = _jit_predict_fn(
        model_predict_eval, metrics_layer, n_devices)
    self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices)

    self._model_train = model_train
    self._model_predict_eval = model_predict_eval
    self._loss_fn = loss_fn
    # TODO(pkozakowski): "Learning rate schedules" are currently able to control
    # control all optimizer parameters and model state, so let's rename them
    # accordingly.
    self._lr_schedule = lr_schedule

    if nontrainable_param_map is None:
      nontrainable_param_map = {}
    self._nontrainable_param_map = nontrainable_param_map

    # Those fields will be set in reset().
    self._output_dir = None
    self._train_sw = None
    self._eval_sw = None
    self._history = None
    self._lr_fn = None
    self._opt_state = None
    self._step = None
    self._model_state = None

    if output_dir is not None:
      self.reset(output_dir)