コード例 #1
0
def _jit_predict_fn(model_predict, metric_fn, n_devices, jit=True):
    """Returns a JIT-compiled predict function (unless jit=False)."""
    model_predict = layers.Serial([model_predict, metric_fn])

    if n_devices == 1:
        return backend.jit(model_predict) if jit else model_predict

    # Multi-devices, pmap and run.
    @functools.partial(backend.pmap, axis_name="batch")
    def mapped_predict(x, params, state, rng):
        return model_predict(x, params=params, state=state, rng=rng)

    def predict(x, params=(), state=(), rng=None):
        """Predict function jited and parallelized as requested."""
        pred = mapped_predict(reshape_by_device(x, n_devices), params, state,
                              jax_random.split(rng, n_devices))

        # Need to reduce the [device, per-device-batch, ...] tensors back to
        # a [batch, ...] tensor. The tensors may be nested.
        def combine(x):
            if len(x.shape) > 1:
                batch_size = x.shape[0] * x.shape[1]
                return np.reshape(x, [batch_size] + list(x.shape[2:]))
            # TODO(lukaszkaiser): is returning averages for scalars the right choice?
            # If it is only scalar, return the average.
            return np.mean(x, axis=0)

        return layers.nested_map(pred, combine)

    return predict
コード例 #2
0
def _jit_update_fn(predict_fn, loss_fn, optimizer, n_devices, jit=True):
  """Get jit-ed update function for loss, optimizer, learning rate function."""
  if n_devices == 1:  # TODO(lukaszkaiser): remove branch when not needed.
    def single_update(i, opt_state, batch, rng):
      rng, subrng = jax_random.split(rng[0])
      params, opt_slots = opt_state
      return optimizer.tree_update(i, backend.grad(loss_fn)(
          params, batch, predict_fn, rng), params, opt_slots), [subrng]
    if jit:
      return backend.jit(single_update)
    else:
      return single_update

  @functools.partial(backend.pmap, axis_name="batch")
  def mapped_update(i, opt_state, batch, rng):
    """This is a multi-device version of the update function above."""
    # We assume all tensors have the first dimension = n_devices.
    rng, subrng = jax_random.split(rng)
    params, opt_slots = opt_state
    grads = backend.grad(loss_fn)(params, batch, predict_fn, rng)
    grads = jax.tree_util.tree_map(
        lambda g: lax.psum(g, "batch"), grads)
    return optimizer.tree_update(i, grads, params, opt_slots), subrng

  def update(i, opt_state, batch, rng):
    return mapped_update(numpy.repeat(i, n_devices), opt_state, batch, rng)

  return update
コード例 #3
0
def _jit_predict_fn(model_predict, n_devices, jit=True):
  """Use jit on model_predict if required."""

  if n_devices == 1:
    if jit:
      return backend.jit(model_predict)
    else:
      return model_predict

  # Multi-devices, pmap and run.
  @functools.partial(backend.pmap, axis_name="batch")
  def mapped_predict(x, params, rng):
    return model_predict(x, params, rng=rng)

  def predict(x, params=(), rng=None):
    """Predict function jited and parallelized as requested."""
    # On one device, jit and run.
    pred = mapped_predict(
        reshape_by_device(x, n_devices),
        params,
        jax_random.split(rng, n_devices))
    # Need to reduce the [device, per-device-batch, ...] tensors back to
    # a [batch, ...] tensor. The tensors may be nested.
    if not isinstance(pred, (list, tuple)):  # Not nested.
      batch_size = pred.shape[0] * pred.shape[1]
      return np.reshape(pred, [batch_size] + list(pred.shape[2:]))
    batch_size = pred[0].shape[0] * pred[0].shape[1]
    return [np.reshape(p, [batch_size] + list(p.shape[2:])) for p in pred]

  return predict
コード例 #4
0
def _jit_predict_fn(model_predict, n_devices, jit=True):
    """Returns a JIT-compiled predict function (unless jit=False)."""

    if n_devices == 1:
        return backend.jit(model_predict) if jit else model_predict

    # Multi-devices, pmap and run.
    @functools.partial(backend.pmap, axis_name="batch")
    def mapped_predict(x, params, state, rng):
        return model_predict(x, params, state, rng=rng)

    def predict(x, params=(), state=(), rng=None):
        """Predict function jited and parallelized as requested."""
        pred, state = mapped_predict(reshape_by_device(x, n_devices), params,
                                     state, jax_random.split(rng, n_devices))

        # Need to reduce the [device, per-device-batch, ...] tensors back to
        # a [batch, ...] tensor. The tensors may be nested.
        def combine(x):
            batch_size = x.shape[0] * x.shape[1]
            return np.reshape(x, [batch_size] + list(x.shape[2:]))

        return layers.nested_map(pred, combine), state

    return predict
