def train_step(self, batch): """Run one training step and update self._opt_state.""" # Calculate the current optimizer parameters. # TODO(pkozakowski): Optimizer parameters get polluted with model state, # which doesn't break anything but is weird. Filter it out. opt_param_updates = self._for_n_devices( math.nested_map(np.array, self.nontrainable_params)) opt_state = self._opt_state opt_state.opt_params.update(opt_param_updates) # Run the update. weights, slots, opt_params = opt_state (weights, slots), stat, self._model_state, self._rngs = self._jit_update_fn( (weights, slots), self._step, opt_params, batch, self._model_state, self._rngs) self._model_state = self._map_to_state_dicts(self._state_dicts_update) self._opt_state = opt_state._replace(weights=weights, slots=slots) if self._should_log_now(): for name, value in stat.items(): scalar_value = np.mean( value) # On multiple devices, take the mean. self._train_sw.scalar('training/' + name, scalar_value, step=self._step) self._step += 1
def model_state(self): # Currently we need to pick [0] as we ignore loss state (empty). state = self._model_state[0] if self.n_devices > 1: unreplicate = lambda x: x[0] state = math.nested_map(unreplicate, state) return state
def save_state(self, keep, prefix='model'): """Save trainer state given a possibly replicated opt_state.""" opt_state = self._opt_state if self.n_devices > 1: first_replica = lambda x: x[0] opt_state = OptState(*math.nested_map(first_replica, opt_state)) # This line, while optional, allows JAX to transfer arrays from the device # to the host in parallel, which is particularly important for cloud TPU. if math.backend_name() == 'jax': opt_state = jax.device_get(opt_state) step, history, model_state = self._step, self._history, self._model_state output_dir = self._output_dir weights_file = os.path.join(output_dir, prefix + '.pkl.gz') # This dict will be stored as the model. trainer_state_dict = make_trainer_state_dict(step, opt_state, history, model_state, self._input_signature) self._save_state_dict(trainer_state_dict, weights_file) if keep: weights_file = os.path.join(output_dir, '{}_{}.pkl.gz'.format(prefix, step)) self._save_state_dict(trainer_state_dict, weights_file)
def model_weights(self): # Currently we need to pick [0] as we ignore loss weights (empty). weights = self._opt_state.weights[0] if self.n_devices > 1: unreplicate = lambda x: x[0] weights = math.nested_map(unreplicate, weights) return weights
def _forward_abstract(self, input_signature): """Computes shapes and dtypes this layer would produce in a forward pass. Args: input_signature: A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances; signatures of inputs. Returns: A tuple of (output, state). The output part of the tuple is a ShapeDtype instance representing the shape and type of the output (if this layer has one output) or a tuple of ShapeDtype instances (if this layer has more than one output). """ try: # Beware: using an actual RNG (as opposed to this ShapeDtype stub) would # cause a large number of dropout masks to be computed and permanently # stored in global memory. rng = ShapeDtype((2, ), onp.uint32) def call_on_input(x, weights, state, rng): return self.forward_with_state(x, weights=weights, state=state, rng=rng) weight_signature = nested_map(signature, self.weights) s = math.abstract_eval(call_on_input)(input_signature, weight_signature, self.state, rng) return s except Exception: name, trace = self.__class__.__name__, _short_traceback(skip=3) raise LayerError(name, '_forward_abstract', self._caller, input_signature, trace)
def _forward_abstract(self, input_signature): """Computes shapes and dtypes this layer would produce in a forward pass. Args: input_signature: `ShapeDtype` instance (if this layer takes one input) or list/tuple of `ShapeDtype` instances. Returns: Tuple of (output, state). The output part of the tuple is a `ShapeDtype` instance representing the shape and type of the output (if this layer has one output) or a tuple of `ShapeDtype` instances (if this layer has more than one output). """ try: # Note: By using rng_signature in place of an rng, we avoid computing and # permanently storing in global memory a large number of dropout masks. # TODO(jonni): Check if using an rng still carries this cost. dummy_rng = math.random.get_prng(0) rng_signature = ShapeDtype(dummy_rng.shape, dummy_rng.dtype) weight_signature = nested_map(signature, self.weights) forward_infer_shapes = math.abstract_eval(self.pure_fn) return forward_infer_shapes(input_signature, weight_signature, self.state, rng_signature) except Exception: # Skipping 13 lines which are all JAX abstract'ifying wrappers. name, trace = self._name, _short_traceback(skip=13) raise LayerError(name, '_forward_abstract', self._caller, input_signature, trace) from None
def _forward_abstract(self, input_signature): """Computes shapes and dtypes this layer would produce in a forward pass. Args: input_signature: ShapeDtype instance (if this layer takes one input) or list/tuple of ShapeDtype instances. Returns: Tuple of (output, state). The output part of the tuple is a ShapeDtype instance representing the shape and type of the output (if this layer has one output) or a tuple of ShapeDtype instances (if this layer has more than one output). """ try: # Note: By using rng_signature in place of an rng, we avoid computing and # permanently storing in global memory a large number of dropout masks. # TODO(jonni): Check if using an rng still carries this cost. rng_signature = ShapeDtype((2, ), np.uint32) weight_signature = nested_map(signature, self.weights) forward_infer_shapes = math.abstract_eval(self.forward_with_state) return forward_infer_shapes(input_signature, weight_signature, self.state, rng_signature) except Exception as e: name, trace = self._name, _short_traceback(skip=3) raise LayerError(name, '_forward_abstract', self._caller, input_signature, trace) from e
def _sizes(x): """Get a structure of sizes for a structure of nested arrays.""" def size(x): try: return x.size except Exception: # pylint: disable=broad-except return 0 return math.nested_map(size, x)
def _test_train_eval_predict(self, backend_name): if xla_bridge.device_count() > 1 and backend_name == 'tf': self.skipTest("tf-numpy backend does't support multi-devices yet.") with math.use_backend(backend_name), self.tmp_dir() as output_dir: # Prepare model and inputs n_classes = 4 steps = 2 eval_steps = 2 # Adds Dropout and BatchNorm to test state handling. def model_fn(mode='train'): return layers.Serial( layers.Dropout(mode=mode, rate=0.1), layers.BatchNorm(mode=mode), models.MLP(d_hidden=16, n_output_classes=n_classes, mode=mode)) inputs = test_inputs(n_classes) # Train and evaluate state = trainer_lib.train(output_dir, model=model_fn, inputs=inputs, steps=steps, eval_steps=eval_steps) # Assert total train steps self.assertEqual(steps, state.step) # Assert 2 evaluations ran train_acc = state.history.get('train', 'metrics/accuracy') eval_acc = state.history.get('eval', 'metrics/accuracy') self.assertEqual(len(train_acc), len(eval_acc)) self.assertLen(eval_acc, 2) # Predict with final weights inputs = inputs.train_stream(1) model = model_fn() weights = state.opt_state.weights[0] state = state.model_state[0] if xla_bridge.device_count() > 1: unreplicate = lambda x: x[0] weights = math.nested_map(unreplicate, weights) state = math.nested_map(unreplicate, state) model(next(inputs)[0], weights=weights, state=state)
def _shapes(x): """Gets a structure of shapes for a structure of nested arrays.""" def shape(x): try: return tuple([int(i) for i in x.shape]) except Exception: # pylint: disable=broad-except return () return tuple(nested_map(shape, x))
def _combine_devices(x_tuple): """Combines multi-device tensors into a single batch.""" def f(x): if len(x.shape) < 2: return x # No extra batch dimension: use devices as batch, so return. batch_size = x.shape[0] * x.shape[1] return math.numpy.reshape(x, [batch_size] + list(x.shape[2:])) return math.nested_map(f, x_tuple)
def predict(x, weights, state, rng): """Predict function jited and parallelized as requested.""" res, state = _combine_devices(model_predict( reshape_by_device(x, n_devices), weights, state, np.stack(math.random.split(rng, n_devices)))) return math.nested_map(lambda y: np.mean(y, axis=0), res), state
def _default_timestep_to_np(self, ts): """Default way to convert timestep to numpy.""" return math.nested_map(np.array, ( ts.observation, ts.action, ts.dist_inputs, ts.reward, ts.discounted_return, ))
def predict(x, weights, state, rng): """Predict function JIT-compileds and parallelized as requested.""" res, state = _combine_devices( model_predict(reshape_by_device(x, n_devices), weights, state, jnp.stack(math.random.split(rng, n_devices)))) if do_mean: return math.nested_map(lambda y: jnp.mean(y, axis=0), res), state else: return res, state
def for_n_devices(x, n_devices): """Replicates/broadcasts `x` for `n_devices`.""" def f(x): if n_devices > 1 and math.backend_name() == 'jax': return _multi_device_put(x) elif n_devices > 1: return jnp.broadcast_to(x, (n_devices,) + x.shape) else: return x return math.nested_map(f, x)
def print_n_weights(self): """Prints the total count of trainable weights.""" opt_state = self._opt_state sizes = _sizes(opt_state.weights) if self.n_devices > 1: unreplicate = lambda x: x[0] single_weights = math.nested_map(unreplicate, opt_state.weights) sizes = _sizes(single_weights) total_size = _nested_reduce(sum, sizes) self.log_step('Total number of trainable weights: %d' % total_size)
def dummy_inputs(rng, input_sig): def f(sig): shape = sig.shape if shape and shape[0] is None: shape = (2, ) + tuple(shape[1:]) if onp.issubdtype(sig.dtype, onp.integer): minval = None else: minval = 0 return rng.uniform(shape=shape, dtype=sig.dtype, minval=minval) return math_lib.nested_map(f, input_sig)
def reshape_by_device(x, n_devices): """Reshapes possibly nested `x` into a shape `(n_devices, ...)`.""" def f(x): x_shape = list(x.shape) batch_size = x_shape[0] batch_size_per_device = batch_size // n_devices if batch_size_per_device * n_devices != batch_size: raise ValueError(f'Number of devices ({n_devices}) does not evenly ' f'divide batch size ({batch_size}).') new_shape_prefix = [n_devices, batch_size_per_device] return math.numpy.reshape(x, new_shape_prefix + x_shape[1:]) return math.nested_map(f, x)
def reshape_by_device(x, n_devices): """Reshapes possibly nested x into a shape (n_devices, ...).""" def f(x): x_shape = list(x.shape) batch_size = x_shape[0] batch_size_per_device = batch_size // n_devices if batch_size_per_device * n_devices != batch_size: raise ValueError( 'We require that n_devices[%d] divides batch_size[%d] evenly.' % (n_devices, batch_size)) new_shape_prefix = [n_devices, batch_size_per_device] return math.numpy.reshape(x, new_shape_prefix + x_shape[1:]) return math.nested_map(f, x)
def _for_n_devices(self, x): """Replicates/broadcasts `x` for n devices if `self.n_devicess > 1`.""" n = self.n_devices def f(x): if n > 1 and math.backend_name() == 'jax': return _multi_device_put(x) elif n > 1: return np.broadcast_to(x, (n, ) + x.shape) else: return x return math.nested_map(f, x)
def build(self, input_shape): with math_lib.use_backend("tf"): # Using `is` instead of `==` following Trax's practice if self._trax_layer.weights is base.EMPTY_WEIGHTS: sanitized_input_shape = math_lib.nested_map( functools.partial(_replace_none_batch, batch_size=self._batch_size), input_shape) weights, state = self._trax_layer.init( tensor_shapes_to_shape_dtypes(sanitized_input_shape, self.dtype), rng=self._initializer_rng) else: weights = self._trax_layer.weights state = self._trax_layer.state # Note: `weights` may contain `EMPTY_WEIGHTS` self._weights = math_lib.nested_map( functools.partial(tf.Variable, trainable=True), weights) self._state = math_lib.nested_map( functools.partial(tf.Variable, trainable=False), state) self._rng = tf.Variable(self._forward_rng_init, trainable=False) super(TraxKerasLayer, self).build(input_shape)
def _default_timestep_to_np(self, ts): """Default way to convert timestep to numpy.""" return math.nested_map( np.array, TimeStepNp( observation=ts.observation, action=ts.action, dist_inputs=ts.dist_inputs, reward=ts.reward, done=ts.done, return_=ts.discounted_return, mask=ts.mask, ))
def policy(self, trajectory): """Chooses an action to play after a trajectory.""" model = self._policy_collect_model model.weights = self._policy_trainer.model_weights tr_slice = trajectory[-self._max_slice_length:] trajectory_np = tr_slice.to_np(timestep_to_np=self.task.timestep_to_np) # Add batch dimension to trajectory_np and run the model. pred = model(trajectory_np.observations[None, ...], n_accelerators=1) # Pick element 0 from the batch (the only one), last (current) timestep. pred = pred[0, -1, :] sample = self._policy_dist.sample(pred) result = (sample, pred) if math.backend_name() == 'jax': result = math.nested_map(lambda x: x.copy(), result) return result
def train_step(self, batch): """Run one training step and update self._opt_state.""" # Calculate the current optimizer parameters. # TODO(pkozakowski): Optimizer parameters get polluted with model state, # which doesn't break anything but is weird. Filter it out. opt_param_updates = self._for_n_devices( math.nested_map(np.array, self.nontrainable_params)) opt_state = self._opt_state opt_state.opt_params.update(opt_param_updates) # Run the update. (weights, slots), self._model_state, self._rngs = self._jit_update_fn( self._step, opt_state, batch, self._model_state, self._rngs) self._model_state = self._map_to_state_dicts(self._state_dicts_update) self._opt_state = opt_state._replace(weights=weights, slots=slots) self._step += 1
def __call__(self, x, **kwargs): """Makes Layer instances callable; for use in tests or interactive settings. This convenience method helps library users play with, test, or otherwise probe the behavior of layers outside of a full training environment. It presents the layer as callable function from inputs to outputs, with the option of manually specifying weights and non-parameter state per individual call. For convenience, weights and non-parameter state are cached per layer instance, starting from default values of `EMPTY_WEIGHTS` and `EMPTY_STATE`, and acquiring non-empty values either by initialization or from values explicitly provided via the weights and state keyword arguments. Args: x: 0 or more input tensors, formatted the same as the inputs to Layer.forward. **kwargs: Additional keyword arguments if needed/desired for this layer. Three possible keyword arguments are especially relevant: - weights=... will override any cached weights values - state=... will override any cached state values - rng=... will supply a PRNG key for use by the layer Returns: 0 or more output tensors, formatted the same as the outputs from Layer.forward. """ weights = kwargs.pop('weights', self.weights) state = kwargs.pop('state', self.state) rng = kwargs.pop('rng', self._rng) rng = math.random.get_prng(0) if rng is None else rng forward = self._forward_internal # TODO(lukaszkaiser): the following arguments are experimental, decide which # are really useful after a number of experiments and finalize the API. n_accelerators = kwargs.pop('n_accelerators', 0) replicate = kwargs.pop('replicate', True) if n_accelerators > 1 and replicate: weights = for_n_devices(weights, n_accelerators) state = for_n_devices(state, n_accelerators) if n_accelerators: forward = jit_forward(forward, n_accelerators) outputs, new_state = forward(x, weights, state, rng) if n_accelerators > 1 and replicate: # Unreplicate state if needed. new_state = math.nested_map(new_state, lambda x: x[0]) self.state = new_state self.weights = weights return outputs
def call(self, inputs): with math_lib.use_backend("tf"): inputs = math_lib.nested_map( functools.partial(_replace_none_batch, batch_size=self._batch_size), inputs) weights, state, rng = read_values( [self._weights, self._state, self._rng]) inputs, weights, state, rng = to_arrays( [inputs, weights, state, rng]) outputs, new_state = self._trax_layer.pure_fn(inputs, weights=weights, state=state, rng=rng) tf.nest.map_structure(lambda v, t: v.assign(t), self._state, new_state) self._rng.assign(self._rng_updater(rng)) outputs = to_tensors(outputs) return outputs
def save_state(self, keep): """Save trainer state given a possibly replicated opt_state.""" opt_state = self._opt_state if self.n_devices > 1: first_replica = lambda x: x[0] opt_state = OptState(*math.nested_map(first_replica, opt_state)) # This line, while optional, allows JAX to transfer arrays from the device # to the host in parallel, which is particularly important for cloud TPU. if math.backend_name() == 'jax': opt_state = jax.device_get(opt_state) step, history, model_state = self._step, self._history, self._model_state output_dir = self._output_dir pkl_module = utils.get_pickle_module() weights_file = os.path.join(output_dir, 'model.pkl') with tf.io.gfile.GFile(weights_file, 'wb') as f: pkl_module.dump((tuple(opt_state), step, history, model_state), f) if keep: weights_file = os.path.join(output_dir, 'model_{}.pkl'.format(step)) with tf.io.gfile.GFile(weights_file, 'wb') as f: pkl_module.dump((tuple(opt_state), step, history, model_state), f) log('Model saved to %s' % weights_file, stdout=False)
def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir=None, random_seed=None, n_devices=None, checkpoints_at=None, should_save_checkpoints=True, should_write_summaries=True, has_weights=False, nontrainable_param_map=None, id_to_mask=None, metrics=None, checkpoint_highest=None, checkpoint_lowest=None): self._is_chief, self._n_devices, rng = (self._init_host_and_devices( n_devices, random_seed)) self._should_save_checkpoints = should_save_checkpoints and self._is_chief self._checkpoints_at = checkpoints_at or [] self._should_write_summaries = should_write_summaries if not output_dir: self._should_save_checkpoints = False self._should_write_summaries = False self._checkpoint_highest = checkpoint_highest self._checkpoint_lowest = checkpoint_lowest self._has_weights = has_weights self._id_to_mask = id_to_mask self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS loss_fn = loss_fn(has_weights=has_weights, id_to_mask=id_to_mask) # Inputs is either an Inputs instance or a function that returns it. self._inputs = inputs if callable( inputs): # If we pass a function, e.g., through gin, call it. self._inputs = inputs() # Initialize the learning rate to a dummy value. It will be set in reset(). opt = optimizer(learning_rate=0.0) # Setup the model. model_train = model(mode='train') model_predict_eval = model(mode='eval') # Setup state. rng, init_rng = jax_random.split(rng) self._rngs = np.stack(jax_random.split(rng, self._n_devices)) # If the inputs are a tuple/list, add [None] (batch) to each element. if self._inputs.input_shape and isinstance(self._inputs.input_shape[0], (list, tuple)): model_input_shape = tuple( tuple([None] + list(shape)) for shape in self._inputs.input_shape) else: # Otherwise just add [None] to the input shape. model_input_shape = tuple([None] + list(self._inputs.input_shape)) # Same for targets. if self._inputs.target_shape and isinstance( self._inputs.target_shape[0], (list, tuple)): model_target_shape = tuple( tuple([None] + list(shape)) for shape in self._inputs.target_shape) else: model_target_shape = tuple([None] + list(self._inputs.target_shape)) # Change all None to 1 in input and target shape. model_input_shape = math.nested_map(lambda x: x or 1, model_input_shape) model_target_shape = math.nested_map(lambda x: x or 1, model_target_shape) def new_opt_state_and_model_state(shape_dtype, rng): """Returns optimizer and model states suitable for training a model.""" # Combine inputs and targets on the stack. shapes, dtypes = shape_dtype input_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes)) # We need to create a new model instance and not reuse `model_train` here, # because `m.initialize` puts cached parameter values in `m` and hence the # next call of `m.initialize` will give wrong results. m = tl.Serial(model(mode='train'), loss_fn) m._set_rng_recursive(rng) # pylint: disable=protected-access weights, state = m.init(input_signature) (slots, opt_params) = opt.tree_init(weights) return (OptState(weights, slots, opt_params), state) if _is_jit_init(): # JIT parameter initialization to avoid memory fragmentation new_opt_state_and_model_state = math.jit( new_opt_state_and_model_state, static_argnums=(0, )) self._new_opt_state_and_model_state = ( lambda: new_opt_state_and_model_state( # pylint: disable=g-long-lambda self._inputs.example_shape_dtype, init_rng)) # Arrange and initialize metrics layers. self._metrics = list(sorted(self._metrics_dict.keys())) metrics_layers = [ self._metrics_dict[m](has_weights=self._has_weights, id_to_mask=self._id_to_mask) for m in self._metrics ] metrics_in_parallel = tl.Branch(*metrics_layers) metrics_in_parallel._set_rng_recursive(init_rng) # pylint: disable=protected-access example_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(*self._inputs.example_shape_dtype)) model_predict_eval.init(example_signature) output_signature = model_predict_eval.output_signature( example_signature) m_weights, m_state = metrics_in_parallel.init(output_signature) self._metrics_weights = self._for_n_devices(m_weights) self._metrics_state = self._for_n_devices(m_state) # Jit model_predict and update so they're fast. self._jit_eval = _jit_predict_fn(model_predict_eval, metrics_in_parallel, self._n_devices) self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, self._n_devices) self._model_train = model_train self._model_predict_eval = model_predict_eval self._loss_fn = loss_fn # TODO(pkozakowski): "Learning rate schedules" are currently able to control # control all optimizer parameters and model state, so let's rename them # accordingly. self._lr_schedule = lr_schedule if nontrainable_param_map is None: nontrainable_param_map = {} self._nontrainable_param_map = nontrainable_param_map # Those fields will be set in reset(). self._output_dir = None self._train_sw = None self._eval_sw = None self._history = None self._lr_fn = None self._opt_state = None self._step = None self._model_state = None self.reset(output_dir)
def unreplicate(self, unreplicate_state=False): """Unreplicate weights and optionally state. Experimental.""" self.weights = math.nested_map(self.weights, lambda x: x[0]) if unreplicate_state: self.state = math.nested_map(self.state, lambda x: x[0])
def trajectory_batch_stream(self, batch_size, epochs=None, max_slice_length=None, min_slice_length=None, margin=0, include_final_state=False, sample_trajectories_uniformly=False): """Return a stream of trajectory batches from the specified epochs. This function returns a stream of tuples of numpy arrays (tensors). If tensors have different lengths, they will be padded by 0. Args: batch_size: the size of the batches to return epochs: a list of epochs to use; we use all epochs if None max_slice_length: maximum length of the slices of trajectories to return min_slice_length: minimum length of the slices of trajectories to return margin: number of extra steps after "done" that should be included in slices, so that networks see the terminal states in the training data include_final_state: whether to include slices with the final state of the trajectory which may have no action and reward sample_trajectories_uniformly: whether to sample trajectories uniformly, or proportionally to the number of slices in each trajectory (default) Yields: batches of trajectory slices sampled uniformly from all slices of length at least min_slice_length and up to max_slice_length in all specified epochs """ def pad(tensor_list): # Replace Nones with valid tensors. not_none_tensors = [t for t in tensor_list if t is not None] assert not_none_tensors, 'All tensors to pad are None.' prototype = np.zeros_like(not_none_tensors[0]) tensor_list = [ t if t is not None else prototype for t in tensor_list ] max_len = max([t.shape[0] for t in tensor_list]) min_len = min([t.shape[0] for t in tensor_list]) if max_len == min_len: # No padding needed. return np.array(tensor_list) pad_len = 2**int(np.ceil(np.log2(max_len))) return np.array([ _zero_pad(t, (0, pad_len - t.shape[0]), axis=0) for t in tensor_list ]) cur_batch = [] for t in self.trajectory_stream(epochs, max_slice_length, include_final_state, sample_trajectories_uniformly, margin=margin): # TODO(pkozakowski): Instead sample the trajectories out of those with # the minimum length. if min_slice_length is not None and len(t) < min_slice_length: continue cur_batch.append(t) if len(cur_batch) == batch_size: # TODO(pkozakowski): Unpack based on name instead of position in the # tuple (how?). obs, act, dinp, rew, ret, done, mask = zip( *[t.to_np(self._timestep_to_np) for t in cur_batch]) # Where act, rew and ret will usually have the following shape: # [batch_size, trajectory_length-1], which we call [B, L-1]. # Observations are more complex and will usuall be [B, L] + S where S # is the shape of the observation space (self.observation_space.shape). # We stop the recursion at level 1, so we pass lists of arrays into # pad(). yield math.nested_map(pad, TrajectoryNp( observations=obs, actions=act, dist_inputs=dinp, rewards=rew, dones=done, returns=ret, mask=mask, ), level=1) cur_batch = []