Beispiel #1
0
  def pseudo_forward(self, pseudo_inputs, params, state):
    """Computes shapes and types this layer would produce for the given inputs.

    Args:
      pseudo_inputs: A ShapeDtype instance (input data minus the actual values)
          or a tuple of ShapeDtype instances, following the same conventions as
          Layer.forward's input arg.
      params: Parameters for this layer.
      state: start state.

    Returns:
      A tuple of (output, state).

      The output part of the tuple is a ShapeDtype instance representing the
      shape and type of the output (if this layer has one output) or a tuple
      of ShapeDtype instances (if this layer has more than one output).
    """
    try:
      # Beware: using an actual RNG (as opposed to this ShapeDtype stub) would
      # cause a large number of dropout masks to be computed and permanently
      # stored in global memory.
      rng = ShapeDtype((2,), onp.uint32)
      def call_on_input(x, params, state, rng):
        return self.forward(x, params=params, state=state, rng=rng)
      params_shapes = nested_map(shape_dtype_for, params)
      s = backend.eval_on_shapes(call_on_input)(pseudo_inputs,
                                                params_shapes, state, rng)
      return s
    except Exception:
      name, trace = self.__class__.__name__, _short_traceback(skip=3)
      raise LayerError(name, 'pseudo_forward', self._caller, pseudo_inputs,
                       None, trace)
Beispiel #2
0
    def _forward_abstract(self, input_signature):
        """Computes shapes and dtypes this layer would produce in a forward pass.

    Args:
      input_signature: A ShapeDtype instance (if this layer takes one input)
          or a list/tuple of ShapeDtype instances; signatures of inputs.

    Returns:
      A tuple of (output, state).

      The output part of the tuple is a ShapeDtype instance representing the
      shape and type of the output (if this layer has one output) or a tuple
      of ShapeDtype instances (if this layer has more than one output).
    """
        try:
            # Beware: using an actual RNG (as opposed to this ShapeDtype stub) would
            # cause a large number of dropout masks to be computed and permanently
            # stored in global memory.
            rng = ShapeDtype((2, ), onp.uint32)

            def call_on_input(x, weights, state, rng):
                return self.forward_with_state(x,
                                               weights=weights,
                                               state=state,
                                               rng=rng)

            weight_signature = nested_map(signature, self.weights)
            s = backend.abstract_eval(call_on_input)(input_signature,
                                                     weight_signature,
                                                     self.state, rng)
            return s
        except Exception:
            name, trace = self.__class__.__name__, _short_traceback(skip=3)
            raise LayerError(name, '_forward_abstract', self._caller,
                             input_signature, trace)
Beispiel #3
0
def sizes(x):
  """Get a structure of sizes for a structure of nested arrays."""
  def size(x):
    try:
      return x.size
    except Exception:  # pylint: disable=broad-except
      return 0
  return nested_map(size, x)
Beispiel #4
0
def shapes(x):
  """Get a structure of shapes for a structure of nested arrays."""
  def shape(x):
    try:
      return tuple([int(i) for i in x.shape])
    except Exception:  # pylint: disable=broad-except
      return []
  return nested_map(shape, x)
Beispiel #5
0
def _combine_devices(x_tuple):
  """Combine multi-device tensors into a single batch."""
  def f(x):
    if len(x.shape) < 2:
      return x  # No extra batch dimension: use devices as batch, so return.
    batch_size = x.shape[0] * x.shape[1]
    return backend.numpy.reshape(x, [batch_size] + list(x.shape[2:]))
  return backend.nested_map(f, x_tuple)
Beispiel #6
0
 def predict(x, weights, state, rng):
   """Predict function jited and parallelized as requested."""
   res, state = _combine_devices(model_predict(
       _reshape_by_device(x, n_devices),
       weights,
       state,
       np.stack(jax_random.split(rng, n_devices))))
   return backend.nested_map(lambda y: np.mean(y, axis=0), res), state
Beispiel #7
0
 def print_n_weights(self):
     """Prints the total count of trainable weights."""
     opt_state = self._opt_state
     sizes = _sizes(opt_state.weights)
     if self.n_devices > 1:
         unreplicate = lambda x: x[0]
         single_weights = backend.nested_map(unreplicate, opt_state.weights)
         sizes = _sizes(single_weights)
     total_size = _nested_reduce(sum, sizes)
     self.log_step('Total number of trainable weights: %d' % total_size)
Beispiel #8
0
 def _for_n_devices(self, x):
   """Replicates/broadcasts `x` for n devices if `self.n_devicess > 1`."""
   n = self.n_devices
   def f(x):
     if n > 1 and backend.get_name() == 'jax':
       return _multi_device_put(x)
     elif n > 1:
       return np.broadcast_to(x, (n,) + x.shape)
     else:
       return x
   return backend.nested_map(f, x)