コード例 #5
0
def _jit_update_fn(predict_fn, loss_fn, optimizer, n_devices, jit=True):
    """Returns a (JIT-compiled) function that computes updates for one step."""
    if n_devices == 1:  # TODO(lukaszkaiser): remove branch when not needed.

        def single_update(i, opt_state, batch, state, rng):
            params, slots, opt_params = opt_state
            rng, subrng = jax_random.split(rng[0])
            grad_fn = backend.grad(loss_fn, has_aux=True)
            grads, state = grad_fn(params, batch, predict_fn, state, rng)
            return optimizer.tree_update(i, grads, params, slots,
                                         opt_params), state, [subrng]

        return backend.jit(single_update) if jit else single_update

    # Else, for n_devices > 1:
    @functools.partial(backend.pmap, axis_name="batch")
    def mapped_update(i, opt_state, batch, state, rng):
        """This is a multi-device version of the update function above."""
        # We assume all tensors have the first dimension = n_devices.
        params, slots, opt_params = opt_state
        rng, subrng = jax_random.split(rng)
        grad_fn = backend.grad(loss_fn, has_aux=True)
        grads, state = grad_fn(params, batch, predict_fn, state, rng)
        grads = jax.tree_util.tree_map(lambda g: lax.psum(g, "batch"), grads)
        return optimizer.tree_update(i, grads, params, slots,
                                     opt_params), state, subrng

    def update(i, opt_state, batch, state, rng):
        return mapped_update(numpy.repeat(i, n_devices), opt_state, batch,
                             state, rng)

    return update
コード例 #6
0
def _jit_compute_loss_fn(predict_fn, loss_fn, n_devices, jit=True):
    """Returns a (JIT-compiled) function that computes the loss for one step."""
    if n_devices == 1:  # TODO(lukaszkaiser): remove branch when not needed.

        def single_compute_loss(opt_state, batch, state, rng):
            rng, subrng = jax_random.split(rng[0])
            loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state,
                                      rng)
            return loss_val, state, [subrng]

        return backend.jit(single_compute_loss) if jit else single_compute_loss

    # Else, for n_devices > 1:
    @functools.partial(backend.pmap, axis_name="batch")
    def mapped_compute_loss(opt_state, batch, state, rng):
        """This is a multi-device version of the update function above."""
        # We assume all tensors have the first dimension = n_devices.
        rng, subrng = jax_random.split(rng)
        loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state, rng)
        return loss_val, state, subrng

    def compute_loss(opt_state, batch, state, rng):
        return mapped_compute_loss(opt_state,
                                   reshape_by_device(batch, n_devices), state,
                                   rng)

    return compute_loss
コード例 #7
0
def _jit_update_fun(predict_fun, loss_fun, optimizer, lr_fun, num_devices):
  """Get jit-ed update function for loss, optimizer, learning rate function."""
  if num_devices == 1:  # TODO(lukaszkaiser): remove branch when not needed.
    def single_update(i, opt_state, batch, rng):
      _, opt_update = optimizer(lr_fun)
      params = trax_opt.get_params(opt_state)
      return opt_update(i, backend.grad(loss_fun)(
          params, batch, predict_fun, rng), opt_state)
    return backend.jit(single_update)

  @functools.partial(backend.pmap, axis_name="batch")
  def mapped_update(i, opt_state, batch, rng):
    """This is a multi-device version of the update function above."""
    # We assume all tensors have the first dimension = num_devices.
    _, opt_update = optimizer(lr_fun)
    params = trax_opt.get_params(opt_state)
    grads = backend.grad(loss_fun)(params, batch, predict_fun, rng)
    grads = jax.tree_util.tree_map(
        lambda g: lax.psum(g, "batch"), grads)
    return opt_update(i, grads, opt_state)

  def update(i, opt_state, batch, rng):
    # TODO(lukaszkaiser): investigate how to replicate rng and correct.
    return mapped_update(jax.replicate(i), opt_state, batch, jax.replicate(rng))

  return update
