def _save_replicated(opt_state, step, history, n_devices, output_dir, keep): """Save state but given a possibly replicated opt_state.""" if n_devices > 1: unreplicate = lambda x: x.mean(0) opt_state = layers.nested_map(opt_state, unreplicate) save_state(State(params=opt_state, step=step, history=history), output_dir, keep=keep)
def reset(self, output_dir): """Reset the model parameters. Restores the parameters from the given output_dir if a checkpoint exists, otherwise randomly initializes them. Does not re-jit the model. Args: output_dir: Output directory. """ self._output_dir = output_dir gfile.makedirs(output_dir) # Create summary writers and history. self._train_sw = jaxboard.SummaryWriter( os.path.join(output_dir, "train")) self._eval_sw = jaxboard.SummaryWriter(os.path.join( output_dir, "eval")) # Reset the train and eval streams. self._train_stream = self._inputs.train_stream() # TODO(lukaszkaiser): add an option to evaluate exactly on the full eval # set by adding a padding and stopping the stream when too large. self._eval_stream = _repeat_stream(self._inputs.eval_stream) self._train_eval_stream = _repeat_stream( self._inputs.train_eval_stream) # Restore the training state. state = restore_state(output_dir) self._step = state.step or 0 history = state.history self._lr_fn = self._lr_schedule(history) self._history = history if state.opt_state: opt_state = state.opt_state model_state = state.model_state else: opt_state, model_state = self._initialize() model_state = layers.nested_map(model_state, self._maybe_replicate) self._opt_state = OptState( *layers.nested_map(opt_state, self._maybe_replicate)) self._model_state = model_state if not state.opt_state: self._maybe_save_state(keep=False) self.update_optimizer_params()
def _print_n_params(opt_state, n_devices, step): """Print out the number of parameters.""" sizes = layers.sizes(opt_state.params) if n_devices > 1: unreplicate = lambda x: x[0] single_params = layers.nested_map(opt_state.params, unreplicate) sizes = layers.sizes(single_params) total_size = layers.nested_reduce(sizes, sum) step_log(step, "Total trainable parameters size: %d" % total_size)
def _save_replicated(opt_state, step, history, n_devices, output_dir, keep): """Save state but given a possibly replicated opt_state.""" if n_devices > 1: first_replica = lambda x: x[0] opt_state = layers.nested_map(opt_state, first_replica) # This line, while optional, allows JAX to transfer arrays from the device to # the host in parallel, which is particularly important for cloud TPU. if backend.get_name() == "jax": opt_state = jax.device_get(opt_state) save_state(State(params=opt_state, step=step, history=history), output_dir, keep=keep)
def reset(self, output_dir): """Reset the model parameters. Restores the parameters from the given output_dir if a checkpoint exists, otherwise randomly initializes them. Does not re-jit the model. Args: output_dir: Output directory. """ self._output_dir = output_dir gfile.makedirs(output_dir) # Create summary writers and history. self._train_sw = jaxboard.SummaryWriter( os.path.join(output_dir, "train")) self._eval_sw = jaxboard.SummaryWriter(os.path.join( output_dir, "eval")) # Reset the training stream. self._train_stream = self._inputs.train_stream() # Restore the training state. state = restore_state(output_dir) self._step = state.step or 0 history = state.history self._lr_fn = self._lr_schedule(history) self._history = history if state.opt_state: opt_state = state.opt_state model_state = state.model_state else: opt_state, model_state = self._initialize() model_state = layers.nested_map(model_state, self._maybe_replicate) self._opt_state = OptState( *layers.nested_map(opt_state, self._maybe_replicate)) self._model_state = model_state if not state.opt_state: self._maybe_save_state(keep=False) self.update_learning_rate()
def predict(x, params=(), state=(), rng=None): """Predict function jited and parallelized as requested.""" pred, state = mapped_predict(reshape_by_device(x, n_devices), params, state, jax_random.split(rng, n_devices)) # Need to reduce the [device, per-device-batch, ...] tensors back to # a [batch, ...] tensor. The tensors may be nested. def combine(x): batch_size = x.shape[0] * x.shape[1] return np.reshape(x, [batch_size] + list(x.shape[2:])) return layers.nested_map(pred, combine), state
def _train_step(self, next_train_batch): """Run one training step and update self._opt_state.""" # Calculate the current learning rate. opt_param_updates = layers.nested_map( self.optimizer_params, lambda x: self._maybe_replicate(np.array(x))) opt_state = self._opt_state opt_state.opt_params.update(opt_param_updates) # Run the update. (params, slots), self._model_state, self._rngs = self._jit_update_fn( self._step, opt_state, next_train_batch, self._model_state, self._rngs) self._opt_state = opt_state._replace(params=params, slots=slots) self._step += 1
def predict(x, params=(), state=(), rng=None): """Predict function jited and parallelized as requested.""" pred = mapped_predict(reshape_by_device(x, n_devices), params, state, jax_random.split(rng, n_devices)) # Need to reduce the [device, per-device-batch, ...] tensors back to # a [batch, ...] tensor. The tensors may be nested. def combine(x): if len(x.shape) > 1: batch_size = x.shape[0] * x.shape[1] return np.reshape(x, [batch_size] + list(x.shape[2:])) # TODO(lukaszkaiser): is returning averages for scalars the right choice? # If it is only scalar, return the average. return np.mean(x, axis=0) return layers.nested_map(pred, combine)
def _train_step(self, next_train_batch): """Run one training step and update self._opt_state.""" # Calculate the current optimizer parameters. # TODO(pkozakowski): Optimizer parameters get polluted with model state, # which doesn't break anything but is weird. Filter it out. opt_param_updates = layers.nested_map( self.nontrainable_params, lambda x: self._maybe_replicate(np.array(x)) ) opt_state = self._opt_state opt_state.opt_params.update(opt_param_updates) # Run the update. (params, slots), self._model_state, self._rngs = self._jit_update_fn( self._step, opt_state, next_train_batch, self._model_state, self._rngs) self._model_state = self._map_to_state_dicts(self._state_dicts_update) self._opt_state = opt_state._replace(params=params, slots=slots) self._step += 1
def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir=None, random_seed=None, n_devices=None, save_steps=None, should_save=True, has_weights=False): if save_steps is None: save_steps = [] self._save_steps = save_steps self._should_save = should_save self._has_weights = has_weights loss_fn = functools.partial(loss_fn, has_weights=self._has_weights) device_count = jax.lib.xla_bridge.device_count() n_devices = n_devices or device_count # TODO(lukaszkaiser): remove this restriction when possible. if n_devices != device_count: raise ValueError( "Jax cannot work yet with n_devices != all devices: " "%d != %d" % (n_devices, device_count)) self._n_devices = n_devices rng = get_random_number_generator_and_set_seed(random_seed) inputs = inputs(n_devices) self._inputs = inputs # Initialize the learning rate to a dummy value. It will be set in reset(). opt = optimizer(learning_rate=0.0) # Setup the model. model_train = model(mode="train") model_predict_eval = model(mode="eval") # Setup state. rng, init_rng = jax_random.split(rng) self._rngs = jax_random.split(rng, n_devices) first_shape = inputs.input_shape[0] # If the inputs are a tuple/list, add [None] (batch) to each element. if isinstance(first_shape, (list, tuple)): model_input_shape = tuple( tuple([None] + list(shape)) for shape in inputs.input_shape) model_target_shape = tuple( tuple([None] + list(shape)) for shape in inputs.target_shape) else: # Otherwise just add [None] to the input shape. model_input_shape = tuple([None] + list(inputs.input_shape)) model_target_shape = tuple([None] + list(inputs.target_shape)) # Change all None to 1 in input and target shape. model_input_shape = layers.nested_map(model_input_shape, lambda x: x if x else 1) model_target_shape = layers.nested_map(model_target_shape, lambda x: x if x else 1) def initialize(input_shape, input_dtype, target_shape, target_dtype, rng): """Helper to initialize the model.""" # Combine inputs and targets on the stack. if not isinstance(input_dtype, (list, tuple)): input_dtype = [input_dtype] input_shape = [input_shape] if not isinstance(target_dtype, (list, tuple)): target_dtype = [target_dtype] target_shape = [target_shape] full_type = list(input_dtype) + list(target_dtype) full_shape = list(input_shape) + list(target_shape) # We need to create a new model instance and not reuse `model_train` here, # because `m.initialize` puts cached parameter values in `m` and hence the # next call of `m.initialize` will give wrong results. params, state = model(mode="train").initialize( full_shape, full_type, rng) (slots, opt_params) = opt.tree_init(params) return (OptState(params, slots, opt_params), state) if _is_jit_init(): # JIT parameter initialization to avoid memory fragmentation initialize = backend.jit(initialize, static_argnums=(0, 1, 2, 3)) self._initialize = lambda: initialize( # pylint: disable=g-long-lambda model_input_shape, self._inputs.input_dtype, model_target_shape, self._inputs.target_dtype, init_rng) # jit model_predict and update so they're fast self._jit_model_predict_eval = _jit_predict_fn(model_predict_eval, n_devices) self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices) self._model_train = model_train self._model_predict_eval = model_predict_eval self._loss_fn = loss_fn self._lr_schedule = lr_schedule # Those fields will be set in reset(). self._output_dir = None self._train_sw = None self._eval_sw = None self._history = None self._lr_fn = None self._opt_state = None self._step = None self._model_state = None if output_dir is not None: self.reset(output_dir)
def reshape_by_device(x, n_devices): """Reshape possibly nested x into a shape [n_devices, ...].""" return layers.nested_map(x, lambda x: _reshape_by_device_single(x, n_devices))
def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir, random_seed=None, n_devices=None, save_steps=None): if save_steps is None: save_steps = [] self._save_steps = save_steps device_count = jax.lib.xla_bridge.device_count() n_devices = n_devices or device_count # TODO(lukaszkaiser): remove this restriction when possible. if n_devices != device_count: raise ValueError("Jax cannot work yet with n_devices != all devices: " "%d != %d" % (n_devices, device_count)) self._n_devices = n_devices rng = get_random_number_generator_and_set_seed(random_seed) self._output_dir = output_dir gfile.makedirs(output_dir) # Create summary writers and history. self._train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train")) self._eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval")) # Create input streams. inputs = inputs(n_devices) self._inputs = inputs self._train_stream = inputs.train_stream() # Setup optimizer and model. state = restore_state(output_dir) history = state.history self._lr_fn = lr_schedule(history) opt = optimizer(self._lr_fn) model_train = model(mode="train") model_predict_eval = model(mode="eval") # Setup state. step = state.step or 0 rng, init_rng = jax_random.split(rng) self._rngs = jax_random.split(rng, n_devices) first_shape = inputs.input_shape[0] # If the inputs are a tuple/list, add [None] (batch) to each element. if isinstance(first_shape, (list, tuple)): model_input_shape = tuple( tuple([None] + list(shape)) for shape in inputs.input_shape) else: # Otherwise just add [None] to the input shape. model_input_shape = tuple([None] + list(inputs.input_shape)) # Change all None to 1 in input shape. model_input_shape = layers.nested_map( model_input_shape, lambda x: x if x else 1) if state.params: params = state.params[0] opt_state = state.params else: params = model_train.initialize( model_input_shape, inputs.input_dtype, init_rng) opt_state = (params, opt.tree_init(params)) if n_devices > 1: replicate = lambda x: numpy.broadcast_to(x, (n_devices,) + x.shape) opt_state = layers.nested_map(opt_state, replicate) # jit model_predict and update so they're fast self._jit_model_predict_eval = _jit_predict_fn( model_predict_eval, n_devices) self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices) self._step = step self._model_train = model_train self._model_predict_eval = model_predict_eval self._loss_fn = loss_fn self._optimizer = optimizer self._opt_state = opt_state self._history = history self._lr_schedule = lr_schedule
def train(output_dir, model=gin.REQUIRED, loss_fn=loss, inputs=trax_inputs.inputs, optimizer=trax_opt.SM3, lr_schedule=lr.MultifactorSchedule, train_steps=1000, save_steps=None, eval_steps=10, eval_frequency=100, n_devices=None, random_seed=None, run_debug_step=False, save_graphs=True, save_backward_graph=False): """Train the model on the inputs. Args: output_dir: Directory where to put the logs and checkpoints. model: The model to train as a callable returning 2 callables, an init_fn and apply_fn. loss_fn: callable with signature: params, trax.inputs.Inputs, model, rng -> loss. inputs: callable returning trax.inputs.Inputs. optimizer: The optimizer (see optimizers/base.py for signature). lr_schedule: A learning rate schedule as a function that takes history and returns a function from step to learning rate (a float). train_steps: int, total number of training steps. save_steps: list of integers. Keep a model file at each of the supplied save steps. eval_steps: int, num of steps per evaluation. If None or 0, eval disabled. eval_frequency: int, how often to run evaluation (every eval_frequency steps). If None or 0, eval disabled. n_devices: how many devices to use (if None, default, use all available) random_seed: the random seed to use; time/os dependent if None (default). run_debug_step: bool, if True, will run the model and loss without @jit for one step. save_graphs: bool, if True, save computation graph to file. save_backward_graph: bool, if True, save backward graph to file too. Returns: trax.State """ if save_steps is None: save_steps = [] device_count = jax.lib.xla_bridge.device_count() n_devices = n_devices or device_count # TODO(lukaszkaiser): remove this restriction when possible. if n_devices != device_count: raise ValueError("Jax cannot work yet with n_devices != all devices: " "%d != %d" % (n_devices, device_count)) rng = get_random_number_generator_and_set_seed(random_seed) gfile.makedirs(output_dir) # Create summary writers and history. train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train")) eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval")) inputs = inputs(n_devices) # Setup optimizer and model state = restore_state(output_dir) history = state.history lr_fn = lr_schedule(history) opt = optimizer(lr_fn) model_train = layers.Serial(model(mode="train")) model_predict_eval = layers.Serial(model(mode="eval")) # Setup state step = state.step or 0 rng, init_rng = jax_random.split(rng) rngs = jax_random.split(rng, n_devices) first_shape = inputs.input_shape[0] # If the inputs are a tuple/list, add [-1] (batch) to each element. if isinstance(first_shape, (list, tuple)): model_input_shape = tuple( [tuple([-1] + list(shape)) for shape in inputs.input_shape]) else: # Otherwise just add [-1] to the input shape. model_input_shape = tuple([-1] + list(inputs.input_shape)) if state.params: params = state.params[0] opt_state = state.params else: params = model_train.initialize(model_input_shape, init_rng) opt_state = (params, opt.tree_init(params)) if n_devices > 1: replicate = lambda x: numpy.broadcast_to(x, (n_devices,) + x.shape) opt_state = layers.nested_map(opt_state, replicate) # jit model_predict and update so they're fast jit_model_predict_eval = _jit_predict_fn(model_predict_eval, n_devices) jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices) train_stream = inputs.train_stream() epoch_steps = [train_steps] # Only training if eval_frequency is 0 or None. if eval_frequency and eval_steps > 0: epoch_steps = itertools.chain([1, # first epoch only 1 step eval_frequency - 1], itertools.repeat(eval_frequency)) step_log(step, "Starting training using %d devices" % n_devices) # Non-compiled debug step helps find problems in models easier. if run_debug_step: debug_loss = loss_fn(params, next(train_stream), model_train, rng) step_log(step, "Debug step loss %.8f" % debug_loss) for epoch, epoch_steps in epochs(train_steps, epoch_steps): # Log separator print() # Timer start_time = time.time() for _ in range(epoch_steps): # Train next_train_batch = next(train_stream) if n_devices > 1: # TODO(lukaszkaiser): use everywhere when possible. next_train_batch = reshape_by_device(next_train_batch, n_devices) opt_state, rngs = jit_update_fn(step, opt_state, next_train_batch, rngs) step += 1 if step in save_steps: _save_replicated(opt_state, step, history, n_devices, output_dir, True) # LR log if step == 1 or step % 10 == 0: train_sw.scalar("training/learning rate", lr_fn(step), step=step) # Timer epoch_time = time.time() - start_time step_log(step, "Ran %d train steps in %0.2f secs" % (epoch_steps, epoch_time)) if epoch_steps > 1: train_sw.scalar("training/steps per second", epoch_steps / epoch_time, step=step) # Print number of parameters if step == 1: sizes = layers.sizes(opt_state[0]) if n_devices > 1: unreplicate = lambda x: x.mean(0) single_params = layers.nested_map(opt_state[0], unreplicate) sizes = layers.sizes(single_params) total_size = layers.nested_reduce(sizes, sum) step_log(step, "Total trainable parameters size: %d" % total_size) # Evaluate in parallel evaluate_train_and_eval( step=step, inputs=inputs, predict_fn=functools.partial(jit_model_predict_eval, params=opt_state[0]), eval_steps=eval_steps, rng=rng, train_sw=train_sw, eval_sw=eval_sw, history=history) # Save computation graph (single-device only for now). if save_graphs and step == 1 and n_devices == 1: params = opt_state[0] # Dump computation graphs to files. forward_computation = jax.xla_computation(model_predict_eval)( next_train_batch[0], params=params, rng=rng) with gfile.GFile(os.path.join(output_dir, "forward.txt"), "w") as f: f.write(forward_computation.GetHloText()) with gfile.GFile(os.path.join(output_dir, "forward.dot"), "w") as f: f.write(forward_computation.GetHloDotGraph()) backward_computation = jax.xla_computation(jit_update_fn)( step, opt_state, next_train_batch, rngs) with gfile.GFile(os.path.join(output_dir, "backward.txt"), "w") as f: f.write(backward_computation.GetHloText()) if save_backward_graph: # Backward graphs can be large so we guard it. with gfile.GFile(os.path.join(output_dir, "backward.dot"), "w") as f: f.write(backward_computation.GetHloDotGraph()) # Save state _save_replicated(opt_state, step, history, n_devices, output_dir, False) # Save Gin config # Gin only tracks the used parameters, so we save it after the first epoch. if epoch == 1: save_gin(output_dir, train_sw) # Update learning rate with new history old_lr_fn = lr_fn lr_fn = lr_schedule(history) if lr_fn != old_lr_fn: # For performance, only jit if there is a change. opt = optimizer(lr_fn) jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices) # Flush summary writers train_sw.flush() eval_sw.flush() step_log(step, "Training done") return State(params=opt_state, step=step, history=history)
def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir=None, random_seed=None, n_devices=None, save_steps=None, should_save=True): if save_steps is None: save_steps = [] self._save_steps = save_steps self._should_save = should_save device_count = jax.lib.xla_bridge.device_count() n_devices = n_devices or device_count # TODO(lukaszkaiser): remove this restriction when possible. if n_devices != device_count: raise ValueError("Jax cannot work yet with n_devices != all devices: " "%d != %d" % (n_devices, device_count)) self._n_devices = n_devices rng = get_random_number_generator_and_set_seed(random_seed) inputs = inputs(n_devices) self._inputs = inputs # Initialize the learning rate to a dummy value. It will be set in reset(). opt = optimizer(learning_rate=0.0) # Setup the model. model_train = model(mode="train") model_predict_eval = model(mode="eval") # Setup state. rng, init_rng = jax_random.split(rng) self._rngs = jax_random.split(rng, n_devices) first_shape = inputs.input_shape[0] # If the inputs are a tuple/list, add [None] (batch) to each element. if isinstance(first_shape, (list, tuple)): model_input_shape = tuple( tuple([None] + list(shape)) for shape in inputs.input_shape) else: # Otherwise just add [None] to the input shape. model_input_shape = tuple([None] + list(inputs.input_shape)) # Change all None to 1 in input shape. model_input_shape = layers.nested_map( model_input_shape, lambda x: x if x else 1) def initialize(input_shape, input_dtype, init_rng): params = model_train.initialize(input_shape, input_dtype, init_rng) (slots, opt_params) = opt.tree_init(params) return OptState(params, slots, opt_params) if _is_jit_init(): # JIT parameter initialization to avoid memory fragmentation initialize = backend.jit(initialize, static_argnums=(0, 1)) self._initialize = lambda: initialize( # pylint: disable=g-long-lambda model_input_shape, self._inputs.input_dtype, init_rng) # jit model_predict and update so they're fast self._jit_model_predict_eval = _jit_predict_fn( model_predict_eval, n_devices) self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices) self._model_train = model_train self._model_predict_eval = model_predict_eval self._loss_fn = loss_fn self._lr_schedule = lr_schedule # Those fields will be set in reset(). self._output_dir = None self._train_sw = None self._eval_sw = None self._history = None self._lr_fn = None self._opt_state = None self._step = None if output_dir is not None: self.reset(output_dir)
def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir=None, random_seed=None, n_devices=None, save_steps=None, should_save=True, has_weights=False, nontrainable_param_map=None, mask_id=None): if save_steps is None: save_steps = [] self._save_steps = save_steps self._should_save = should_save self._has_weights = has_weights self._mask_id = mask_id loss_fn = loss_fn(has_weights=has_weights, mask_id=mask_id) device_count = jax.lib.xla_bridge.device_count() n_devices = n_devices or device_count # TODO(lukaszkaiser): remove this restriction when possible. if n_devices != device_count: raise ValueError( "Jax cannot work yet with n_devices != all devices: " "%d != %d" % (n_devices, device_count)) self._n_devices = n_devices rng = get_random_number_generator_and_set_seed(random_seed) inputs = inputs(n_devices) self._inputs = inputs # Initialize the learning rate to a dummy value. It will be set in reset(). opt = optimizer(learning_rate=0.0) # Setup the model. model_train = model(mode="train") model_predict_eval = model(mode="eval") # Setup state. rng, init_rng = jax_random.split(rng) self._rngs = jax_random.split(rng, n_devices) first_shape = inputs.input_shape[0] # If the inputs are a tuple/list, add [None] (batch) to each element. if isinstance(first_shape, (list, tuple)): model_input_shape = tuple( tuple([None] + list(shape)) for shape in inputs.input_shape) model_target_shape = tuple( tuple([None] + list(shape)) for shape in inputs.target_shape) else: # Otherwise just add [None] to the input shape. model_input_shape = tuple([None] + list(inputs.input_shape)) model_target_shape = tuple([None] + list(inputs.target_shape)) # Change all None to 1 in input and target shape. model_input_shape = layers.nested_map(model_input_shape, lambda x: x if x else 1) model_target_shape = layers.nested_map(model_target_shape, lambda x: x if x else 1) def new_opt_state_and_model_state(input_shape, input_dtype, target_shape, target_dtype, rng): """Returns optimizer and model states suitable for training a model.""" # Combine inputs and targets on the stack. if not isinstance(input_dtype, (list, tuple)): input_dtype = [input_dtype] input_shape = [input_shape] if not isinstance(target_dtype, (list, tuple)): target_dtype = [target_dtype] target_shape = [target_shape] full_type = list(input_dtype) + list(target_dtype) full_shape = list(input_shape) + list(target_shape) if self._has_weights: full_shape += list(target_shape) full_type += [np.float32 for _ in target_dtype] # We need to create a new model instance and not reuse `model_train` here, # because `m.initialize` puts cached parameter values in `m` and hence the # next call of `m.initialize` will give wrong results. m = layers.Serial([model(mode="train"), loss_fn]) params, state = m.initialize_once(full_shape, full_type, rng) (slots, opt_params) = opt.tree_init(params) return (OptState(params, slots, opt_params), state) if _is_jit_init(): # JIT parameter initialization to avoid memory fragmentation new_opt_state_and_model_state = backend.jit( new_opt_state_and_model_state, static_argnums=(0, 1, 2, 3)) self._new_opt_state_and_model_state = ( lambda: new_opt_state_and_model_state( # pylint: disable=g-long-lambda model_input_shape, self._inputs.input_dtype, model_target_shape, self._inputs.target_dtype, init_rng)) # jit model_predict and update so they're fast # TODO(lukaszkaiser): the code below creates a layer computing # multiple metrics from a single model output; re-factor for clarity. dup_layer = layers.Dup3() if self._has_weights else layers.Dup2() def lower(layer): """Apply layer below the current inputs, targets, and possibly weights.""" if self._has_weights: # Apply layer below inputs, targets, and loss weights. return layers.Parallel([], [], [], layer) else: # Apply layer below inputs and targets. return layers.Parallel([], [], layer) metrics_layer = [] self._metrics = list(sorted(_METRICS.keys())) for i, m in enumerate(reversed(self._metrics)): metric = _METRICS[m](has_weights=self._has_weights, mask_id=self._mask_id) if i != len(self._metrics) - 1: metrics_layer.append(dup_layer) metrics_layer.append(lower(metric)) else: metrics_layer.append(metric) # TODO(lukaszkaiser): clean this up once layer API stabilizes. # For now, we need to initialize metric layers somehow, so here we go. # We assume that they do not have any parameters, so this is a dummy. dummy_shape = ((1, 2), (1, ), (1, )) if self._has_weights else ((1, 2), (1, )) dummy_type = [np.float32] * (3 if self._has_weights else 2) metrics_layer = layers.Serial(metrics_layer) metrics_params, metrics_state = metrics_layer.initialize_once( dummy_shape, tuple(dummy_type), init_rng) self._metrics_params = layers.nested_map(metrics_params, self._maybe_replicate) self._metrics_state = layers.nested_map(metrics_state, self._maybe_replicate) self._jit_eval = _jit_predict_fn(model_predict_eval, metrics_layer, n_devices) self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices) self._model_train = model_train self._model_predict_eval = model_predict_eval self._loss_fn = loss_fn # TODO(pkozakowski): "Learning rate schedules" are currently able to control # control all optimizer parameters and model state, so let's rename them # accordingly. self._lr_schedule = lr_schedule if nontrainable_param_map is None: nontrainable_param_map = {} self._nontrainable_param_map = nontrainable_param_map # Those fields will be set in reset(). self._output_dir = None self._train_sw = None self._eval_sw = None self._history = None self._lr_fn = None self._opt_state = None self._step = None self._model_state = None if output_dir is not None: self.reset(output_dir)