def pseudo_forward(self, pseudo_inputs, params, state): """Computes shapes and types this layer would produce for the given inputs. Args: pseudo_inputs: A ShapeDtype instance (input data minus the actual values) or a tuple of ShapeDtype instances, following the same conventions as Layer.forward's input arg. params: Parameters for this layer. state: start state. Returns: A tuple of (output, state). The output part of the tuple is a ShapeDtype instance representing the shape and type of the output (if this layer has one output) or a tuple of ShapeDtype instances (if this layer has more than one output). """ try: # Beware: using an actual RNG (as opposed to this ShapeDtype stub) would # cause a large number of dropout masks to be computed and permanently # stored in global memory. rng = ShapeDtype((2,), onp.uint32) def call_on_input(x, params, state, rng): return self.forward(x, params=params, state=state, rng=rng) params_shapes = nested_map(shape_dtype_for, params) s = backend.eval_on_shapes(call_on_input)(pseudo_inputs, params_shapes, state, rng) return s except Exception: name, trace = self.__class__.__name__, _short_traceback(skip=3) raise LayerError(name, 'pseudo_forward', self._caller, pseudo_inputs, None, trace)
def _forward_abstract(self, input_signature): """Computes shapes and dtypes this layer would produce in a forward pass. Args: input_signature: A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances; signatures of inputs. Returns: A tuple of (output, state). The output part of the tuple is a ShapeDtype instance representing the shape and type of the output (if this layer has one output) or a tuple of ShapeDtype instances (if this layer has more than one output). """ try: # Beware: using an actual RNG (as opposed to this ShapeDtype stub) would # cause a large number of dropout masks to be computed and permanently # stored in global memory. rng = ShapeDtype((2, ), onp.uint32) def call_on_input(x, weights, state, rng): return self.forward_with_state(x, weights=weights, state=state, rng=rng) weight_signature = nested_map(signature, self.weights) s = backend.abstract_eval(call_on_input)(input_signature, weight_signature, self.state, rng) return s except Exception: name, trace = self.__class__.__name__, _short_traceback(skip=3) raise LayerError(name, '_forward_abstract', self._caller, input_signature, trace)
def sizes(x): """Get a structure of sizes for a structure of nested arrays.""" def size(x): try: return x.size except Exception: # pylint: disable=broad-except return 0 return nested_map(size, x)
def shapes(x): """Get a structure of shapes for a structure of nested arrays.""" def shape(x): try: return tuple([int(i) for i in x.shape]) except Exception: # pylint: disable=broad-except return [] return nested_map(shape, x)
def _combine_devices(x_tuple): """Combine multi-device tensors into a single batch.""" def f(x): if len(x.shape) < 2: return x # No extra batch dimension: use devices as batch, so return. batch_size = x.shape[0] * x.shape[1] return backend.numpy.reshape(x, [batch_size] + list(x.shape[2:])) return backend.nested_map(f, x_tuple)
def predict(x, weights, state, rng): """Predict function jited and parallelized as requested.""" res, state = _combine_devices(model_predict( _reshape_by_device(x, n_devices), weights, state, np.stack(jax_random.split(rng, n_devices)))) return backend.nested_map(lambda y: np.mean(y, axis=0), res), state
def print_n_weights(self): """Prints the total count of trainable weights.""" opt_state = self._opt_state sizes = _sizes(opt_state.weights) if self.n_devices > 1: unreplicate = lambda x: x[0] single_weights = backend.nested_map(unreplicate, opt_state.weights) sizes = _sizes(single_weights) total_size = _nested_reduce(sum, sizes) self.log_step('Total number of trainable weights: %d' % total_size)
def _for_n_devices(self, x): """Replicates/broadcasts `x` for n devices if `self.n_devicess > 1`.""" n = self.n_devices def f(x): if n > 1 and backend.get_name() == 'jax': return _multi_device_put(x) elif n > 1: return np.broadcast_to(x, (n,) + x.shape) else: return x return backend.nested_map(f, x)
def _reshape_by_device(x, n_devices): """Reshapes possibly nested x into a shape (n_devices, ...).""" def f(x): x_shape = list(x.shape) batch_size = x_shape[0] batch_size_per_device = batch_size // n_devices if batch_size_per_device * n_devices != batch_size: raise ValueError( 'We require that n_devices[%d] divides batch_size[%d] evenly.' % (n_devices, batch_size)) new_shape_prefix = [n_devices, batch_size_per_device] return backend.numpy.reshape(x, new_shape_prefix + x_shape[1:]) return backend.nested_map(f, x)
def train_step(self, batch): """Run one training step and update self._opt_state.""" # Calculate the current optimizer parameters. # TODO(pkozakowski): Optimizer parameters get polluted with model state, # which doesn't break anything but is weird. Filter it out. opt_param_updates = self._for_n_devices( backend.nested_map(np.array, self.nontrainable_params)) opt_state = self._opt_state opt_state.opt_params.update(opt_param_updates) # Run the update. (weights, slots), self._model_state, self._rngs = self._jit_update_fn( self._step, opt_state, batch, self._model_state, self._rngs) self._model_state = self._map_to_state_dicts(self._state_dicts_update) self._opt_state = opt_state._replace(weights=weights, slots=slots) self._step += 1
def save_state(self, keep): """Save trainer state given a possibly replicated opt_state.""" opt_state = self._opt_state if self.n_devices > 1: first_replica = lambda x: x[0] opt_state = OptState(*backend.nested_map(first_replica, opt_state)) # This line, while optional, allows JAX to transfer arrays from the device # to the host in parallel, which is particularly important for cloud TPU. if backend.get_name() == 'jax': opt_state = jax.device_get(opt_state) step, history, model_state = self._step, self._history, self._model_state output_dir = self._output_dir pkl_module = utils.get_pickle_module() weights_file = os.path.join(output_dir, 'model.pkl') with tf.io.gfile.GFile(weights_file, 'wb') as f: pkl_module.dump((tuple(opt_state), step, history, model_state), f) if keep: weights_file = os.path.join(output_dir, 'model_{}.pkl'.format(step)) with tf.io.gfile.GFile(weights_file, 'wb') as f: pkl_module.dump((tuple(opt_state), step, history, model_state), f) log('Model saved to %s' % weights_file, stdout=False)
def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir=None, random_seed=None, n_devices=None, save_steps=None, should_save_checkpoints=True, should_write_summaries=True, has_weights=False, nontrainable_param_map=None, mask_id=None, metrics=None): if backend.get_name() == 'jax': self._host_id = jax.host_id() self._host_count = jax.host_count() else: self._host_id = 0 self._host_count = 1 self._is_chief = (self._host_id == 0) if save_steps is None: save_steps = [] self._save_steps = save_steps self._should_save_checkpoints = should_save_checkpoints self._should_write_summaries = should_write_summaries self._has_weights = has_weights self._mask_id = mask_id self._metrics_dict = _METRICS if metrics is None else metrics loss_fn = loss_fn(has_weights=has_weights, mask_id=mask_id) device_count = backend.device_count() n_devices = n_devices or device_count # TODO(lukaszkaiser): remove this restriction when possible. if n_devices != device_count and backend.get_name() == 'jax': raise ValueError( 'JAX cannot work yet with n_devices != all devices: ' '%d != %d' % (n_devices, device_count)) self._n_devices = n_devices # Simple differential seeding of RNG across hosts by host_id and time. if random_seed is None and self._host_count > 1: _, random_seed = divmod( int(time.time() * 1e6) + int(self._host_id * 1e6), 2**32) rng = get_random_number_generator_and_set_seed(random_seed) inputs = inputs(n_devices) self._inputs = inputs # Initialize the learning rate to a dummy value. It will be set in reset(). opt = optimizer(learning_rate=0.0) # Setup the model. model_train = model(mode='train') model_predict_eval = model(mode='eval') # Setup state. rng, init_rng = jax_random.split(rng) self._rngs = np.stack(jax_random.split(rng, n_devices)) first_shape = inputs.input_shape[0] # If the inputs are a tuple/list, add [None] (batch) to each element. if isinstance(first_shape, (list, tuple)): model_input_shape = tuple( tuple([None] + list(shape)) for shape in inputs.input_shape) model_target_shape = tuple( tuple([None] + list(shape)) for shape in inputs.target_shape) else: # Otherwise just add [None] to the input shape. model_input_shape = tuple([None] + list(inputs.input_shape)) model_target_shape = tuple([None] + list(inputs.target_shape)) # Change all None to 1 in input and target shape. model_input_shape = backend.nested_map(lambda x: x or 1, model_input_shape) model_target_shape = backend.nested_map(lambda x: x or 1, model_target_shape) def new_opt_state_and_model_state(input_shape, input_dtype, target_shape, target_dtype, rng): """Returns optimizer and model states suitable for training a model.""" # Combine inputs and targets on the stack. if not isinstance(input_dtype, (list, tuple)): input_dtype = [input_dtype] input_shape = [input_shape] if not isinstance(target_dtype, (list, tuple)): target_dtype = [target_dtype] target_shape = [target_shape] dtypes = list(input_dtype) + list(target_dtype) shapes = list(input_shape) + list(target_shape) if self._has_weights: shapes += list(target_shape) dtypes += [np.float32 for _ in target_dtype] input_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes)) # We need to create a new model instance and not reuse `model_train` here, # because `m.initialize` puts cached parameter values in `m` and hence the # next call of `m.initialize` will give wrong results. m = tl.Serial(model(mode='train'), loss_fn) m._set_rng_recursive(rng) # pylint: disable=protected-access weights, state = m.init(input_signature) (slots, opt_params) = opt.tree_init(weights) return (OptState(weights, slots, opt_params), state) if _is_jit_init(): # JIT parameter initialization to avoid memory fragmentation new_opt_state_and_model_state = backend.jit( new_opt_state_and_model_state, static_argnums=(0, 1, 2, 3)) self._new_opt_state_and_model_state = ( lambda: new_opt_state_and_model_state( # pylint: disable=g-long-lambda model_input_shape, self._inputs.input_dtype, model_target_shape, self._inputs.target_dtype, init_rng)) # Arrange and initialize metrics layers. self._metrics = list(sorted(self._metrics_dict.keys())) metrics_layers = [ self._metrics_dict[m](has_weights=self._has_weights, mask_id=self._mask_id) for m in self._metrics ] metrics_in_parallel = tl.Branch(*metrics_layers) # TODO(lukaszkaiser): clean this up once layer API stabilizes. # For now, we need to initialize metric layers somehow, so here we go. # We assume that they do not have any parameters, so this is a dummy. dummy_shapes = ((1, 2), (1, ), (1, )) if self._has_weights else ((1, 2), (1, )) dummy_signature = tuple(ShapeDtype(s) for s in dummy_shapes) metrics_in_parallel._set_rng_recursive(init_rng) # pylint: disable=protected-access m_weights, m_state = metrics_in_parallel.init(dummy_signature) self._metrics_weights = self._for_n_devices(m_weights) self._metrics_state = self._for_n_devices(m_state) # Jit model_predict and update so they're fast. self._jit_eval = _jit_predict_fn(model_predict_eval, metrics_in_parallel, n_devices) self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices) self._model_train = model_train self._model_predict_eval = model_predict_eval self._loss_fn = loss_fn # TODO(pkozakowski): "Learning rate schedules" are currently able to control # control all optimizer parameters and model state, so let's rename them # accordingly. self._lr_schedule = lr_schedule if nontrainable_param_map is None: nontrainable_param_map = {} self._nontrainable_param_map = nontrainable_param_map # Those fields will be set in reset(). self._output_dir = None self._train_sw = None self._eval_sw = None self._history = None self._lr_fn = None self._opt_state = None self._step = None self._model_state = None if output_dir is not None: self.reset(output_dir)