コード例 #8
0
    def __init__(self,
                 model,
                 batch_size,
                 observation_space,
                 action_space,
                 reward_range,
                 discrete_rewards,
                 history_stream,
                 output_dir,
                 model_predict_kwargs=None):
        """Initializes the env.

    Args:
      model: TRAX model.
      batch_size: (int) Number of simulated environments run in parallel.
      observation_space: (gym.Space) Observation space.
      action_space: (gym.Space) Action space.
      reward_range: (tuple) Pair (min_reward, max_reward).
      discrete_rewards: (bool) Whether to discretize the rewards.
      history_stream: Iterator yielding batches of initial input data for the
        model. The format is implementation-specific.
      output_dir: (str) Output dir.
      model_predict_kwargs: (dict) Additional model keyword arguments for
        inference. Useful when different config is needed for training and
        inference, e.g. train with memory efficient attention and predict with
        the regular one.
    """
        self._model = model
        if model_predict_kwargs is None:
            model_predict_kwargs = {}
        model_predict = self._model(mode="predict", **model_predict_kwargs)

        def predict_with_state(*args, **kwargs):
            output = model_predict(*args, **kwargs)
            return (output, model_predict.state)

        self._model_predict = backend.jit(predict_with_state)
        self._model_initialize = model_predict.initialize_once

        self._observation_space = observation_space
        self._action_space = action_space
        self._reward_range = reward_range
        self._output_dir = output_dir

        self._predict_fn = None
        self._rng = None
        self._model_state = None
        self._history_stream = None

        # Call the super's ctor. It will use some of the member fields, so we call
        # it in the end.
        super(SimulatedEnvProblem, self).__init__(
            batch_size=batch_size,
            discrete_rewards=discrete_rewards,
            history_stream=history_stream,
        )

        self.seed()
コード例 #9
0
ファイル: trax.py プロジェクト: SigmaQuan/tensor2tensor
    def predict(x, params=(), rng=None):
        """Predict function jited and parallelized as requested."""
        # On one device, jit and run.
        if num_devices == 1:
            return backend.jit(model_predict)(x, params, rng=rng)

        # Multi-devices, pmap and run.
        @functools.partial(backend.pmap, axis_name="batch")
        def mapped_predict(x, params, rng):
            return model_predict(x, params, rng=rng)

        pred = mapped_predict(reshape_by_device(x, num_devices), params,
                              jax_random.split(rng, num_devices))
        batch_size = x.shape[0]
        return np.reshape(pred, [batch_size] + list(pred.shape[2:]))
コード例 #10
0
  def predict(x, params=(), rng=None):
    """Predict function jited and parallelized as requested."""
    # On one device, jit and run.
    if n_devices == 1:
      return backend.jit(model_predict)(x, params, rng=rng)

    pred = mapped_predict(
        reshape_by_device(x, n_devices),
        params,
        jax_random.split(rng, n_devices))
    # Need to reduce the [device, per-device-batch, ...] tensors back to
    # a [batch, ...] tensor. The tensors may be nested.
    if not isinstance(pred, (list, tuple)):  # Not nested.
      batch_size = pred.shape[0] * pred.shape[1]
      return np.reshape(pred, [batch_size] + list(pred.shape[2:]))
    batch_size = pred[0].shape[0] * pred[0].shape[1]
    return [np.reshape(p, [batch_size] + list(p.shape[2:])) for p in pred]
コード例 #11
0
    def __init__(self, model, history_length, trajectory_length, batch_size,
                 observation_space, action_space, reward_range,
                 discrete_rewards, initial_observation_stream, output_dir):
        """Initializes the env.

    Args:
      model: TRAX model.
      history_length: (int) Number of last observations fed into the model.
      trajectory_length: (int) Length of each trajectory unrolled from the
        model.
      batch_size: (int) Number of simulated environments run in parallel.
      observation_space: (gym.Space) Observation space.
      action_space: (gym.Space) Action space.
      reward_range: (tuple) Pair (min_reward, max_reward).
      discrete_rewards: (bool) Whether to discretize the rewards.
      initial_observation_stream: Iterator yielding batches of initial
        observations for the model.
      output_dir: (str) Output dir.
    """
        # TODO(pkozakowski): At some point we will have a "predict" mode which we
        # should use here. When this happens, change the mode.
        self._model_predict = backend.jit(model(mode="eval"))
        self._history_length = history_length
        self._trajectory_length = trajectory_length
        self._observation_space = observation_space
        self._action_space = action_space
        self._reward_range = reward_range
        self._output_dir = output_dir

        self._model_params = None
        self._rng = None
        self._initial_observation_stream = None
        self._history = None
        self._steps = None

        # Call the super's ctor. It will use some of the member fields, so we call
        # it in the end.
        super(SimulatedEnvProblem, self).__init__(
            batch_size=batch_size,
            discrete_rewards=discrete_rewards,
            initial_observation_stream=initial_observation_stream,
        )

        self.seed()
コード例 #12
0
ファイル: trax.py プロジェクト: sheldonresearch/tensor2tensor
  def predict(params, batch, rng=None):
    """Predict function jited and parallelized as requested."""
    # If not jit'ing, just run the function.
    if not jit_eval:
      return model_predict(params, batch, rng=rng)

    # On one device, jit and run.
    if num_devices == 1:
      return backend.jit(model_predict)(params, batch, rng=rng)

    # Multi-devices, pmap and run.
    @functools.partial(backend.pmap, axis_name="batch")
    def mapped_predict(params, batch, rng):
      return model_predict(params, batch, rng=rng)
    pred = mapped_predict(
        jax.replicate(params),
        reshape_by_device(batch, num_devices),
        jax.replicate(rng))
    batch_size = batch.shape[0]
    return np.reshape(pred, [batch_size] + list(pred.shape[2:]))