Beispiel #9
0
def _reshape_by_device(x, n_devices):
  """Reshapes possibly nested x into a shape (n_devices, ...)."""
  def f(x):
    x_shape = list(x.shape)
    batch_size = x_shape[0]
    batch_size_per_device = batch_size // n_devices
    if batch_size_per_device * n_devices != batch_size:
      raise ValueError(
          'We require that n_devices[%d] divides batch_size[%d] evenly.' %
          (n_devices, batch_size))
    new_shape_prefix = [n_devices, batch_size_per_device]
    return backend.numpy.reshape(x, new_shape_prefix + x_shape[1:])
  return backend.nested_map(f, x)
Beispiel #10
0
    def train_step(self, batch):
        """Run one training step and update self._opt_state."""
        # Calculate the current optimizer parameters.
        # TODO(pkozakowski): Optimizer parameters get polluted with model state,
        # which doesn't break anything but is weird. Filter it out.
        opt_param_updates = self._for_n_devices(
            backend.nested_map(np.array, self.nontrainable_params))
        opt_state = self._opt_state
        opt_state.opt_params.update(opt_param_updates)

        # Run the update.
        (weights, slots), self._model_state, self._rngs = self._jit_update_fn(
            self._step, opt_state, batch, self._model_state, self._rngs)
        self._model_state = self._map_to_state_dicts(self._state_dicts_update)
        self._opt_state = opt_state._replace(weights=weights, slots=slots)
        self._step += 1
Beispiel #11
0
  def save_state(self, keep):
    """Save trainer state given a possibly replicated opt_state."""
    opt_state = self._opt_state
    if self.n_devices > 1:
      first_replica = lambda x: x[0]
      opt_state = OptState(*backend.nested_map(first_replica, opt_state))
    # This line, while optional, allows JAX to transfer arrays from the device
    # to the host in parallel, which is particularly important for cloud TPU.
    if backend.get_name() == 'jax':
      opt_state = jax.device_get(opt_state)
    step, history, model_state = self._step, self._history, self._model_state
    output_dir = self._output_dir

    pkl_module = utils.get_pickle_module()
    weights_file = os.path.join(output_dir, 'model.pkl')
    with tf.io.gfile.GFile(weights_file, 'wb') as f:
      pkl_module.dump((tuple(opt_state), step, history, model_state), f)
    if keep:
      weights_file = os.path.join(output_dir, 'model_{}.pkl'.format(step))
      with tf.io.gfile.GFile(weights_file, 'wb') as f:
        pkl_module.dump((tuple(opt_state), step, history, model_state), f)
    log('Model saved to %s' % weights_file, stdout=False)
