def _jit_predict_fn(model_predict, metric_fn, n_devices, jit=True): """Returns a JIT-compiled predict function (unless jit=False).""" model_predict = layers.Serial([model_predict, metric_fn]) if n_devices == 1: return backend.jit(model_predict) if jit else model_predict # Multi-devices, pmap and run. @functools.partial(backend.pmap, axis_name="batch") def mapped_predict(x, params, state, rng): return model_predict(x, params=params, state=state, rng=rng) def predict(x, params=(), state=(), rng=None): """Predict function jited and parallelized as requested.""" pred = mapped_predict(reshape_by_device(x, n_devices), params, state, jax_random.split(rng, n_devices)) # Need to reduce the [device, per-device-batch, ...] tensors back to # a [batch, ...] tensor. The tensors may be nested. def combine(x): if len(x.shape) > 1: batch_size = x.shape[0] * x.shape[1] return np.reshape(x, [batch_size] + list(x.shape[2:])) # TODO(lukaszkaiser): is returning averages for scalars the right choice? # If it is only scalar, return the average. return np.mean(x, axis=0) return layers.nested_map(pred, combine) return predict
def _jit_update_fn(predict_fn, loss_fn, optimizer, n_devices, jit=True): """Get jit-ed update function for loss, optimizer, learning rate function.""" if n_devices == 1: # TODO(lukaszkaiser): remove branch when not needed. def single_update(i, opt_state, batch, rng): rng, subrng = jax_random.split(rng[0]) params, opt_slots = opt_state return optimizer.tree_update(i, backend.grad(loss_fn)( params, batch, predict_fn, rng), params, opt_slots), [subrng] if jit: return backend.jit(single_update) else: return single_update @functools.partial(backend.pmap, axis_name="batch") def mapped_update(i, opt_state, batch, rng): """This is a multi-device version of the update function above.""" # We assume all tensors have the first dimension = n_devices. rng, subrng = jax_random.split(rng) params, opt_slots = opt_state grads = backend.grad(loss_fn)(params, batch, predict_fn, rng) grads = jax.tree_util.tree_map( lambda g: lax.psum(g, "batch"), grads) return optimizer.tree_update(i, grads, params, opt_slots), subrng def update(i, opt_state, batch, rng): return mapped_update(numpy.repeat(i, n_devices), opt_state, batch, rng) return update
def _jit_predict_fn(model_predict, n_devices, jit=True): """Use jit on model_predict if required.""" if n_devices == 1: if jit: return backend.jit(model_predict) else: return model_predict # Multi-devices, pmap and run. @functools.partial(backend.pmap, axis_name="batch") def mapped_predict(x, params, rng): return model_predict(x, params, rng=rng) def predict(x, params=(), rng=None): """Predict function jited and parallelized as requested.""" # On one device, jit and run. pred = mapped_predict( reshape_by_device(x, n_devices), params, jax_random.split(rng, n_devices)) # Need to reduce the [device, per-device-batch, ...] tensors back to # a [batch, ...] tensor. The tensors may be nested. if not isinstance(pred, (list, tuple)): # Not nested. batch_size = pred.shape[0] * pred.shape[1] return np.reshape(pred, [batch_size] + list(pred.shape[2:])) batch_size = pred[0].shape[0] * pred[0].shape[1] return [np.reshape(p, [batch_size] + list(p.shape[2:])) for p in pred] return predict
def _jit_predict_fn(model_predict, n_devices, jit=True): """Returns a JIT-compiled predict function (unless jit=False).""" if n_devices == 1: return backend.jit(model_predict) if jit else model_predict # Multi-devices, pmap and run. @functools.partial(backend.pmap, axis_name="batch") def mapped_predict(x, params, state, rng): return model_predict(x, params, state, rng=rng) def predict(x, params=(), state=(), rng=None): """Predict function jited and parallelized as requested.""" pred, state = mapped_predict(reshape_by_device(x, n_devices), params, state, jax_random.split(rng, n_devices)) # Need to reduce the [device, per-device-batch, ...] tensors back to # a [batch, ...] tensor. The tensors may be nested. def combine(x): batch_size = x.shape[0] * x.shape[1] return np.reshape(x, [batch_size] + list(x.shape[2:])) return layers.nested_map(pred, combine), state return predict
def _jit_update_fn(predict_fn, loss_fn, optimizer, n_devices, jit=True): """Returns a (JIT-compiled) function that computes updates for one step.""" if n_devices == 1: # TODO(lukaszkaiser): remove branch when not needed. def single_update(i, opt_state, batch, state, rng): params, slots, opt_params = opt_state rng, subrng = jax_random.split(rng[0]) grad_fn = backend.grad(loss_fn, has_aux=True) grads, state = grad_fn(params, batch, predict_fn, state, rng) return optimizer.tree_update(i, grads, params, slots, opt_params), state, [subrng] return backend.jit(single_update) if jit else single_update # Else, for n_devices > 1: @functools.partial(backend.pmap, axis_name="batch") def mapped_update(i, opt_state, batch, state, rng): """This is a multi-device version of the update function above.""" # We assume all tensors have the first dimension = n_devices. params, slots, opt_params = opt_state rng, subrng = jax_random.split(rng) grad_fn = backend.grad(loss_fn, has_aux=True) grads, state = grad_fn(params, batch, predict_fn, state, rng) grads = jax.tree_util.tree_map(lambda g: lax.psum(g, "batch"), grads) return optimizer.tree_update(i, grads, params, slots, opt_params), state, subrng def update(i, opt_state, batch, state, rng): return mapped_update(numpy.repeat(i, n_devices), opt_state, batch, state, rng) return update
def _jit_compute_loss_fn(predict_fn, loss_fn, n_devices, jit=True): """Returns a (JIT-compiled) function that computes the loss for one step.""" if n_devices == 1: # TODO(lukaszkaiser): remove branch when not needed. def single_compute_loss(opt_state, batch, state, rng): rng, subrng = jax_random.split(rng[0]) loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state, rng) return loss_val, state, [subrng] return backend.jit(single_compute_loss) if jit else single_compute_loss # Else, for n_devices > 1: @functools.partial(backend.pmap, axis_name="batch") def mapped_compute_loss(opt_state, batch, state, rng): """This is a multi-device version of the update function above.""" # We assume all tensors have the first dimension = n_devices. rng, subrng = jax_random.split(rng) loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state, rng) return loss_val, state, subrng def compute_loss(opt_state, batch, state, rng): return mapped_compute_loss(opt_state, reshape_by_device(batch, n_devices), state, rng) return compute_loss
def _jit_update_fun(predict_fun, loss_fun, optimizer, lr_fun, num_devices): """Get jit-ed update function for loss, optimizer, learning rate function.""" if num_devices == 1: # TODO(lukaszkaiser): remove branch when not needed. def single_update(i, opt_state, batch, rng): _, opt_update = optimizer(lr_fun) params = trax_opt.get_params(opt_state) return opt_update(i, backend.grad(loss_fun)( params, batch, predict_fun, rng), opt_state) return backend.jit(single_update) @functools.partial(backend.pmap, axis_name="batch") def mapped_update(i, opt_state, batch, rng): """This is a multi-device version of the update function above.""" # We assume all tensors have the first dimension = num_devices. _, opt_update = optimizer(lr_fun) params = trax_opt.get_params(opt_state) grads = backend.grad(loss_fun)(params, batch, predict_fun, rng) grads = jax.tree_util.tree_map( lambda g: lax.psum(g, "batch"), grads) return opt_update(i, grads, opt_state) def update(i, opt_state, batch, rng): # TODO(lukaszkaiser): investigate how to replicate rng and correct. return mapped_update(jax.replicate(i), opt_state, batch, jax.replicate(rng)) return update
def __init__(self, model, batch_size, observation_space, action_space, reward_range, discrete_rewards, history_stream, output_dir, model_predict_kwargs=None): """Initializes the env. Args: model: TRAX model. batch_size: (int) Number of simulated environments run in parallel. observation_space: (gym.Space) Observation space. action_space: (gym.Space) Action space. reward_range: (tuple) Pair (min_reward, max_reward). discrete_rewards: (bool) Whether to discretize the rewards. history_stream: Iterator yielding batches of initial input data for the model. The format is implementation-specific. output_dir: (str) Output dir. model_predict_kwargs: (dict) Additional model keyword arguments for inference. Useful when different config is needed for training and inference, e.g. train with memory efficient attention and predict with the regular one. """ self._model = model if model_predict_kwargs is None: model_predict_kwargs = {} model_predict = self._model(mode="predict", **model_predict_kwargs) def predict_with_state(*args, **kwargs): output = model_predict(*args, **kwargs) return (output, model_predict.state) self._model_predict = backend.jit(predict_with_state) self._model_initialize = model_predict.initialize_once self._observation_space = observation_space self._action_space = action_space self._reward_range = reward_range self._output_dir = output_dir self._predict_fn = None self._rng = None self._model_state = None self._history_stream = None # Call the super's ctor. It will use some of the member fields, so we call # it in the end. super(SimulatedEnvProblem, self).__init__( batch_size=batch_size, discrete_rewards=discrete_rewards, history_stream=history_stream, ) self.seed()
def predict(x, params=(), rng=None): """Predict function jited and parallelized as requested.""" # On one device, jit and run. if num_devices == 1: return backend.jit(model_predict)(x, params, rng=rng) # Multi-devices, pmap and run. @functools.partial(backend.pmap, axis_name="batch") def mapped_predict(x, params, rng): return model_predict(x, params, rng=rng) pred = mapped_predict(reshape_by_device(x, num_devices), params, jax_random.split(rng, num_devices)) batch_size = x.shape[0] return np.reshape(pred, [batch_size] + list(pred.shape[2:]))
def predict(x, params=(), rng=None): """Predict function jited and parallelized as requested.""" # On one device, jit and run. if n_devices == 1: return backend.jit(model_predict)(x, params, rng=rng) pred = mapped_predict( reshape_by_device(x, n_devices), params, jax_random.split(rng, n_devices)) # Need to reduce the [device, per-device-batch, ...] tensors back to # a [batch, ...] tensor. The tensors may be nested. if not isinstance(pred, (list, tuple)): # Not nested. batch_size = pred.shape[0] * pred.shape[1] return np.reshape(pred, [batch_size] + list(pred.shape[2:])) batch_size = pred[0].shape[0] * pred[0].shape[1] return [np.reshape(p, [batch_size] + list(p.shape[2:])) for p in pred]
def __init__(self, model, history_length, trajectory_length, batch_size, observation_space, action_space, reward_range, discrete_rewards, initial_observation_stream, output_dir): """Initializes the env. Args: model: TRAX model. history_length: (int) Number of last observations fed into the model. trajectory_length: (int) Length of each trajectory unrolled from the model. batch_size: (int) Number of simulated environments run in parallel. observation_space: (gym.Space) Observation space. action_space: (gym.Space) Action space. reward_range: (tuple) Pair (min_reward, max_reward). discrete_rewards: (bool) Whether to discretize the rewards. initial_observation_stream: Iterator yielding batches of initial observations for the model. output_dir: (str) Output dir. """ # TODO(pkozakowski): At some point we will have a "predict" mode which we # should use here. When this happens, change the mode. self._model_predict = backend.jit(model(mode="eval")) self._history_length = history_length self._trajectory_length = trajectory_length self._observation_space = observation_space self._action_space = action_space self._reward_range = reward_range self._output_dir = output_dir self._model_params = None self._rng = None self._initial_observation_stream = None self._history = None self._steps = None # Call the super's ctor. It will use some of the member fields, so we call # it in the end. super(SimulatedEnvProblem, self).__init__( batch_size=batch_size, discrete_rewards=discrete_rewards, initial_observation_stream=initial_observation_stream, ) self.seed()
def predict(params, batch, rng=None): """Predict function jited and parallelized as requested.""" # If not jit'ing, just run the function. if not jit_eval: return model_predict(params, batch, rng=rng) # On one device, jit and run. if num_devices == 1: return backend.jit(model_predict)(params, batch, rng=rng) # Multi-devices, pmap and run. @functools.partial(backend.pmap, axis_name="batch") def mapped_predict(params, batch, rng): return model_predict(params, batch, rng=rng) pred = mapped_predict( jax.replicate(params), reshape_by_device(batch, num_devices), jax.replicate(rng)) batch_size = batch.shape[0] return np.reshape(pred, [batch_size] + list(pred.shape[2:]))
def __init__(self, model, batch_size, observation_space, action_space, reward_range, discrete_rewards, history_stream, output_dir): """Initializes the env. Args: model: TRAX model. batch_size: (int) Number of simulated environments run in parallel. observation_space: (gym.Space) Observation space. action_space: (gym.Space) Action space. reward_range: (tuple) Pair (min_reward, max_reward). discrete_rewards: (bool) Whether to discretize the rewards. history_stream: Iterator yielding batches of initial input data for the model. The format is implementation-specific. output_dir: (str) Output dir. """ # TODO(pkozakowski): At some point we will have a "predict" mode which we # should use here. When this happens, change the mode. self._model = model self._model_predict = backend.jit(self._model(mode="eval")) self._observation_space = observation_space self._action_space = action_space self._reward_range = reward_range self._output_dir = output_dir self._predict_fn = None self._rng = None self._model_state = None self._history_stream = None # Call the super's ctor. It will use some of the member fields, so we call # it in the end. super(SimulatedEnvProblem, self).__init__( batch_size=batch_size, discrete_rewards=discrete_rewards, history_stream=history_stream, ) self.seed()
def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir=None, random_seed=None, n_devices=None, save_steps=None, should_save=True, has_weights=False): if save_steps is None: save_steps = [] self._save_steps = save_steps self._should_save = should_save self._has_weights = has_weights loss_fn = functools.partial(loss_fn, has_weights=self._has_weights) device_count = jax.lib.xla_bridge.device_count() n_devices = n_devices or device_count # TODO(lukaszkaiser): remove this restriction when possible. if n_devices != device_count: raise ValueError( "Jax cannot work yet with n_devices != all devices: " "%d != %d" % (n_devices, device_count)) self._n_devices = n_devices rng = get_random_number_generator_and_set_seed(random_seed) inputs = inputs(n_devices) self._inputs = inputs # Initialize the learning rate to a dummy value. It will be set in reset(). opt = optimizer(learning_rate=0.0) # Setup the model. model_train = model(mode="train") model_predict_eval = model(mode="eval") # Setup state. rng, init_rng = jax_random.split(rng) self._rngs = jax_random.split(rng, n_devices) first_shape = inputs.input_shape[0] # If the inputs are a tuple/list, add [None] (batch) to each element. if isinstance(first_shape, (list, tuple)): model_input_shape = tuple( tuple([None] + list(shape)) for shape in inputs.input_shape) model_target_shape = tuple( tuple([None] + list(shape)) for shape in inputs.target_shape) else: # Otherwise just add [None] to the input shape. model_input_shape = tuple([None] + list(inputs.input_shape)) model_target_shape = tuple([None] + list(inputs.target_shape)) # Change all None to 1 in input and target shape. model_input_shape = layers.nested_map(model_input_shape, lambda x: x if x else 1) model_target_shape = layers.nested_map(model_target_shape, lambda x: x if x else 1) def initialize(input_shape, input_dtype, target_shape, target_dtype, rng): """Helper to initialize the model.""" # Combine inputs and targets on the stack. if not isinstance(input_dtype, (list, tuple)): input_dtype = [input_dtype] input_shape = [input_shape] if not isinstance(target_dtype, (list, tuple)): target_dtype = [target_dtype] target_shape = [target_shape] full_type = list(input_dtype) + list(target_dtype) full_shape = list(input_shape) + list(target_shape) # We need to create a new model instance and not reuse `model_train` here, # because `m.initialize` puts cached parameter values in `m` and hence the # next call of `m.initialize` will give wrong results. params, state = model(mode="train").initialize( full_shape, full_type, rng) (slots, opt_params) = opt.tree_init(params) return (OptState(params, slots, opt_params), state) if _is_jit_init(): # JIT parameter initialization to avoid memory fragmentation initialize = backend.jit(initialize, static_argnums=(0, 1, 2, 3)) self._initialize = lambda: initialize( # pylint: disable=g-long-lambda model_input_shape, self._inputs.input_dtype, model_target_shape, self._inputs.target_dtype, init_rng) # jit model_predict and update so they're fast self._jit_model_predict_eval = _jit_predict_fn(model_predict_eval, n_devices) self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices) self._model_train = model_train self._model_predict_eval = model_predict_eval self._loss_fn = loss_fn self._lr_schedule = lr_schedule # Those fields will be set in reset(). self._output_dir = None self._train_sw = None self._eval_sw = None self._history = None self._lr_fn = None self._opt_state = None self._step = None self._model_state = None if output_dir is not None: self.reset(output_dir)
def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir=None, random_seed=None, n_devices=None, save_steps=None, should_save=True, has_weights=False, nontrainable_param_map=None, mask_id=None): if save_steps is None: save_steps = [] self._save_steps = save_steps self._should_save = should_save self._has_weights = has_weights self._mask_id = mask_id loss_fn = loss_fn(has_weights=has_weights, mask_id=mask_id) device_count = jax.lib.xla_bridge.device_count() n_devices = n_devices or device_count # TODO(lukaszkaiser): remove this restriction when possible. if n_devices != device_count: raise ValueError( "Jax cannot work yet with n_devices != all devices: " "%d != %d" % (n_devices, device_count)) self._n_devices = n_devices rng = get_random_number_generator_and_set_seed(random_seed) inputs = inputs(n_devices) self._inputs = inputs # Initialize the learning rate to a dummy value. It will be set in reset(). opt = optimizer(learning_rate=0.0) # Setup the model. model_train = model(mode="train") model_predict_eval = model(mode="eval") # Setup state. rng, init_rng = jax_random.split(rng) self._rngs = jax_random.split(rng, n_devices) first_shape = inputs.input_shape[0] # If the inputs are a tuple/list, add [None] (batch) to each element. if isinstance(first_shape, (list, tuple)): model_input_shape = tuple( tuple([None] + list(shape)) for shape in inputs.input_shape) model_target_shape = tuple( tuple([None] + list(shape)) for shape in inputs.target_shape) else: # Otherwise just add [None] to the input shape. model_input_shape = tuple([None] + list(inputs.input_shape)) model_target_shape = tuple([None] + list(inputs.target_shape)) # Change all None to 1 in input and target shape. model_input_shape = layers.nested_map(model_input_shape, lambda x: x if x else 1) model_target_shape = layers.nested_map(model_target_shape, lambda x: x if x else 1) def new_opt_state_and_model_state(input_shape, input_dtype, target_shape, target_dtype, rng): """Returns optimizer and model states suitable for training a model.""" # Combine inputs and targets on the stack. if not isinstance(input_dtype, (list, tuple)): input_dtype = [input_dtype] input_shape = [input_shape] if not isinstance(target_dtype, (list, tuple)): target_dtype = [target_dtype] target_shape = [target_shape] full_type = list(input_dtype) + list(target_dtype) full_shape = list(input_shape) + list(target_shape) if self._has_weights: full_shape += list(target_shape) full_type += [np.float32 for _ in target_dtype] # We need to create a new model instance and not reuse `model_train` here, # because `m.initialize` puts cached parameter values in `m` and hence the # next call of `m.initialize` will give wrong results. m = layers.Serial([model(mode="train"), loss_fn]) params, state = m.initialize_once(full_shape, full_type, rng) (slots, opt_params) = opt.tree_init(params) return (OptState(params, slots, opt_params), state) if _is_jit_init(): # JIT parameter initialization to avoid memory fragmentation new_opt_state_and_model_state = backend.jit( new_opt_state_and_model_state, static_argnums=(0, 1, 2, 3)) self._new_opt_state_and_model_state = ( lambda: new_opt_state_and_model_state( # pylint: disable=g-long-lambda model_input_shape, self._inputs.input_dtype, model_target_shape, self._inputs.target_dtype, init_rng)) # jit model_predict and update so they're fast # TODO(lukaszkaiser): the code below creates a layer computing # multiple metrics from a single model output; re-factor for clarity. dup_layer = layers.Dup3() if self._has_weights else layers.Dup2() def lower(layer): """Apply layer below the current inputs, targets, and possibly weights.""" if self._has_weights: # Apply layer below inputs, targets, and loss weights. return layers.Parallel([], [], [], layer) else: # Apply layer below inputs and targets. return layers.Parallel([], [], layer) metrics_layer = [] self._metrics = list(sorted(_METRICS.keys())) for i, m in enumerate(reversed(self._metrics)): metric = _METRICS[m](has_weights=self._has_weights, mask_id=self._mask_id) if i != len(self._metrics) - 1: metrics_layer.append(dup_layer) metrics_layer.append(lower(metric)) else: metrics_layer.append(metric) # TODO(lukaszkaiser): clean this up once layer API stabilizes. # For now, we need to initialize metric layers somehow, so here we go. # We assume that they do not have any parameters, so this is a dummy. dummy_shape = ((1, 2), (1, ), (1, )) if self._has_weights else ((1, 2), (1, )) dummy_type = [np.float32] * (3 if self._has_weights else 2) metrics_layer = layers.Serial(metrics_layer) metrics_params, metrics_state = metrics_layer.initialize_once( dummy_shape, tuple(dummy_type), init_rng) self._metrics_params = layers.nested_map(metrics_params, self._maybe_replicate) self._metrics_state = layers.nested_map(metrics_state, self._maybe_replicate) self._jit_eval = _jit_predict_fn(model_predict_eval, metrics_layer, n_devices) self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices) self._model_train = model_train self._model_predict_eval = model_predict_eval self._loss_fn = loss_fn # TODO(pkozakowski): "Learning rate schedules" are currently able to control # control all optimizer parameters and model state, so let's rename them # accordingly. self._lr_schedule = lr_schedule if nontrainable_param_map is None: nontrainable_param_map = {} self._nontrainable_param_map = nontrainable_param_map # Those fields will be set in reset(). self._output_dir = None self._train_sw = None self._eval_sw = None self._history = None self._lr_fn = None self._opt_state = None self._step = None self._model_state = None if output_dir is not None: self.reset(output_dir)
def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir=None, random_seed=None, n_devices=None, save_steps=None, should_save=True): if save_steps is None: save_steps = [] self._save_steps = save_steps self._should_save = should_save device_count = jax.lib.xla_bridge.device_count() n_devices = n_devices or device_count # TODO(lukaszkaiser): remove this restriction when possible. if n_devices != device_count: raise ValueError("Jax cannot work yet with n_devices != all devices: " "%d != %d" % (n_devices, device_count)) self._n_devices = n_devices rng = get_random_number_generator_and_set_seed(random_seed) inputs = inputs(n_devices) self._inputs = inputs # Initialize the learning rate to a dummy value. It will be set in reset(). opt = optimizer(learning_rate=0.0) # Setup the model. model_train = model(mode="train") model_predict_eval = model(mode="eval") # Setup state. rng, init_rng = jax_random.split(rng) self._rngs = jax_random.split(rng, n_devices) first_shape = inputs.input_shape[0] # If the inputs are a tuple/list, add [None] (batch) to each element. if isinstance(first_shape, (list, tuple)): model_input_shape = tuple( tuple([None] + list(shape)) for shape in inputs.input_shape) else: # Otherwise just add [None] to the input shape. model_input_shape = tuple([None] + list(inputs.input_shape)) # Change all None to 1 in input shape. model_input_shape = layers.nested_map( model_input_shape, lambda x: x if x else 1) def initialize(input_shape, input_dtype, init_rng): params = model_train.initialize(input_shape, input_dtype, init_rng) (slots, opt_params) = opt.tree_init(params) return OptState(params, slots, opt_params) if _is_jit_init(): # JIT parameter initialization to avoid memory fragmentation initialize = backend.jit(initialize, static_argnums=(0, 1)) self._initialize = lambda: initialize( # pylint: disable=g-long-lambda model_input_shape, self._inputs.input_dtype, init_rng) # jit model_predict and update so they're fast self._jit_model_predict_eval = _jit_predict_fn( model_predict_eval, n_devices) self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices) self._model_train = model_train self._model_predict_eval = model_predict_eval self._loss_fn = loss_fn self._lr_schedule = lr_schedule # Those fields will be set in reset(). self._output_dir = None self._train_sw = None self._eval_sw = None self._history = None self._lr_fn = None self._opt_state = None self._step = None if output_dir is not None: self.reset(output_dir)
def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir, random_seed=None, n_devices=None, save_steps=None): if save_steps is None: save_steps = [] self._save_steps = save_steps device_count = jax.lib.xla_bridge.device_count() n_devices = n_devices or device_count # TODO(lukaszkaiser): remove this restriction when possible. if n_devices != device_count: raise ValueError( "Jax cannot work yet with n_devices != all devices: " "%d != %d" % (n_devices, device_count)) self._n_devices = n_devices rng = get_random_number_generator_and_set_seed(random_seed) self._output_dir = output_dir gfile.makedirs(output_dir) # Create summary writers and history. self._train_sw = jaxboard.SummaryWriter( os.path.join(output_dir, "train")) self._eval_sw = jaxboard.SummaryWriter(os.path.join( output_dir, "eval")) # Create input streams. inputs = inputs(n_devices) self._inputs = inputs self._train_stream = inputs.train_stream() # Setup optimizer and model. state = restore_state(output_dir) history = state.history self._lr_fn = lr_schedule(history) opt = optimizer(self._lr_fn) model_train = model(mode="train") model_predict_eval = model(mode="eval") # Setup state. step = state.step or 0 rng, init_rng = jax_random.split(rng) self._rngs = jax_random.split(rng, n_devices) first_shape = inputs.input_shape[0] # If the inputs are a tuple/list, add [None] (batch) to each element. if isinstance(first_shape, (list, tuple)): model_input_shape = tuple( tuple([None] + list(shape)) for shape in inputs.input_shape) else: # Otherwise just add [None] to the input shape. model_input_shape = tuple([None] + list(inputs.input_shape)) # Change all None to 1 in input shape. model_input_shape = layers.nested_map(model_input_shape, lambda x: x if x else 1) if state.params: opt_state = state.params else: # JIT parameter initialization to avoid memory fragmentation def initialize(input_shape, input_dtype, init_rng): params = model_train.initialize(input_shape, input_dtype, init_rng) opt_state = (params, opt.tree_init(params)) return opt_state initialize = backend.jit(initialize, static_argnums=(0, 1)) opt_state = initialize(model_input_shape, inputs.input_dtype, init_rng) if n_devices > 1: replicate = lambda x: numpy.broadcast_to(x, (n_devices, ) + x.shape) opt_state = layers.nested_map(opt_state, replicate) # jit model_predict and update so they're fast self._jit_model_predict_eval = _jit_predict_fn(model_predict_eval, n_devices) self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices) self._step = step self._model_train = model_train self._model_predict_eval = model_predict_eval self._loss_fn = loss_fn self._optimizer = optimizer self._opt_state = opt_state self._history = history self._lr_schedule = lr_schedule
def train(output_dir, model=gin.REQUIRED, loss_fun=loss, inputs=trax_inputs.inputs, optimizer=trax_opt.adam, lr_schedule=lr.MultifactorSchedule, train_steps=1000, eval_steps=10, eval_frequency=100, num_devices=None, random_seed=None, run_debug_step=False): """Train the model on the inputs. Args: output_dir: Directory where to put the logs and checkpoints. model: The model to train as a callable returning 2 callables, an init_fun and apply_fun. loss_fun: callable with signature: params, trax.inputs.Inputs, model, rng -> loss. inputs: callable returning trax.inputs.Inputs. optimizer: The optimizer as a callable taking a learning_rate callable and returning 2 callables, opt_init and opt_update. lr_schedule: A learning rate schedule as a function that takes history and returns a function from step to learning rate (a float). train_steps: int, total number of training steps. eval_steps: int, num of steps per evaluation. If None or 0, eval disabled. eval_frequency: int, how often to run evaluation (every eval_frequency steps). If None or 0, eval disabled. num_devices: how many devices to use (if None, default, use all available) random_seed: the random seed to use; time/os dependent if None (default). run_debug_step: bool, if True, will run the model and loss without @jit for one step. Returns: trax.State """ num_devices = num_devices or jax.lib.xla_bridge.device_count() rng = get_random_number_generator_and_set_seed(random_seed) gfile.makedirs(output_dir) # Create summary writers and history. train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train")) eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval")) inputs = inputs() # Setup optimizer and model state = restore_state(output_dir) history = state.history lr_fun = lr_schedule(history) opt_init, _ = optimizer(lr_fun) model_init, model_predict = model() # Setup state step = state.step or 0 rng, init_key = jax_random.split(rng) params_initializer = \ lambda: model_init(init_key, [-1] + list(inputs.input_shape))[1] params = state.params or params_initializer() opt_state = opt_init(params) if num_devices > 1: # TODO(lukaszkaiser): use everywhere when pmap is stable. opt_state = jax.replicate(opt_state) # jit model_predict and update so they're fast jit_model_predict = backend.jit(model_predict) # for evaluation jit_update_fun = _jit_update_fun(model_predict, loss_fun, optimizer, lr_fun, num_devices) print() train_stream = inputs.train_stream() epoch_steps = [train_steps ] # Only training if eval_frequency is 0 or None. if eval_frequency: epoch_steps = itertools.chain( [ 1, # first epoch only 1 step eval_frequency - 1 ], itertools.repeat(eval_frequency)) step_log(step, "Starting training using %d devices" % num_devices) # Non-compiled debug step helps find problems in models easier. if run_debug_step: debug_loss = loss_fun(params, next(train_stream), model_predict, rng) step_log(step, "Debug step loss %.8f" % debug_loss) for epoch, epoch_steps in epochs(train_steps, epoch_steps): # Log separator print() # Timer start_time = time.time() for _ in range(epoch_steps): # Train next_train_batch = next(train_stream) if num_devices > 1: # TODO(lukaszkaiser): use everywhere when possible. next_train_batch = reshape_by_device(next_train_batch, num_devices) rng, subrng = jax_random.split(rng) opt_state = jit_update_fun(step, opt_state, next_train_batch, subrng) step += 1 # LR log if step == 1 or step % 10 == 0: train_sw.scalar("training/learning rate", lr_fun(step), step=step) # Timer epoch_time = time.time() - start_time step_log( step, "Ran %d train steps in %0.2f secs" % (epoch_steps, epoch_time)) if epoch_steps > 1: train_sw.scalar("training/steps per second", epoch_steps / epoch_time, step=step) # Evaluate if num_devices > 1: # TODO(lukaszkaiser): remove branch when possible. params = trax_opt.get_params(jax.unreplicate(opt_state)) else: params = trax_opt.get_params(opt_state) evaluate_train_and_eval(step=step, inputs=inputs, predict_fun=functools.partial( jit_model_predict, params), eval_steps=eval_steps, rng=rng, train_sw=train_sw, eval_sw=eval_sw, history=history) # Save state save_state(State(params=params, step=step, history=history), output_dir) # Save Gin config # Gin only tracks the used parameters, so we save it after the first epoch. if epoch == 1: save_gin(output_dir, train_sw) # Update learning rate with new history old_lr_fun = lr_fun lr_fun = lr_schedule(history) if lr_fun != old_lr_fun: # For performance, only jit if there is a change. jit_update_fun = _jit_update_fun(model_predict, loss_fun, optimizer, lr_fun, num_devices) # Flush summary writers train_sw.writer.flush() eval_sw.writer.flush() step_log(step, "Training done") return State(params=params, step=step, history=history)