コード例 #13
0
  def __init__(self, model, batch_size, observation_space, action_space,
               reward_range, discrete_rewards, history_stream, output_dir):
    """Initializes the env.

    Args:
      model: TRAX model.
      batch_size: (int) Number of simulated environments run in parallel.
      observation_space: (gym.Space) Observation space.
      action_space: (gym.Space) Action space.
      reward_range: (tuple) Pair (min_reward, max_reward).
      discrete_rewards: (bool) Whether to discretize the rewards.
      history_stream: Iterator yielding batches of initial input data for the
        model. The format is implementation-specific.
      output_dir: (str) Output dir.
    """
    # TODO(pkozakowski): At some point we will have a "predict" mode which we
    # should use here. When this happens, change the mode.
    self._model = model
    self._model_predict = backend.jit(self._model(mode="eval"))
    self._observation_space = observation_space
    self._action_space = action_space
    self._reward_range = reward_range
    self._output_dir = output_dir

    self._predict_fn = None
    self._rng = None
    self._model_state = None
    self._history_stream = None

    # Call the super's ctor. It will use some of the member fields, so we call
    # it in the end.
    super(SimulatedEnvProblem, self).__init__(
        batch_size=batch_size,
        discrete_rewards=discrete_rewards,
        history_stream=history_stream,
    )

    self.seed()
コード例 #14
0
    def __init__(self,
                 model,
                 loss_fn,
                 optimizer,
                 lr_schedule,
                 inputs,
                 output_dir=None,
                 random_seed=None,
                 n_devices=None,
                 save_steps=None,
                 should_save=True,
                 has_weights=False):
        if save_steps is None:
            save_steps = []
        self._save_steps = save_steps
        self._should_save = should_save
        self._has_weights = has_weights
        loss_fn = functools.partial(loss_fn, has_weights=self._has_weights)
        device_count = jax.lib.xla_bridge.device_count()
        n_devices = n_devices or device_count
        # TODO(lukaszkaiser): remove this restriction when possible.
        if n_devices != device_count:
            raise ValueError(
                "Jax cannot work yet with n_devices != all devices: "
                "%d != %d" % (n_devices, device_count))
        self._n_devices = n_devices
        rng = get_random_number_generator_and_set_seed(random_seed)
        inputs = inputs(n_devices)
        self._inputs = inputs

        # Initialize the learning rate to a dummy value. It will be set in reset().
        opt = optimizer(learning_rate=0.0)

        # Setup the model.
        model_train = model(mode="train")
        model_predict_eval = model(mode="eval")

        # Setup state.
        rng, init_rng = jax_random.split(rng)
        self._rngs = jax_random.split(rng, n_devices)
        first_shape = inputs.input_shape[0]
        # If the inputs are a tuple/list, add [None] (batch) to each element.
        if isinstance(first_shape, (list, tuple)):
            model_input_shape = tuple(
                tuple([None] + list(shape)) for shape in inputs.input_shape)
            model_target_shape = tuple(
                tuple([None] + list(shape)) for shape in inputs.target_shape)
        else:  # Otherwise just add [None] to the input shape.
            model_input_shape = tuple([None] + list(inputs.input_shape))
            model_target_shape = tuple([None] + list(inputs.target_shape))
        # Change all None to 1 in input and target shape.
        model_input_shape = layers.nested_map(model_input_shape, lambda x: x
                                              if x else 1)
        model_target_shape = layers.nested_map(model_target_shape, lambda x: x
                                               if x else 1)

        def initialize(input_shape, input_dtype, target_shape, target_dtype,
                       rng):
            """Helper to initialize the model."""
            # Combine inputs and targets on the stack.
            if not isinstance(input_dtype, (list, tuple)):
                input_dtype = [input_dtype]
                input_shape = [input_shape]
            if not isinstance(target_dtype, (list, tuple)):
                target_dtype = [target_dtype]
                target_shape = [target_shape]
            full_type = list(input_dtype) + list(target_dtype)
            full_shape = list(input_shape) + list(target_shape)
            # We need to create a new model instance and not reuse `model_train` here,
            # because `m.initialize` puts cached parameter values in `m` and hence the
            # next call of `m.initialize` will give wrong results.
            params, state = model(mode="train").initialize(
                full_shape, full_type, rng)
            (slots, opt_params) = opt.tree_init(params)
            return (OptState(params, slots, opt_params), state)

        if _is_jit_init():
            # JIT parameter initialization to avoid memory fragmentation
            initialize = backend.jit(initialize, static_argnums=(0, 1, 2, 3))
        self._initialize = lambda: initialize(  # pylint: disable=g-long-lambda
            model_input_shape, self._inputs.input_dtype, model_target_shape,
            self._inputs.target_dtype, init_rng)

        # jit model_predict and update so they're fast
        self._jit_model_predict_eval = _jit_predict_fn(model_predict_eval,
                                                       n_devices)
        self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt,
                                             n_devices)

        self._model_train = model_train
        self._model_predict_eval = model_predict_eval
        self._loss_fn = loss_fn
        self._lr_schedule = lr_schedule

        # Those fields will be set in reset().
        self._output_dir = None
        self._train_sw = None
        self._eval_sw = None
        self._history = None
        self._lr_fn = None
        self._opt_state = None
        self._step = None
        self._model_state = None

        if output_dir is not None:
            self.reset(output_dir)
