def test_pure_mlp_forward_shape(self): """Run the PureMLP model forward and check output shape.""" input_signature = ShapeDtype((7, 28, 28, 3)) model = mlp.PureMLP(layer_widths=(32, 16, 8)) final_shape = tl.check_shape_agreement(model, input_signature) self.assertEqual((7, 8), final_shape)
def test_lstm_cell(self): self._test_cell_runs(rnn.LSTMCell(9), input_signature=(ShapeDtype( (8, 9)), ShapeDtype((8, 18))), output_shape=((8, 9), (8, 18)))
def test_rnnlm_forward_shape(self): """Runs the RNN LM forward and checks output shape.""" input_signature = ShapeDtype((3, 28), dtype=math.numpy.int32) model = rnn.RNNLM(vocab_size=20, d_model=16) final_shape = tl.check_shape_agreement(model, input_signature) self.assertEqual((3, 28, 20), final_shape)
def test_branch_one_layer(self): layer = cb.Branch(divide_by(0.5)) input_signature = ShapeDtype((3, 2)) expected_shape = (3, 2) output_shape = base.check_shape_agreement(layer, input_signature) self.assertEqual(output_shape, expected_shape)
def test_conv_gru_cell(self): self._test_cell_runs(rnn.ConvGRUCell(9, kernel_size=(3, 3)), input_signature=ShapeDtype((8, 1, 7, 9)), output_shape=(8, 1, 7, 9))
def test_serial_no_op_list(self): layer = cb.Serial([]) input_signature = (ShapeDtype((3, 2)), ShapeDtype((4, 7))) expected_shape = ((3, 2), (4, 7)) output_shape = base.check_shape_agreement(layer, input_signature) self.assertEqual(output_shape, expected_shape)
def test_branch_noop_dup(self): layer = cb.Branch([], cb.Dup()) input_signature = ShapeDtype((3, 2)) expected_shape = ((3, 2), (3, 2), (3, 2)) output_shape = base.check_shape_agreement(layer, input_signature) self.assertEqual(output_shape, expected_shape)
def test_accuracy_scalar(self): input_signature = (ShapeDtype((29, 4, 4, 20)), ShapeDtype( (29, 4, 4)), ShapeDtype((29, 4, 4))) result_shape = base.check_shape_agreement(metrics.AccuracyScalar(), input_signature) self.assertEqual(result_shape, ())
def test_resnet(self): input_signature = ShapeDtype((3, 256, 256, 3)) model = resnet.Resnet50(d_hidden=8, n_output_classes=10) final_shape = tl.check_shape_agreement(model, input_signature) self.assertEqual((3, 10), final_shape)
def test_weighted_mean_shape(self): input_signature = (ShapeDtype( (29, 4, 4, 20)), ShapeDtype((29, 4, 4, 20))) result_shape = base.check_shape_agreement(metrics._WeightedMean(), input_signature) self.assertEqual(result_shape, ())
def test_cross_entropy_loss(self): input_signature = (ShapeDtype((29, 4, 4, 20)), ShapeDtype( (29, 4, 4)), ShapeDtype((29, 4, 4))) result_shape = base.check_shape_agreement(metrics.CrossEntropyLoss(), input_signature) self.assertEqual(result_shape, ())
def test_layer_norm_shape(self): input_signature = ShapeDtype((29, 5, 7, 20)) result_shape = base.check_shape_agreement( normalization.LayerNorm(), input_signature) self.assertEqual(result_shape, input_signature.shape)
def test_image_fec(self): input_signature = ShapeDtype((3, 256, 256, 3)) model = ImageFEC(d_hidden=8, n_output_classes=10) final_shape = tl.check_shape_agreement(model, input_signature) self.assertEqual((3, 10), final_shape)
def test_mlp_forward_shape(self): """Run the MLP model forward and check output shape.""" input_signature = ShapeDtype((3, 28, 28, 1)) model = mlp.MLP(d_hidden=32, n_output_classes=10) final_shape = tl.check_shape_agreement(model, input_signature) self.assertEqual((3, 10), final_shape)
def test_drop(self): layer = cb.Drop() input_signature = ShapeDtype((3, 2)) expected_shape = () output_shape = base.check_shape_agreement(layer, input_signature) self.assertEqual(output_shape, expected_shape)
def test_wide_resnet(self): input_signature = ShapeDtype((3, 32, 32, 3)) model = resnet.WideResnet(n_blocks=1, n_output_classes=10) final_shape = tl.check_shape_agreement(model, input_signature) self.assertEqual((3, 10), final_shape)
def test_swap(self): layer = cb.Swap() input_signature = (ShapeDtype((3, 2)), ShapeDtype((4, 7))) expected_shape = ((4, 7), (3, 2)) output_shape = base.check_shape_agreement(layer, input_signature) self.assertEqual(output_shape, expected_shape)
def test_combined_loss(self): B, T, A, OBS = 2, 10, 2, (28, 28, 3) # pylint: disable=invalid-name batch_observation_shape = (1, 1) + OBS make_net = lambda: ppo.policy_and_value_net( # pylint: disable=g-long-lambda n_controls=1, n_actions=A, vocab_size=None, bottom_layers_fn=lambda: [layers.Flatten(n_axes_to_keep=2)], two_towers=True, ) net = make_net() input_signature = ShapeDtype(batch_observation_shape) old_params, _ = net.init(input_signature) new_params, state = make_net().init(input_signature) # Generate a batch of observations. observations = np.random.uniform(size=(B, T + 1) + OBS) actions = np.random.randint(0, A, size=(B, T + 1)) rewards = np.random.uniform(0, 1, size=(B, T)) mask = np.ones_like(rewards) # Just test that this computes at all. (new_log_probabs, value_predictions_new) = (net(observations, weights=new_params, state=state)) (old_log_probabs, value_predictions_old) = (net(observations, weights=old_params, state=state)) gamma = 0.99 lambda_ = 0.95 epsilon = 0.2 value_weight = 1.0 entropy_weight = 0.01 nontrainable_params = { 'gamma': gamma, 'lambda': lambda_, 'epsilon': epsilon, 'value_weight': value_weight, 'entropy_weight': entropy_weight, } rewards_to_actions = np.eye(value_predictions_old.shape[1]) (value_loss_1, _) = ppo.value_loss_given_predictions( value_predictions_new, rewards, mask, gamma=gamma, value_prediction_old=value_predictions_old, epsilon=epsilon) (ppo_loss_1, _) = ppo.ppo_loss_given_predictions(new_log_probabs, old_log_probabs, value_predictions_old, actions, rewards_to_actions, rewards, mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon) (combined_loss, (ppo_loss_2, value_loss_2, entropy_bonus), _, state) = (ppo.combined_loss(new_params, old_log_probabs, value_predictions_old, net, observations, actions, rewards_to_actions, rewards, mask, nontrainable_params=nontrainable_params, state=state)) # Test that these compute at all and are self consistent. self.assertGreater(entropy_bonus, 0.0) self.assertNear(value_loss_1, value_loss_2, 1e-6) self.assertNear(ppo_loss_1, ppo_loss_2, 1e-6) self.assertNear( combined_loss, ppo_loss_2 + (value_weight * value_loss_2) - (entropy_weight * entropy_bonus), 1e-6)
def test_serial_div_div(self): layer = cb.Serial(divide_by(2.0), divide_by(5.0)) input_signature = ShapeDtype((3, 2)) expected_shape = (3, 2) output_shape = base.check_shape_agreement(layer, input_signature) self.assertEqual(output_shape, expected_shape)
def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir=None, random_seed=None, n_devices=None, save_steps=None, should_save_checkpoints=True, should_write_summaries=True, has_weights=False, nontrainable_param_map=None, mask_id=None, metrics=None): if backend.get_name() == 'jax': self._host_id = jax.host_id() self._host_count = jax.host_count() else: self._host_id = 0 self._host_count = 1 self._is_chief = (self._host_id == 0) if save_steps is None: save_steps = [] self._save_steps = save_steps self._should_save_checkpoints = should_save_checkpoints self._should_write_summaries = should_write_summaries self._has_weights = has_weights self._mask_id = mask_id self._metrics_dict = _METRICS if metrics is None else metrics loss_fn = loss_fn(has_weights=has_weights, mask_id=mask_id) device_count = backend.device_count() n_devices = n_devices or device_count # TODO(lukaszkaiser): remove this restriction when possible. if n_devices != device_count and backend.get_name() == 'jax': raise ValueError('JAX cannot work yet with n_devices != all devices: ' '%d != %d' % (n_devices, device_count)) self._n_devices = n_devices # Simple differential seeding of RNG across hosts by host_id and time. if random_seed is None and self._host_count > 1: _, random_seed = divmod(int(time.time() * 1e6) + int(self._host_id * 1e6), 2**32) rng = get_random_number_generator_and_set_seed(random_seed) inputs = inputs(n_devices) self._inputs = inputs # Initialize the learning rate to a dummy value. It will be set in reset(). opt = optimizer(learning_rate=0.0) # Setup the model. model_train = model(mode='train') model_predict_eval = model(mode='eval') # Setup state. rng, init_rng = jax_random.split(rng) self._rngs = np.stack(jax_random.split(rng, n_devices)) first_shape = inputs.input_shape[0] # If the inputs are a tuple/list, add [None] (batch) to each element. if isinstance(first_shape, (list, tuple)): model_input_shape = tuple( tuple([None] + list(shape)) for shape in inputs.input_shape) model_target_shape = tuple( tuple([None] + list(shape)) for shape in inputs.target_shape) else: # Otherwise just add [None] to the input shape. model_input_shape = tuple([None] + list(inputs.input_shape)) model_target_shape = tuple([None] + list(inputs.target_shape)) # Change all None to 1 in input and target shape. model_input_shape = backend.nested_map(lambda x: x or 1, model_input_shape) model_target_shape = backend.nested_map(lambda x: x or 1, model_target_shape) def new_opt_state_and_model_state(input_shape, input_dtype, target_shape, target_dtype, rng): """Returns optimizer and model states suitable for training a model.""" # Combine inputs and targets on the stack. if not isinstance(input_dtype, (list, tuple)): input_dtype = [input_dtype] input_shape = [input_shape] if not isinstance(target_dtype, (list, tuple)): target_dtype = [target_dtype] target_shape = [target_shape] dtypes = list(input_dtype) + list(target_dtype) shapes = list(input_shape) + list(target_shape) if self._has_weights: shapes += list(target_shape) dtypes += [np.float32 for _ in target_dtype] input_signature = tuple(ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes)) # We need to create a new model instance and not reuse `model_train` here, # because `m.initialize` puts cached parameter values in `m` and hence the # next call of `m.initialize` will give wrong results. m = tl.Serial(model(mode='train'), loss_fn) m._set_rng_recursive(rng) # pylint: disable=protected-access weights, state = m.init(input_signature) (slots, opt_params) = opt.tree_init(weights) return (OptState(weights, slots, opt_params), state) if _is_jit_init(): # JIT parameter initialization to avoid memory fragmentation new_opt_state_and_model_state = backend.jit(new_opt_state_and_model_state, static_argnums=(0, 1, 2, 3)) self._new_opt_state_and_model_state = ( lambda: new_opt_state_and_model_state( # pylint: disable=g-long-lambda model_input_shape, self._inputs.input_dtype, model_target_shape, self._inputs.target_dtype, init_rng)) # jit model_predict and update so they're fast # TODO(lukaszkaiser): the code below creates a layer computing # multiple metrics from a single model output; re-factor for clarity. dup_layer = tl.Dup3() if self._has_weights else tl.Dup2() def lower(layer): """Apply layer below the current inputs, targets, and possibly weights.""" if self._has_weights: # Apply layer below inputs, targets, and loss weights. return tl.Parallel([], [], [], layer) else: # Apply layer below inputs and targets. return tl.Parallel([], [], layer) metrics_layer = [] self._metrics = list(sorted(self._metrics_dict.keys())) for i, m in enumerate(reversed(self._metrics)): metric = self._metrics_dict[m](has_weights=self._has_weights, mask_id=self._mask_id) if i != len(self._metrics) - 1: metrics_layer.append(dup_layer) metrics_layer.append(lower(metric)) else: metrics_layer.append(metric) # TODO(lukaszkaiser): clean this up once layer API stabilizes. # For now, we need to initialize metric layers somehow, so here we go. # We assume that they do not have any parameters, so this is a dummy. dummy_shapes = ((1, 2), (1,), (1,)) if self._has_weights else ((1, 2), (1,)) dummy_dtypes = [np.float32] * (3 if self._has_weights else 2) dummy_signature = tuple(ShapeDtype(s, d) for s, d in zip(dummy_shapes, dummy_dtypes)) metrics_layer = tl.Serial(metrics_layer) metrics_layer._set_rng_recursive(init_rng) # pylint: disable=protected-access metrics_weights, metrics_state = ( metrics_layer.init(dummy_signature)) self._metrics_weights = self._for_n_devices(metrics_weights) self._metrics_state = self._for_n_devices(metrics_state) self._jit_eval = _jit_predict_fn( model_predict_eval, metrics_layer, n_devices) self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices) self._model_train = model_train self._model_predict_eval = model_predict_eval self._loss_fn = loss_fn # TODO(pkozakowski): "Learning rate schedules" are currently able to control # control all optimizer parameters and model state, so let's rename them # accordingly. self._lr_schedule = lr_schedule if nontrainable_param_map is None: nontrainable_param_map = {} self._nontrainable_param_map = nontrainable_param_map # Those fields will be set in reset(). self._output_dir = None self._train_sw = None self._eval_sw = None self._history = None self._lr_fn = None self._opt_state = None self._step = None self._model_state = None if output_dir is not None: self.reset(output_dir)
def test_branch_add_div(self): layer = cb.Branch(cb.Add(), divide_by(0.5)) input_signature = (ShapeDtype((3, 2)), ShapeDtype((3, 2))) expected_shape = ((3, 2), (3, 2)) output_shape = base.check_shape_agreement(layer, input_signature) self.assertEqual(output_shape, expected_shape)
def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir=None, random_seed=None, n_devices=None, checkpoints_at=None, should_save_checkpoints=True, should_write_summaries=True, has_weights=False, nontrainable_param_map=None, id_to_mask=None, metrics=None, checkpoint_highest=None, checkpoint_lowest=None): self._is_chief, self._n_devices, rng = (self._init_host_and_devices( n_devices, random_seed)) self._should_save_checkpoints = should_save_checkpoints and self._is_chief self._checkpoints_at = checkpoints_at or [] self._should_write_summaries = should_write_summaries if not output_dir: self._should_save_checkpoints = False self._should_write_summaries = False self._checkpoint_highest = checkpoint_highest self._checkpoint_lowest = checkpoint_lowest self._has_weights = has_weights self._id_to_mask = id_to_mask self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS loss_fn = loss_fn(has_weights=has_weights, id_to_mask=id_to_mask) # Inputs is either an Inputs instance or a function that returns it. self._inputs = inputs if callable( inputs): # If we pass a function, e.g., through gin, call it. self._inputs = inputs() # Initialize the learning rate to a dummy value. It will be set in reset(). opt = optimizer(learning_rate=0.0) # Setup the model. model_train = model(mode='train') model_predict_eval = model(mode='eval') # Setup state. rng, init_rng = jax_random.split(rng) self._rngs = np.stack(jax_random.split(rng, self._n_devices)) # If the inputs are a tuple/list, add [None] (batch) to each element. if self._inputs.input_shape and isinstance(self._inputs.input_shape[0], (list, tuple)): model_input_shape = tuple( tuple([None] + list(shape)) for shape in self._inputs.input_shape) else: # Otherwise just add [None] to the input shape. model_input_shape = tuple([None] + list(self._inputs.input_shape)) # Same for targets. if self._inputs.target_shape and isinstance( self._inputs.target_shape[0], (list, tuple)): model_target_shape = tuple( tuple([None] + list(shape)) for shape in self._inputs.target_shape) else: model_target_shape = tuple([None] + list(self._inputs.target_shape)) # Change all None to 1 in input and target shape. model_input_shape = math.nested_map(lambda x: x or 1, model_input_shape) model_target_shape = math.nested_map(lambda x: x or 1, model_target_shape) def new_opt_state_and_model_state(shape_dtype, rng): """Returns optimizer and model states suitable for training a model.""" # Combine inputs and targets on the stack. shapes, dtypes = shape_dtype input_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes)) # We need to create a new model instance and not reuse `model_train` here, # because `m.initialize` puts cached parameter values in `m` and hence the # next call of `m.initialize` will give wrong results. m = tl.Serial(model(mode='train'), loss_fn) m._set_rng_recursive(rng) # pylint: disable=protected-access weights, state = m.init(input_signature) (slots, opt_params) = opt.tree_init(weights) return (OptState(weights, slots, opt_params), state) if _is_jit_init(): # JIT parameter initialization to avoid memory fragmentation new_opt_state_and_model_state = math.jit( new_opt_state_and_model_state, static_argnums=(0, )) self._new_opt_state_and_model_state = ( lambda: new_opt_state_and_model_state( # pylint: disable=g-long-lambda self._inputs.example_shape_dtype, init_rng)) # Arrange and initialize metrics layers. self._metrics = list(sorted(self._metrics_dict.keys())) metrics_layers = [ self._metrics_dict[m](has_weights=self._has_weights, id_to_mask=self._id_to_mask) for m in self._metrics ] metrics_in_parallel = tl.Branch(*metrics_layers) metrics_in_parallel._set_rng_recursive(init_rng) # pylint: disable=protected-access example_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(*self._inputs.example_shape_dtype)) model_predict_eval.init(example_signature) output_signature = model_predict_eval.output_signature( example_signature) m_weights, m_state = metrics_in_parallel.init(output_signature) self._metrics_weights = self._for_n_devices(m_weights) self._metrics_state = self._for_n_devices(m_state) # Jit model_predict and update so they're fast. self._jit_eval = _jit_predict_fn(model_predict_eval, metrics_in_parallel, self._n_devices) self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, self._n_devices) self._model_train = model_train self._model_predict_eval = model_predict_eval self._loss_fn = loss_fn # TODO(pkozakowski): "Learning rate schedules" are currently able to control # control all optimizer parameters and model state, so let's rename them # accordingly. self._lr_schedule = lr_schedule if nontrainable_param_map is None: nontrainable_param_map = {} self._nontrainable_param_map = nontrainable_param_map # Those fields will be set in reset(). self._output_dir = None self._train_sw = None self._eval_sw = None self._history = None self._lr_fn = None self._opt_state = None self._step = None self._model_state = None self.reset(output_dir)
def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir=None, random_seed=None, n_devices=None, checkpoints_at=None, should_save_checkpoints=True, should_write_summaries=True, metrics=None, checkpoint_highest=None, checkpoint_lowest=None): self._is_chief, _, self._n_devices, rng = ( training.init_host_and_devices(n_devices, random_seed)) self._should_save_checkpoints = should_save_checkpoints and self._is_chief self._checkpoints_at = checkpoints_at if checkpoints_at is not None else [] self._should_write_summaries = should_write_summaries if not output_dir: self._should_save_checkpoints = False self._should_write_summaries = False self._checkpoint_highest = checkpoint_highest self._checkpoint_lowest = checkpoint_lowest self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS # Inputs is either an Inputs instance or a function that returns it. self._inputs = inputs if callable( inputs): # If we pass a function, e.g., through gin, call it. self._inputs = inputs() # Initialize the learning rate to a dummy value. It will be set in reset(). opt = optimizer(learning_rate=0.0) # Setup the model. model_train = model(mode='train') model_predict_eval = model(mode='eval') self._model_with_loss = tl.Serial(model_train, loss_fn) # Setup state. rng, init_rng = jax_random.split(rng) self._rngs = np.stack(jax_random.split(rng, self._n_devices)) shapes, dtypes = self._inputs.example_shape_dtype input_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes)) def new_opt_state_and_model_state(rng): """Returns optimizer and model states suitable for training a model.""" weights, state = self._model_with_loss.init(input_signature, rng=rng) (slots, opt_params) = opt.tree_init(weights) return (OptState(weights, slots, opt_params), state) if fastmath.is_backend(fastmath.Backend.JAX): # JIT parameter initialization to avoid memory fragmentation new_opt_state_and_model_state = ( fastmath.jit(new_opt_state_and_model_state)) self._new_opt_state_and_model_state = ( lambda: new_opt_state_and_model_state(init_rng)) # Arrange and initialize metrics layers. self._metrics = list(sorted(self._metrics_dict.keys())) metrics_layers = [self._metrics_dict[m] for m in self._metrics] metrics_in_parallel = tl.Branch(*metrics_layers) metrics_in_parallel.rng = init_rng example_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(*self._inputs.example_shape_dtype)) model_predict_eval.init(example_signature) self._input_signature = example_signature output_signature = model_predict_eval.output_signature( example_signature) m_weights, m_state = metrics_in_parallel.init(output_signature) self._metrics_weights = self._for_n_devices(m_weights) self._metrics_state = self._for_n_devices(m_state) # Jit model_predict and update so they're fast. self._jit_eval = _jit_predict_fn(model_predict_eval, metrics_in_parallel, self._n_devices) self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, self._n_devices) self._model_train = model_train self._model_predict_eval = model_predict_eval self._loss_fn = loss_fn self._lr_schedule = lr_schedule # Those fields will be set in reset(). self._output_dir = None self._train_sw = None self._eval_sw = None self._history = None self._opt_state = None self._step = None self._model_state = None self.reset(output_dir)
def test_parallel_dup_dup(self): layer = cb.Parallel(cb.Dup(), cb.Dup()) input_signature = (ShapeDtype((3, 2)), ShapeDtype((4, 7))) expected_shape = ((3, 2), (3, 2), (4, 7), (4, 7)) output_shape = base.check_shape_agreement(layer, input_signature) self.assertEqual(output_shape, expected_shape)
def test_gru_cell(self): self._test_cell_runs(rnn.GRUCell(9), input_signature=(ShapeDtype( (8, 7, 9)), ShapeDtype((8, 7, 9))), output_shape=((8, 7, 9), (8, 7, 9)))
def test_parallel_div_div(self): layer = cb.Parallel(divide_by(0.5), divide_by(3.0)) input_signature = (ShapeDtype((3, 2)), ShapeDtype((4, 7))) expected_shape = ((3, 2), (4, 7)) output_shape = base.check_shape_agreement(layer, input_signature) self.assertEqual(output_shape, expected_shape)
def test_sru(self): self._test_cell_runs(rnn.SRU(7), input_signature=ShapeDtype((8, 9, 7)), output_shape=(8, 9, 7))
def test_parallel_no_ops(self): layer = cb.Parallel([], None) input_signature = (ShapeDtype((3, 2)), ShapeDtype((4, 7))) expected_shape = ((3, 2), (4, 7)) output_shape = base.check_shape_agreement(layer, input_signature) self.assertEqual(output_shape, expected_shape)
def test_constructor_and_read_properties(self): sd = ShapeDtype((2, 3), onp.int32) self.assertEqual(sd.shape, (2, 3)) self.assertEqual(sd.dtype, onp.int32)
def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir=None, random_seed=None, n_devices=None, checkpoints_at=None, should_save_checkpoints=True, should_write_summaries=True, has_weights=False, nontrainable_param_map=None, mask_id=None, metrics=None): self._is_chief, self._n_devices, rng = ( self._init_host_and_devices(n_devices, random_seed)) self._should_save_checkpoints = should_save_checkpoints and self._is_chief self._checkpoints_at = checkpoints_at or [] self._should_write_summaries = should_write_summaries self._has_weights = has_weights self._mask_id = mask_id self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS loss_fn = loss_fn(has_weights=has_weights, mask_id=mask_id) inputs = inputs(self._n_devices) self._inputs = inputs # Initialize the learning rate to a dummy value. It will be set in reset(). opt = optimizer(learning_rate=0.0) # Setup the model. model_train = model(mode='train') model_predict_eval = model(mode='eval') # Setup state. rng, init_rng = jax_random.split(rng) self._rngs = np.stack(jax_random.split(rng, self._n_devices)) first_shape = inputs.input_shape[0] # If the inputs are a tuple/list, add [None] (batch) to each element. if isinstance(first_shape, (list, tuple)): model_input_shape = tuple( tuple([None] + list(shape)) for shape in inputs.input_shape) model_target_shape = tuple( tuple([None] + list(shape)) for shape in inputs.target_shape) else: # Otherwise just add [None] to the input shape. model_input_shape = tuple([None] + list(inputs.input_shape)) model_target_shape = tuple([None] + list(inputs.target_shape)) # Change all None to 1 in input and target shape. model_input_shape = backend.nested_map(lambda x: x or 1, model_input_shape) model_target_shape = backend.nested_map(lambda x: x or 1, model_target_shape) def new_opt_state_and_model_state(input_shape, input_dtype, target_shape, target_dtype, rng): """Returns optimizer and model states suitable for training a model.""" # Combine inputs and targets on the stack. if not isinstance(input_dtype, (list, tuple)): input_dtype = [input_dtype] input_shape = [input_shape] if not isinstance(target_dtype, (list, tuple)): target_dtype = [target_dtype] target_shape = [target_shape] dtypes = list(input_dtype) + list(target_dtype) shapes = list(input_shape) + list(target_shape) if self._has_weights: shapes += list(target_shape) dtypes += [np.float32 for _ in target_dtype] input_signature = tuple(ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes)) # We need to create a new model instance and not reuse `model_train` here, # because `m.initialize` puts cached parameter values in `m` and hence the # next call of `m.initialize` will give wrong results. m = tl.Serial(model(mode='train'), loss_fn) m._set_rng_recursive(rng) # pylint: disable=protected-access weights, state = m.init(input_signature) (slots, opt_params) = opt.tree_init(weights) return (OptState(weights, slots, opt_params), state) if _is_jit_init(): # JIT parameter initialization to avoid memory fragmentation new_opt_state_and_model_state = backend.jit(new_opt_state_and_model_state, static_argnums=(0, 1, 2, 3)) self._new_opt_state_and_model_state = ( lambda: new_opt_state_and_model_state( # pylint: disable=g-long-lambda model_input_shape, self._inputs.input_dtype, model_target_shape, self._inputs.target_dtype, init_rng)) # Arrange and initialize metrics layers. self._metrics = list(sorted(self._metrics_dict.keys())) metrics_layers = [self._metrics_dict[m](has_weights=self._has_weights, mask_id=self._mask_id) for m in self._metrics] metrics_in_parallel = tl.Branch(*metrics_layers) # TODO(lukaszkaiser): clean this up once layer API stabilizes. # For now, we need to initialize metric layers somehow, so here we go. # We assume that they do not have any parameters, so this is a dummy. dummy_shapes = ((1, 2), (1,), (1,)) if self._has_weights else ((1, 2), (1,)) dummy_signature = tuple(ShapeDtype(s) for s in dummy_shapes) metrics_in_parallel._set_rng_recursive(init_rng) # pylint: disable=protected-access m_weights, m_state = metrics_in_parallel.init(dummy_signature) self._metrics_weights = self._for_n_devices(m_weights) self._metrics_state = self._for_n_devices(m_state) # Jit model_predict and update so they're fast. self._jit_eval = _jit_predict_fn( model_predict_eval, metrics_in_parallel, self._n_devices) self._jit_update_fn = _jit_update_fn( model_train, loss_fn, opt, self._n_devices) self._model_train = model_train self._model_predict_eval = model_predict_eval self._loss_fn = loss_fn # TODO(pkozakowski): "Learning rate schedules" are currently able to control # control all optimizer parameters and model state, so let's rename them # accordingly. self._lr_schedule = lr_schedule if nontrainable_param_map is None: nontrainable_param_map = {} self._nontrainable_param_map = nontrainable_param_map # Those fields will be set in reset(). self._output_dir = None self._train_sw = None self._eval_sw = None self._history = None self._lr_fn = None self._opt_state = None self._step = None self._model_state = None if output_dir is not None: self.reset(output_dir)