Beispiel #12
0
    def __init__(self,
                 model,
                 loss_fn,
                 optimizer,
                 lr_schedule,
                 inputs,
                 output_dir=None,
                 random_seed=None,
                 n_devices=None,
                 save_steps=None,
                 should_save_checkpoints=True,
                 should_write_summaries=True,
                 has_weights=False,
                 nontrainable_param_map=None,
                 mask_id=None,
                 metrics=None):
        if backend.get_name() == 'jax':
            self._host_id = jax.host_id()
            self._host_count = jax.host_count()
        else:
            self._host_id = 0
            self._host_count = 1
        self._is_chief = (self._host_id == 0)

        if save_steps is None:
            save_steps = []
        self._save_steps = save_steps
        self._should_save_checkpoints = should_save_checkpoints
        self._should_write_summaries = should_write_summaries
        self._has_weights = has_weights
        self._mask_id = mask_id
        self._metrics_dict = _METRICS if metrics is None else metrics
        loss_fn = loss_fn(has_weights=has_weights, mask_id=mask_id)
        device_count = backend.device_count()
        n_devices = n_devices or device_count
        # TODO(lukaszkaiser): remove this restriction when possible.
        if n_devices != device_count and backend.get_name() == 'jax':
            raise ValueError(
                'JAX cannot work yet with n_devices != all devices: '
                '%d != %d' % (n_devices, device_count))
        self._n_devices = n_devices

        # Simple differential seeding of RNG across hosts by host_id and time.
        if random_seed is None and self._host_count > 1:
            _, random_seed = divmod(
                int(time.time() * 1e6) + int(self._host_id * 1e6), 2**32)
        rng = get_random_number_generator_and_set_seed(random_seed)
        inputs = inputs(n_devices)
        self._inputs = inputs

        # Initialize the learning rate to a dummy value. It will be set in reset().
        opt = optimizer(learning_rate=0.0)

        # Setup the model.
        model_train = model(mode='train')
        model_predict_eval = model(mode='eval')

        # Setup state.
        rng, init_rng = jax_random.split(rng)
        self._rngs = np.stack(jax_random.split(rng, n_devices))
        first_shape = inputs.input_shape[0]
        # If the inputs are a tuple/list, add [None] (batch) to each element.
        if isinstance(first_shape, (list, tuple)):
            model_input_shape = tuple(
                tuple([None] + list(shape)) for shape in inputs.input_shape)
            model_target_shape = tuple(
                tuple([None] + list(shape)) for shape in inputs.target_shape)
        else:  # Otherwise just add [None] to the input shape.
            model_input_shape = tuple([None] + list(inputs.input_shape))
            model_target_shape = tuple([None] + list(inputs.target_shape))
        # Change all None to 1 in input and target shape.
        model_input_shape = backend.nested_map(lambda x: x or 1,
                                               model_input_shape)
        model_target_shape = backend.nested_map(lambda x: x or 1,
                                                model_target_shape)

        def new_opt_state_and_model_state(input_shape, input_dtype,
                                          target_shape, target_dtype, rng):
            """Returns optimizer and model states suitable for training a model."""
            # Combine inputs and targets on the stack.
            if not isinstance(input_dtype, (list, tuple)):
                input_dtype = [input_dtype]
                input_shape = [input_shape]
            if not isinstance(target_dtype, (list, tuple)):
                target_dtype = [target_dtype]
                target_shape = [target_shape]
            dtypes = list(input_dtype) + list(target_dtype)
            shapes = list(input_shape) + list(target_shape)
            if self._has_weights:
                shapes += list(target_shape)
                dtypes += [np.float32 for _ in target_dtype]
            input_signature = tuple(
                ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes))
            # We need to create a new model instance and not reuse `model_train` here,
            # because `m.initialize` puts cached parameter values in `m` and hence the
            # next call of `m.initialize` will give wrong results.
            m = tl.Serial(model(mode='train'), loss_fn)
            m._set_rng_recursive(rng)  # pylint: disable=protected-access
            weights, state = m.init(input_signature)
            (slots, opt_params) = opt.tree_init(weights)
            return (OptState(weights, slots, opt_params), state)

        if _is_jit_init():
            # JIT parameter initialization to avoid memory fragmentation
            new_opt_state_and_model_state = backend.jit(
                new_opt_state_and_model_state, static_argnums=(0, 1, 2, 3))
        self._new_opt_state_and_model_state = (
            lambda: new_opt_state_and_model_state(  # pylint: disable=g-long-lambda
                model_input_shape, self._inputs.input_dtype,
                model_target_shape, self._inputs.target_dtype, init_rng))

        # Arrange and initialize metrics layers.
        self._metrics = list(sorted(self._metrics_dict.keys()))
        metrics_layers = [
            self._metrics_dict[m](has_weights=self._has_weights,
                                  mask_id=self._mask_id) for m in self._metrics
        ]
        metrics_in_parallel = tl.Branch(*metrics_layers)
        # TODO(lukaszkaiser): clean this up once layer API stabilizes.
        # For now, we need to initialize metric layers somehow, so here we go.
        # We assume that they do not have any parameters, so this is a dummy.
        dummy_shapes = ((1, 2), (1, ),
                        (1, )) if self._has_weights else ((1, 2), (1, ))
        dummy_signature = tuple(ShapeDtype(s) for s in dummy_shapes)
        metrics_in_parallel._set_rng_recursive(init_rng)  # pylint: disable=protected-access
        m_weights, m_state = metrics_in_parallel.init(dummy_signature)
        self._metrics_weights = self._for_n_devices(m_weights)
        self._metrics_state = self._for_n_devices(m_state)

        # Jit model_predict and update so they're fast.
        self._jit_eval = _jit_predict_fn(model_predict_eval,
                                         metrics_in_parallel, n_devices)
        self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt,
                                             n_devices)

        self._model_train = model_train
        self._model_predict_eval = model_predict_eval
        self._loss_fn = loss_fn
        # TODO(pkozakowski): "Learning rate schedules" are currently able to control
        # control all optimizer parameters and model state, so let's rename them
        # accordingly.
        self._lr_schedule = lr_schedule

        if nontrainable_param_map is None:
            nontrainable_param_map = {}
        self._nontrainable_param_map = nontrainable_param_map

        # Those fields will be set in reset().
        self._output_dir = None
        self._train_sw = None
        self._eval_sw = None
        self._history = None
        self._lr_fn = None
        self._opt_state = None
        self._step = None
        self._model_state = None

        if output_dir is not None:
            self.reset(output_dir)