コード例 #15
0
    def __init__(self,
                 model,
                 loss_fn,
                 optimizer,
                 lr_schedule,
                 inputs,
                 output_dir=None,
                 random_seed=None,
                 n_devices=None,
                 save_steps=None,
                 should_save=True,
                 has_weights=False,
                 nontrainable_param_map=None,
                 mask_id=None):
        if save_steps is None:
            save_steps = []
        self._save_steps = save_steps
        self._should_save = should_save
        self._has_weights = has_weights
        self._mask_id = mask_id
        loss_fn = loss_fn(has_weights=has_weights, mask_id=mask_id)
        device_count = jax.lib.xla_bridge.device_count()
        n_devices = n_devices or device_count
        # TODO(lukaszkaiser): remove this restriction when possible.
        if n_devices != device_count:
            raise ValueError(
                "Jax cannot work yet with n_devices != all devices: "
                "%d != %d" % (n_devices, device_count))
        self._n_devices = n_devices
        rng = get_random_number_generator_and_set_seed(random_seed)
        inputs = inputs(n_devices)
        self._inputs = inputs

        # Initialize the learning rate to a dummy value. It will be set in reset().
        opt = optimizer(learning_rate=0.0)

        # Setup the model.
        model_train = model(mode="train")
        model_predict_eval = model(mode="eval")

        # Setup state.
        rng, init_rng = jax_random.split(rng)
        self._rngs = jax_random.split(rng, n_devices)
        first_shape = inputs.input_shape[0]
        # If the inputs are a tuple/list, add [None] (batch) to each element.
        if isinstance(first_shape, (list, tuple)):
            model_input_shape = tuple(
                tuple([None] + list(shape)) for shape in inputs.input_shape)
            model_target_shape = tuple(
                tuple([None] + list(shape)) for shape in inputs.target_shape)
        else:  # Otherwise just add [None] to the input shape.
            model_input_shape = tuple([None] + list(inputs.input_shape))
            model_target_shape = tuple([None] + list(inputs.target_shape))
        # Change all None to 1 in input and target shape.
        model_input_shape = layers.nested_map(model_input_shape, lambda x: x
                                              if x else 1)
        model_target_shape = layers.nested_map(model_target_shape, lambda x: x
                                               if x else 1)

        def new_opt_state_and_model_state(input_shape, input_dtype,
                                          target_shape, target_dtype, rng):
            """Returns optimizer and model states suitable for training a model."""
            # Combine inputs and targets on the stack.
            if not isinstance(input_dtype, (list, tuple)):
                input_dtype = [input_dtype]
                input_shape = [input_shape]
            if not isinstance(target_dtype, (list, tuple)):
                target_dtype = [target_dtype]
                target_shape = [target_shape]
            full_type = list(input_dtype) + list(target_dtype)
            full_shape = list(input_shape) + list(target_shape)
            if self._has_weights:
                full_shape += list(target_shape)
                full_type += [np.float32 for _ in target_dtype]
            # We need to create a new model instance and not reuse `model_train` here,
            # because `m.initialize` puts cached parameter values in `m` and hence the
            # next call of `m.initialize` will give wrong results.
            m = layers.Serial([model(mode="train"), loss_fn])
            params, state = m.initialize_once(full_shape, full_type, rng)
            (slots, opt_params) = opt.tree_init(params)
            return (OptState(params, slots, opt_params), state)

        if _is_jit_init():
            # JIT parameter initialization to avoid memory fragmentation
            new_opt_state_and_model_state = backend.jit(
                new_opt_state_and_model_state, static_argnums=(0, 1, 2, 3))
        self._new_opt_state_and_model_state = (
            lambda: new_opt_state_and_model_state(  # pylint: disable=g-long-lambda
                model_input_shape, self._inputs.input_dtype,
                model_target_shape, self._inputs.target_dtype, init_rng))

        # jit model_predict and update so they're fast
        # TODO(lukaszkaiser): the code below creates a layer computing
        # multiple metrics from a single model output; re-factor for clarity.
        dup_layer = layers.Dup3() if self._has_weights else layers.Dup2()

        def lower(layer):
            """Apply layer below the current inputs, targets, and possibly weights."""
            if self._has_weights:
                # Apply layer below inputs, targets, and loss weights.
                return layers.Parallel([], [], [], layer)
            else:
                # Apply layer below inputs and targets.
                return layers.Parallel([], [], layer)

        metrics_layer = []
        self._metrics = list(sorted(_METRICS.keys()))
        for i, m in enumerate(reversed(self._metrics)):
            metric = _METRICS[m](has_weights=self._has_weights,
                                 mask_id=self._mask_id)
            if i != len(self._metrics) - 1:
                metrics_layer.append(dup_layer)
                metrics_layer.append(lower(metric))
            else:
                metrics_layer.append(metric)
        # TODO(lukaszkaiser): clean this up once layer API stabilizes.
        # For now, we need to initialize metric layers somehow, so here we go.
        # We assume that they do not have any parameters, so this is a dummy.
        dummy_shape = ((1, 2), (1, ), (1, )) if self._has_weights else ((1, 2),
                                                                        (1, ))
        dummy_type = [np.float32] * (3 if self._has_weights else 2)
        metrics_layer = layers.Serial(metrics_layer)
        metrics_params, metrics_state = metrics_layer.initialize_once(
            dummy_shape, tuple(dummy_type), init_rng)
        self._metrics_params = layers.nested_map(metrics_params,
                                                 self._maybe_replicate)
        self._metrics_state = layers.nested_map(metrics_state,
                                                self._maybe_replicate)
        self._jit_eval = _jit_predict_fn(model_predict_eval, metrics_layer,
                                         n_devices)
        self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt,
                                             n_devices)

        self._model_train = model_train
        self._model_predict_eval = model_predict_eval
        self._loss_fn = loss_fn
        # TODO(pkozakowski): "Learning rate schedules" are currently able to control
        # control all optimizer parameters and model state, so let's rename them
        # accordingly.
        self._lr_schedule = lr_schedule

        if nontrainable_param_map is None:
            nontrainable_param_map = {}
        self._nontrainable_param_map = nontrainable_param_map

        # Those fields will be set in reset().
        self._output_dir = None
        self._train_sw = None
        self._eval_sw = None
        self._history = None
        self._lr_fn = None
        self._opt_state = None
        self._step = None
        self._model_state = None

        if output_dir is not None:
            self.reset(output_dir)
コード例 #16
0
  def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs,
               output_dir=None, random_seed=None, n_devices=None,
               save_steps=None, should_save=True):
    if save_steps is None:
      save_steps = []
    self._save_steps = save_steps
    self._should_save = should_save
    device_count = jax.lib.xla_bridge.device_count()
    n_devices = n_devices or device_count
    # TODO(lukaszkaiser): remove this restriction when possible.
    if n_devices != device_count:
      raise ValueError("Jax cannot work yet with n_devices != all devices: "
                       "%d != %d" % (n_devices, device_count))
    self._n_devices = n_devices
    rng = get_random_number_generator_and_set_seed(random_seed)
    inputs = inputs(n_devices)
    self._inputs = inputs

    # Initialize the learning rate to a dummy value. It will be set in reset().
    opt = optimizer(learning_rate=0.0)

    # Setup the model.
    model_train = model(mode="train")
    model_predict_eval = model(mode="eval")

    # Setup state.
    rng, init_rng = jax_random.split(rng)
    self._rngs = jax_random.split(rng, n_devices)
    first_shape = inputs.input_shape[0]
    # If the inputs are a tuple/list, add [None] (batch) to each element.
    if isinstance(first_shape, (list, tuple)):
      model_input_shape = tuple(
          tuple([None] + list(shape)) for shape in inputs.input_shape)
    else:  # Otherwise just add [None] to the input shape.
      model_input_shape = tuple([None] + list(inputs.input_shape))
    # Change all None to 1 in input shape.
    model_input_shape = layers.nested_map(
        model_input_shape, lambda x: x if x else 1)
    def initialize(input_shape, input_dtype, init_rng):
      params = model_train.initialize(input_shape, input_dtype, init_rng)
      (slots, opt_params) = opt.tree_init(params)
      return OptState(params, slots, opt_params)
    if _is_jit_init():
      # JIT parameter initialization to avoid memory fragmentation
      initialize = backend.jit(initialize, static_argnums=(0, 1))
    self._initialize = lambda: initialize(  # pylint: disable=g-long-lambda
        model_input_shape, self._inputs.input_dtype, init_rng)

    # jit model_predict and update so they're fast
    self._jit_model_predict_eval = _jit_predict_fn(
        model_predict_eval, n_devices)
    self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices)

    self._model_train = model_train
    self._model_predict_eval = model_predict_eval
    self._loss_fn = loss_fn
    self._lr_schedule = lr_schedule

    # Those fields will be set in reset().
    self._output_dir = None
    self._train_sw = None
    self._eval_sw = None
    self._history = None
    self._lr_fn = None
    self._opt_state = None
    self._step = None

    if output_dir is not None:
      self.reset(output_dir)
コード例 #17
0
    def __init__(self,
                 model,
                 loss_fn,
                 optimizer,
                 lr_schedule,
                 inputs,
                 output_dir,
                 random_seed=None,
                 n_devices=None,
                 save_steps=None):
        if save_steps is None:
            save_steps = []
        self._save_steps = save_steps
        device_count = jax.lib.xla_bridge.device_count()
        n_devices = n_devices or device_count
        # TODO(lukaszkaiser): remove this restriction when possible.
        if n_devices != device_count:
            raise ValueError(
                "Jax cannot work yet with n_devices != all devices: "
                "%d != %d" % (n_devices, device_count))
        self._n_devices = n_devices
        rng = get_random_number_generator_and_set_seed(random_seed)
        self._output_dir = output_dir
        gfile.makedirs(output_dir)
        # Create summary writers and history.
        self._train_sw = jaxboard.SummaryWriter(
            os.path.join(output_dir, "train"))
        self._eval_sw = jaxboard.SummaryWriter(os.path.join(
            output_dir, "eval"))

        # Create input streams.
        inputs = inputs(n_devices)
        self._inputs = inputs
        self._train_stream = inputs.train_stream()

        # Setup optimizer and model.
        state = restore_state(output_dir)
        history = state.history
        self._lr_fn = lr_schedule(history)
        opt = optimizer(self._lr_fn)

        model_train = model(mode="train")
        model_predict_eval = model(mode="eval")

        # Setup state.
        step = state.step or 0
        rng, init_rng = jax_random.split(rng)
        self._rngs = jax_random.split(rng, n_devices)
        first_shape = inputs.input_shape[0]
        # If the inputs are a tuple/list, add [None] (batch) to each element.
        if isinstance(first_shape, (list, tuple)):
            model_input_shape = tuple(
                tuple([None] + list(shape)) for shape in inputs.input_shape)
        else:  # Otherwise just add [None] to the input shape.
            model_input_shape = tuple([None] + list(inputs.input_shape))
        # Change all None to 1 in input shape.
        model_input_shape = layers.nested_map(model_input_shape, lambda x: x
                                              if x else 1)
        if state.params:
            opt_state = state.params
        else:
            # JIT parameter initialization to avoid memory fragmentation
            def initialize(input_shape, input_dtype, init_rng):
                params = model_train.initialize(input_shape, input_dtype,
                                                init_rng)
                opt_state = (params, opt.tree_init(params))
                return opt_state

            initialize = backend.jit(initialize, static_argnums=(0, 1))
            opt_state = initialize(model_input_shape, inputs.input_dtype,
                                   init_rng)
        if n_devices > 1:
            replicate = lambda x: numpy.broadcast_to(x,
                                                     (n_devices, ) + x.shape)
            opt_state = layers.nested_map(opt_state, replicate)

        # jit model_predict and update so they're fast
        self._jit_model_predict_eval = _jit_predict_fn(model_predict_eval,
                                                       n_devices)
        self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt,
                                             n_devices)

        self._step = step
        self._model_train = model_train
        self._model_predict_eval = model_predict_eval
        self._loss_fn = loss_fn
        self._optimizer = optimizer
        self._opt_state = opt_state
        self._history = history
        self._lr_schedule = lr_schedule
コード例 #18
0
ファイル: trax.py プロジェクト: bochuxt/tensor2tensor
def train(output_dir,
          model=gin.REQUIRED,
          loss_fun=loss,
          inputs=trax_inputs.inputs,
          optimizer=trax_opt.adam,
          lr_schedule=lr.MultifactorSchedule,
          train_steps=1000,
          eval_steps=10,
          eval_frequency=100,
          num_devices=None,
          random_seed=None,
          run_debug_step=False):
    """Train the model on the inputs.

  Args:
    output_dir: Directory where to put the logs and checkpoints.
    model: The model to train as a callable returning 2 callables, an init_fun
      and apply_fun.
    loss_fun: callable with signature: params, trax.inputs.Inputs, model, rng
      -> loss.
    inputs: callable returning trax.inputs.Inputs.
    optimizer: The optimizer as a callable taking a learning_rate callable and
      returning 2 callables, opt_init and opt_update.
    lr_schedule: A learning rate schedule as a function that takes history and
      returns a function from step to learning rate (a float).
    train_steps: int, total number of training steps.
    eval_steps: int, num of steps per evaluation. If None or 0, eval disabled.
    eval_frequency: int, how often to run evaluation (every eval_frequency
      steps). If None or 0, eval disabled.
    num_devices: how many devices to use (if None, default, use all available)
    random_seed: the random seed to use; time/os dependent if None (default).
    run_debug_step: bool, if True, will run the model and loss without @jit for
      one step.

  Returns:
    trax.State
  """
    num_devices = num_devices or jax.lib.xla_bridge.device_count()
    rng = get_random_number_generator_and_set_seed(random_seed)
    gfile.makedirs(output_dir)
    # Create summary writers and history.
    train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train"))
    eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval"))

    inputs = inputs()

    # Setup optimizer and model
    state = restore_state(output_dir)
    history = state.history
    lr_fun = lr_schedule(history)
    opt_init, _ = optimizer(lr_fun)
    model_init, model_predict = model()

    # Setup state
    step = state.step or 0
    rng, init_key = jax_random.split(rng)
    params_initializer = \
        lambda: model_init(init_key, [-1] + list(inputs.input_shape))[1]
    params = state.params or params_initializer()
    opt_state = opt_init(params)
    if num_devices > 1:  # TODO(lukaszkaiser): use everywhere when pmap is stable.
        opt_state = jax.replicate(opt_state)

    # jit model_predict and update so they're fast
    jit_model_predict = backend.jit(model_predict)  # for evaluation
    jit_update_fun = _jit_update_fun(model_predict, loss_fun, optimizer,
                                     lr_fun, num_devices)

    print()
    train_stream = inputs.train_stream()
    epoch_steps = [train_steps
                   ]  # Only training if eval_frequency is 0 or None.
    if eval_frequency:
        epoch_steps = itertools.chain(
            [
                1,  # first epoch only 1 step
                eval_frequency - 1
            ],
            itertools.repeat(eval_frequency))
    step_log(step, "Starting training using %d devices" % num_devices)

    # Non-compiled debug step helps find problems in models easier.
    if run_debug_step:
        debug_loss = loss_fun(params, next(train_stream), model_predict, rng)
        step_log(step, "Debug step loss %.8f" % debug_loss)

    for epoch, epoch_steps in epochs(train_steps, epoch_steps):
        # Log separator
        print()

        # Timer
        start_time = time.time()

        for _ in range(epoch_steps):
            # Train
            next_train_batch = next(train_stream)
            if num_devices > 1:  # TODO(lukaszkaiser): use everywhere when possible.
                next_train_batch = reshape_by_device(next_train_batch,
                                                     num_devices)
            rng, subrng = jax_random.split(rng)
            opt_state = jit_update_fun(step, opt_state, next_train_batch,
                                       subrng)
            step += 1

            # LR log
            if step == 1 or step % 10 == 0:
                train_sw.scalar("training/learning rate",
                                lr_fun(step),
                                step=step)

        # Timer
        epoch_time = time.time() - start_time
        step_log(
            step,
            "Ran %d train steps in %0.2f secs" % (epoch_steps, epoch_time))
        if epoch_steps > 1:
            train_sw.scalar("training/steps per second",
                            epoch_steps / epoch_time,
                            step=step)

        # Evaluate
        if num_devices > 1:  # TODO(lukaszkaiser): remove branch when possible.
            params = trax_opt.get_params(jax.unreplicate(opt_state))
        else:
            params = trax_opt.get_params(opt_state)
        evaluate_train_and_eval(step=step,
                                inputs=inputs,
                                predict_fun=functools.partial(
                                    jit_model_predict, params),
                                eval_steps=eval_steps,
                                rng=rng,
                                train_sw=train_sw,
                                eval_sw=eval_sw,
                                history=history)

        # Save state
        save_state(State(params=params, step=step, history=history),
                   output_dir)

        # Save Gin config
        # Gin only tracks the used parameters, so we save it after the first epoch.
        if epoch == 1:
            save_gin(output_dir, train_sw)

        # Update learning rate with new history
        old_lr_fun = lr_fun
        lr_fun = lr_schedule(history)
        if lr_fun != old_lr_fun:  # For performance, only jit if there is a change.
            jit_update_fun = _jit_update_fun(model_predict, loss_fun,
                                             optimizer, lr_fun, num_devices)

        # Flush summary writers
        train_sw.writer.flush()
        eval_sw.writer.flush()

    step_log(step, "Training done")
    return State(params=params, step=step, history=history)