def test_initialize_rescale(self): """Test rescaling a single layer of a model.""" input_shape = (28, 28, 1) output_shape = (10,) model_str = 'fully_connected' model_cls = models.get_model(model_str) model_hps = models.get_model_hparams(model_str) loss_name = 'cross_entropy' metrics_name = 'classification_metrics' hps = copy.copy(model_hps) hps.update({'output_shape': output_shape}) rng = jax.random.PRNGKey(0) model = model_cls(hps, {}, loss_name, metrics_name) initializer = initializers.get_initializer('noop') rng, init_rng = jax.random.split(rng) # First initialize with no rescale. flax_module, _ = trainer.initialize( model.flax_module_def, initializer, model.loss_fn, input_shape, output_shape, hps, init_rng, metrics_logger=None) utils.log_pytree_shape_and_statistics(flax_module.params) # Now rescale a layer by 100. rescale_factor = 100 hps.layer_rescale_factors = { '/Dense_1/kernel': rescale_factor, } rescaled_module, _ = trainer.initialize( model.flax_module_def, initializer, model.loss_fn, input_shape, output_shape, hps, init_rng, metrics_logger=None) # Check the right variable is rescaled v1 = flax_module.params['Dense_1']['kernel'] v2 = rescaled_module.params['Dense_1']['kernel'] diff = np.linalg.norm(v1.reshape(-1) * rescale_factor - v2.reshape(-1)) self.assertAlmostEqual(diff, 0.0) # Check that other variables are the same v1 = flax_module.params['Dense_2']['kernel'] v2 = rescaled_module.params['Dense_2']['kernel'] diff = np.linalg.norm(v1.reshape(-1) - v2.reshape(-1)) self.assertAlmostEqual(diff, 0.0)
def setUp(self): super(CheckpointTest, self).setUp() self.test_dir = tempfile.mkdtemp() loss_name = 'cross_entropy' metrics_name = 'classification_metrics' model = models.get_model('fully_connected') model_hps = models.get_model_hparams('fully_connected') hps = copy.copy(model_hps) hps.update({'output_shape': OUTPUT_SHAPE}) rng = jax.random.PRNGKey(0) model = model(hps, {}, loss_name, metrics_name) xs = jnp.array(np.random.normal(size=INPUT_SHAPE)) rng, params_rng = jax.random.split(rng) _, self.flax_module = model.flax_module_def.create(params_rng, xs)
def setUp(self): super(CheckpointTest, self).setUp() self.test_dir = tempfile.mkdtemp() loss_name = 'cross_entropy' metrics_name = 'classification_metrics' model = models.get_model('fully_connected') model_hps = models.get_model_hparams('fully_connected') hps = copy.copy(model_hps) hps.update({'output_shape': OUTPUT_SHAPE}) rng = jax.random.PRNGKey(0) model = model(hps, {}, loss_name, metrics_name) xs = jnp.array(np.random.normal(size=INPUT_SHAPE)) rng, params_rng = jax.random.split(rng) model_init_fn = jax.jit( functools.partial(model.flax_module.init, train=False)) init_dict = model_init_fn({'params': params_rng}, xs) self.params = init_dict['params']
def _load_model(model_name): """Load a test model.""" rng = jax.random.PRNGKey(0) model_cls = models.get_model(model_name) loss_name = 'cross_entropy' metrics_name = 'classification_metrics' model_hps = models.get_model_hparams(model_name) hps = copy.copy(model_hps) hps.update({'output_shape': OUTPUT_SHAPE}) model = model_cls(hps, {}, loss_name, metrics_name) input_shape = (BATCH_SIZE, ) + MODEL_TO_INPUT_SHAPE[model_name] _, flax_module = model.flax_module_def.create_by_shape(rng, [input_shape], train=True) utils.log_pytree_shape_and_statistics(flax_module.params) return flax_module, input_shape
def test_graph_model(self): """Test forward pass of the GNN model.""" edge_input_shape = (5,) node_input_shape = (5,) output_shape = (5,) model_str = 'gnn' model_hps = models.get_model_hparams(model_str) model_hps.update({'output_shape': output_shape, 'latent_dim': 10, 'hidden_dims': (10,), 'batch_size': 5, 'normalizer': 'batch_norm'}) model_cls = models.get_model(model_str) rng = jax.random.PRNGKey(0) dropout_rng, params_rng = jax.random.split(rng) loss = 'sigmoid_binary_cross_entropy' metrics = 'binary_classification_metrics' model = model_cls(model_hps, {}, loss, metrics) num_graphs = 5 node_per_graph = 3 edge_per_graph = 9 inputs = jraph.get_fully_connected_graph( n_node_per_graph=node_per_graph, n_graph=num_graphs, node_features=np.ones((num_graphs * node_per_graph,) + node_input_shape), ) inputs = inputs._replace( edges=np.ones((num_graphs * edge_per_graph,) + edge_input_shape)) padded_inputs = jraph.pad_with_graphs(inputs, 20, 50, 7) model_init_fn = jax.jit( functools.partial(model.flax_module.init, train=False)) init_dict = model_init_fn({'params': params_rng}, padded_inputs) params = init_dict['params'] batch_stats = init_dict['batch_stats'] # Check that the forward pass works with mutated batch_stats. outputs, _ = model.flax_module.apply( {'params': params, 'batch_stats': batch_stats}, padded_inputs, mutable=['batch_stats'], rngs={'dropout': dropout_rng}, train=True) self.assertEqual(outputs.shape, (7,) + output_shape)
def _load_model(model_name): """Load a test model.""" rng = jax.random.PRNGKey(0) model_cls = models.get_model(model_name) loss_name = 'cross_entropy' metrics_name = 'classification_metrics' model_hps = models.get_model_hparams(model_name) hps = copy.copy(model_hps) hps.update({'output_shape': OUTPUT_SHAPE}) model = model_cls(hps, {}, loss_name, metrics_name) input_shape = (BATCH_SIZE, ) + MODEL_TO_INPUT_SHAPE[model_name] model_init_fn = jax.jit( functools.partial(model.flax_module.init, train=True)) init_dict = model_init_fn({'params': rng}, jnp.zeros(input_shape)) # Trainable model parameters. params = init_dict['params'] utils.log_pytree_shape_and_statistics(params) return model.flax_module, params, input_shape, hps
def test_classification_model(self, model_str): """Test forward pass of the image models.""" model_cls = models.get_model(model_str) model_hps = models.get_model_hparams(model_str) loss = 'cross_entropy' metrics = 'classification_metrics' hps = copy.copy(model_hps) hps.update({'output_shape': OUTPUT_SHAPE['classification']}) rng = jax.random.PRNGKey(0) dropout_rng, params_rng = jax.random.split(rng) model = model_cls(hps, {}, loss, metrics) xs = jnp.array(np.random.normal(size=INPUT_SHAPE['classification'])) model_init_fn = jax.jit( functools.partial(model.flax_module.init, train=False)) init_dict = model_init_fn({'params': params_rng}, xs) params = init_dict['params'] batch_stats = init_dict.get('batch_stats', {}) # Check that the forward pass works with mutated batch_stats. outputs, new_batch_stats = model.flax_module.apply( {'params': params, 'batch_stats': batch_stats}, xs, mutable=['batch_stats'], rngs={'dropout': dropout_rng}, train=True) self.assertEqual(outputs.shape, (INPUT_SHAPE['classification'][0], OUTPUT_SHAPE['classification'][-1])) # If it's a batch norm model check the batch stats changed. if batch_stats: bflat, _ = ravel_pytree(batch_stats) new_bflat, _ = ravel_pytree(new_batch_stats) self.assertFalse(jnp.array_equal(bflat, new_bflat)) # Test batch_norm in inference mode. outputs = model.flax_module.apply( {'params': params, 'batch_stats': batch_stats}, xs, train=False) self.assertEqual( outputs.shape, (INPUT_SHAPE['classification'][0], OUTPUT_SHAPE['classification'][-1]))
def test_autoencoder_model(self, model_str): """Test forward pass of the autoencoder models.""" model_cls = models.get_model(model_str) model_hps = models.get_model_hparams(model_str) loss = 'sigmoid_binary_cross_entropy' metrics = 'binary_autoencoder_metrics' hps = copy.copy(model_hps) hps.update({'output_shape': OUTPUT_SHAPE[model_str]}) params_rng = jax.random.PRNGKey(0) model = model_cls(hps, {}, loss, metrics) xs = jnp.array(np.random.normal(size=INPUT_SHAPE[model_str])) model_init_fn = jax.jit( functools.partial(model.flax_module.init, train=False)) init_dict = model_init_fn({'params': params_rng}, xs) params = init_dict['params'] batch_stats = init_dict.get('batch_stats', {}) # Check that the forward pass works with mutated batch_stats. outputs, new_batch_stats = model.flax_module.apply( {'params': params, 'batch_stats': batch_stats}, xs, mutable=['batch_stats'], train=True) self.assertEqual( outputs.shape, tuple([INPUT_SHAPE[model_str][0]] + list(OUTPUT_SHAPE[model_str]))) # If it's a batch norm model check the batch stats changed. if batch_stats: bflat, _ = ravel_pytree(batch_stats) new_bflat, _ = ravel_pytree(new_batch_stats) self.assertFalse(jnp.array_equal(bflat, new_bflat)) # Test batch_norm in inference mode. outputs = model.flax_module.apply( {'params': params, 'batch_stats': batch_stats}, xs, train=False) self.assertEqual( outputs.shape, tuple([INPUT_SHAPE[model_str][0]] + list(OUTPUT_SHAPE[model_str])))
def test_autoencoder_model(self, model_str): """Test forward pass of the autoencoder models.""" model_cls = models.get_model(model_str) model_hps = models.get_model_hparams(model_str) loss = 'sigmoid_binary_cross_entropy' metrics = 'binary_autoencoder_metrics' hps = copy.copy(model_hps) hps.update({'output_shape': OUTPUT_SHAPE[model_str]}) rng = jax.random.PRNGKey(0) model = model_cls(hps, {}, loss, metrics) xs = jnp.array(np.random.normal(size=INPUT_SHAPE[model_str])) rng, params_rng = jax.random.split(rng) with nn.stateful() as batch_stats: with nn.stochastic(params_rng): _, flax_module = model.flax_module_def.create(params_rng, xs) # Check that the forward pass works with mutated batch_stats. with nn.stateful(batch_stats) as new_batch_stats: with nn.stochastic(params_rng): outputs = flax_module(xs) self.assertEqual( outputs.shape, tuple([INPUT_SHAPE[model_str][0]] + list(OUTPUT_SHAPE[model_str]))) # If it's a batch norm model check the batch stats changed. if batch_stats.as_dict(): bflat, _ = ravel_pytree(batch_stats) new_bflat, _ = ravel_pytree(new_batch_stats) self.assertFalse(jnp.array_equal(bflat, new_bflat)) # Test batch_norm in inference mode. with nn.stateful(batch_stats, mutable=False): outputs = flax_module(xs, train=False) self.assertEqual( outputs.shape, tuple([INPUT_SHAPE[model_str][0]] + list(OUTPUT_SHAPE[model_str])))
def test_dlrm_model_trainer(self): """Tests that dlrm model training decreases loss.""" rng = jax.random.PRNGKey(1337) model_str = 'dlrm' dataset_str = 'criteo1tb' model_cls = models.get_model(model_str) model_hps = models.get_model_hparams(model_str) dataset_hps = datasets.get_dataset_hparams(dataset_str) dataset_hps.update({ 'batch_size': model_hps.batch_size, 'num_dense_features': model_hps.num_dense_features, 'vocab_sizes': model_hps.vocab_sizes, }) eval_num_batches = 5 eval_batch_size = dataset_hps.batch_size loss_name = 'sigmoid_binary_cross_entropy' metrics_name = 'binary_classification_metrics' dataset, dataset_meta_data = _get_fake_dlrm_dataset( dataset_hps.batch_size, eval_num_batches, dataset_hps) hps = copy.copy(model_hps) hps.update({ 'train_size': 15, 'valid_size': 10, 'test_size': 10, 'input_shape': (model_hps.num_dense_features + len(model_hps.vocab_sizes), ), 'output_shape': (1, ), 'l2_decay_factor': 1e-4, 'l2_decay_rank_threshold': 2, 'num_device_prefetches': 0, }) model = model_cls(hps, dataset_meta_data, loss_name, metrics_name) initializer = initializers.get_initializer('noop') metrics_logger, init_logger = utils.set_up_loggers(self.test_dir) _ = list( trainer.train( train_dir=self.test_dir, model=model, dataset_builder=lambda *unused_args, **unused_kwargs: dataset, initializer=initializer, num_train_steps=10, hps=hps, rng=rng, eval_batch_size=eval_batch_size, eval_num_batches=eval_num_batches, eval_train_num_batches=eval_num_batches, eval_frequency=2, checkpoint_steps=[], metrics_logger=metrics_logger, init_logger=init_logger)) with tf.io.gfile.GFile(os.path.join(self.test_dir, 'measurements.csv')) as f: df = pandas.read_csv(f) train_loss = df['train/ce_loss'].values self.assertLess(train_loss[-1], train_loss[0])
def test_graph_model_trainer(self): """Tests that graph model training decreases loss.""" rng = jax.random.PRNGKey(1337) model_str = 'gnn' model_cls = models.get_model(model_str) hps = models.get_model_hparams(model_str) hps.update({ 'batch_size': 2, 'input_edge_shape': (7, ), 'input_node_shape': (3, ), 'input_shape': (7, 3), 'output_shape': (5, ), 'model_dtype': 'float32', 'train_size': 15, 'valid_size': 10, 'test_size': 10, 'num_message_passing_steps': 1, 'normalizer': 'none', 'dropout_rate': 0.0, 'lr_hparams': { 'base_lr': 0.001, 'schedule': 'constant' }, 'num_device_prefetches': 0, }) eval_num_batches = 5 eval_batch_size = hps.batch_size loss_name = 'sigmoid_binary_cross_entropy' metrics_name = 'binary_classification_metrics_ogbg_map' dataset, dataset_meta_data = _get_fake_graph_dataset( batch_size=hps.batch_size, eval_num_batches=eval_num_batches, hps=hps) model = model_cls(hps, dataset_meta_data, loss_name, metrics_name) initializer = initializers.get_initializer('noop') metrics_logger, init_logger = utils.set_up_loggers(self.test_dir) _ = list( trainer.train( train_dir=self.test_dir, model=model, dataset_builder=lambda *unused_args, **unused_kwargs: dataset, initializer=initializer, num_train_steps=10, hps=hps, rng=rng, eval_batch_size=eval_batch_size, eval_num_batches=eval_num_batches, eval_train_num_batches=eval_num_batches, # Note that for some reason, moving from the deprecated to linen # Flax model API made training less stable so we need to eval more # frequently in order to get a `train_loss[0]` that is earlier in # training. eval_frequency=2, checkpoint_steps=[], metrics_logger=metrics_logger, init_logger=init_logger)) with tf.io.gfile.GFile(os.path.join(self.test_dir, 'measurements.csv')) as f: df = pandas.read_csv(f) train_loss = df['train/ce_loss'].values self.assertLess(train_loss[-1], train_loss[0])
def build_hparams(model_name, initializer_name, dataset_name, hparam_file, hparam_overrides): """Build experiment hyperparameters. Args: model_name: the string model name. initializer_name: the string initializer name. dataset_name: the string dataset name. hparam_file: the string to the hyperparameter override file (possibly on CNS). hparam_overrides: a dict of hyperparameter override names/values, or a JSON string encoding of this hyperparameter override dict. Note that this is applied after the hyperparameter file overrides. Returns: A ConfigDict of experiment hyperparameters. """ model_hps = models.get_model_hparams(model_name) initializer_hps = initializers.get_initializer_hparams(initializer_name) dataset_hps = datasets.get_dataset_hparams(dataset_name) merged_dict = {} hps_dicts = [ hps.to_dict() for hps in [model_hps, initializer_hps, dataset_hps] ] total_hps = 0 for hps_dict in hps_dicts: merged_dict.update(hps_dict) total_hps += len(hps_dict.keys()) # Check that all provided have no overlap. if total_hps != len(merged_dict.keys()): raise ValueError('There is overlap in the provided hparams.') # Convert to the Shallue and Lee label smoothing style. if merged_dict.get('use_shallue_label_smoothing', False): num_classes = merged_dict['output_shape'][-1] merged_dict['label_smoothing'] *= num_classes / float(num_classes - 1) merged = config_dict.ConfigDict(merged_dict) merged.lock() # Subconfig "opt_hparams" and "lr_hparams" are allowed to add new fields. for key in ['opt_hparams', 'lr_hparams']: if key not in merged: with merged.unlocked(): merged[key] = config_dict.ConfigDict() for key in ['opt_hparams', 'lr_hparams']: merged[key].unlock() if hparam_file: logging.info('Loading hparams from %s', hparam_file) with gfile.GFile(hparam_file, 'r') as f: hparam_dict = json.load(f) merged.update_from_flattened_dict(hparam_dict) if hparam_overrides: if isinstance(hparam_overrides, str): hparam_overrides = json.loads(hparam_overrides) # If the user is changing the learning rate schedule or optimizer. We must # wipe all of the keys from the old dictionary. if 'lr_hparams.schedule' in hparam_overrides and merged[ 'lr_hparams']['schedule'] != hparam_overrides[ 'lr_hparams.schedule']: merged['lr_hparams'] = {} if 'optimizer' in hparam_overrides and merged[ 'optimizer'] != hparam_overrides['optimizer']: merged['opt_hparams'] = {} merged.update_from_flattened_dict(hparam_overrides) return merged
def test_accumulation(self): """Test simple gradient accumulation.""" num_steps = 3 per_step_batch_size = 16 total_batch_size = 48 virtual_batch_size = 8 model_str = 'wide_resnet' # Pick a model with batch norm. model_cls = models.get_model(model_str) model_hps = models.get_model_hparams(model_str) dataset_name = 'cifar10' dataset_builder = datasets.get_dataset(dataset_name) hps = copy.copy(model_hps) hps.update(datasets.get_dataset_hparams(dataset_name)) # Compute updates using gradient accumulation. hps.update({ 'batch_size': per_step_batch_size, 'virtual_batch_size': virtual_batch_size, 'normalizer': 'virtual_batch_norm', 'total_accumulated_batch_size': total_batch_size, }) grad_acc_params, grad_acc_batch_stats, grad_acc_training_cost = _init_model( model_cls, hps) total_dataset = dataset_builder(shuffle_rng=jax.random.PRNGKey(1), batch_size=total_batch_size, eval_batch_size=10, hps=hps) # Ensure we see the same exact batches. train_iter = total_dataset.train_iterator_fn() train_iter = itertools.islice(train_iter, 0, num_steps) train_iter = itertools.cycle(train_iter) def grad_acc_train_iter(): for _ in range(num_steps): total_batch = next(train_iter) # Split each total batch into sub batches. num_sub_batches = total_batch_size // per_step_batch_size start_index = 0 end_index = int(total_batch_size / num_sub_batches) for bi in range(num_sub_batches): yield jax.tree_map(lambda x: x[start_index:end_index], total_batch) # pylint: disable=cell-var-from-loop start_index = end_index end_index = int(total_batch_size * (bi + 2) / num_sub_batches) lrs = jnp.array([1.0, 0.1, 1e-2]) sgd_opt_init, sgd_opt_update = optax.sgd( learning_rate=lambda t: lrs.at[t].get()) opt_init, opt_update = gradient_accumulator.accumulate_gradients( per_step_batch_size=per_step_batch_size, total_batch_size=total_batch_size, virtual_batch_size=virtual_batch_size, base_opt_init_fn=sgd_opt_init, base_opt_update_fn=sgd_opt_update) grad_acc_params, grad_acc_batch_stats = _optimize( # Run for 3x the number of steps to see the same number of examples. num_steps=3 * num_steps, params=grad_acc_params, batch_stats=grad_acc_batch_stats, training_cost=grad_acc_training_cost, train_iter=grad_acc_train_iter(), opt_init=opt_init, opt_update=opt_update) # Compute the same updates, but without gradient accumulation. hps.update({ 'batch_size': total_batch_size, 'total_accumulated_batch_size': None, }) params, batch_stats, training_cost = _init_model(model_cls, hps) params, batch_stats = _optimize(num_steps=num_steps, params=params, batch_stats=batch_stats, training_cost=training_cost, train_iter=train_iter, opt_init=sgd_opt_init, opt_update=sgd_opt_update) diffs_params = jax.tree_multimap(lambda a, b: jnp.mean(jnp.abs(a - b)), grad_acc_params, params) def batch_stats_reduce(a, b): if len(a.shape) > 0: # pylint: disable=g-explicit-length-test return jnp.mean( jnp.abs(jnp.mean(a, axis=0) - jnp.mean(b, axis=0))) # The gradient accumulator counters are scalars. return a - b diffs_batch_stats = jax.tree_multimap(batch_stats_reduce, grad_acc_batch_stats, batch_stats) # We sometimes get small floating point errors in the gradients, so we # cannot test for the values being exactly the same. acceptable_params_diff = 1e-4 acceptable_batch_stats_diff = 5e-3 def check_closeness(root_name, d, max_diff): not_close_dict = {} for name, dd in d.items(): new_name = root_name + '/' + name if root_name else name if isinstance(dd, (dict, core.FrozenDict)): not_close_dict.update( check_closeness(new_name, dd, max_diff)) else: if dd > max_diff: not_close_dict[new_name] = dd return not_close_dict not_close_params = check_closeness('', diffs_params, acceptable_params_diff) self.assertEmpty(not_close_params) not_close_batch_stats = check_closeness('', diffs_batch_stats, acceptable_batch_stats_diff) # Note that for the variance variables in the batch stats collection, they # sometimes can start to diverge slightly over time (with a higher number of # training steps), likely due to numerical issues. self.assertEmpty(not_close_batch_stats)
def test_hessian_free_optimizer(self): """Tests the Hessian-free optimizer.""" model_str = 'autoencoder' model_cls = models.get_model(model_str) model_hps = models.get_model_hparams(model_str) loss = 'sigmoid_binary_cross_entropy' metrics = 'binary_autoencoder_metrics' input_shape = (2, 2, 1) output_shape = (4, ) hps = copy.copy(model_hps) hps.update({ 'optimizer': 'hessian_free', 'opt_hparams': { 'weight_decay': 0.0, }, 'hid_sizes': [2], 'activation_function': ['id'], 'input_shape': input_shape, 'output_shape': output_shape }) model = model_cls(hps, {}, loss, metrics) inputs = jnp.array([[[1, 0], [1, 1]], [[1, 0], [0, 1]]]) targets = inputs.reshape(tuple([inputs.shape[0]] + list(output_shape))) batch = {'inputs': inputs, 'targets': targets} def forward_fn(variables, inputs): logits = model.flax_module.apply(variables, inputs, train=True) return logits def opt_cost(variables): return model.loss_fn(forward_fn(variables, inputs), targets) init_fn, update_fn = optimizers.get_optimizer(hps, model) params = { 'Dense_0': { 'kernel': jnp.array([[-1., 2.], [2., 0.], [-1., 3.], [-2., 2.]]), 'bias': jnp.array([0., 0.]) }, 'Dense_1': { 'kernel': jnp.array([[4., 2., -2., 4.], [-3., 1., 2., -4.]]), 'bias': jnp.array([0., 0., 0., 0.]) } } variables = {'params': params} grad_fn = jax.grad(opt_cost) grads = grad_fn(variables)['params'] outputs = forward_fn(variables, batch['inputs']) n = inputs.shape[0] m = outputs.shape[-1] d = ravel_pytree(params)[0].shape[0] v = np.ones(d) state = init_fn(params) partial_forward_fn = partial(forward_fn, inputs=batch['inputs']) partial_loss_fn = partial(model.loss_fn, targets=batch['targets']) matmul_fn = partial(gvp, variables, outputs, state.inner_state.damping, partial_forward_fn, partial_loss_fn) jacobian = jax.jacfwd(partial_forward_fn)(variables)['params'] jacobian_tensor = np.concatenate( (jacobian['Dense_0']['bias'].reshape( n, m, -1), jacobian['Dense_0']['kernel'].reshape( n, m, -1), jacobian['Dense_1']['bias'].reshape(n, m, -1), jacobian['Dense_1']['kernel'].reshape(n, m, -1)), axis=2) ggn_matrix = 0 for i in range(n): jacobian_matrix = jacobian_tensor[i] hessian = jax.hessian(partial_loss_fn)(outputs[i, None])[0, :, 0, :] ggn_matrix += np.transpose( jacobian_matrix) @ hessian @ jacobian_matrix ggn_matrix /= n ggn_matrix += state.inner_state.damping * np.identity(d) expected = ggn_matrix @ v # Test the gvp function self.assertAlmostEqual(jnp.linalg.norm(matmul_fn(v) - expected), 0, places=4) update_pmapped = jax.pmap(update_fn, axis_name='batch', in_axes=(None, None, None, 0, None)) batch_shard = data_utils.shard(batch) state.hyperparams['learning_rate'] = 1.0 p, state = update_pmapped(grads, state, params, batch_shard, None) # Test the damping parameter update self.assertEqual(state.inner_state.damping, 3 / 2) # Test the search direction self.assertAlmostEqual(jnp.linalg.norm( ravel_pytree(p)[0] + jnp.linalg.inv(ggn_matrix) @ ravel_pytree(grads)[0]), 0, places=4)
def test_cg_backtracking(self): """Tests CG backtracking.""" model_str = 'autoencoder' model_cls = models.get_model(model_str) model_hps = models.get_model_hparams(model_str) loss = 'sigmoid_binary_cross_entropy' metrics = 'binary_autoencoder_metrics' input_shape = (2, 2, 1) output_shape = (4, ) hps = copy.copy(model_hps) hps.update({ 'optimizer': 'hessian_free', 'opt_hparams': { 'weight_decay': 0.0, }, 'hid_sizes': [2], 'activation_function': ['id'], 'input_shape': input_shape, 'output_shape': output_shape }) model = model_cls(hps, {}, loss, metrics) inputs = jnp.array([[[1, 0], [1, 1]], [[1, 0], [0, 1]]]) targets = inputs.reshape(tuple([inputs.shape[0]] + list(output_shape))) batch = {'inputs': inputs, 'targets': targets} def forward_fn(variables, inputs): return model.flax_module.apply(variables, inputs, train=False) def opt_cost(params): return model.loss_fn(forward_fn(params, inputs), targets) params = { 'Dense_0': { 'kernel': jnp.array([[-1., 2.], [2., 0.], [-1., 3.], [-2., 2.]]), 'bias': jnp.array([0., 0.]) }, 'Dense_1': { 'kernel': jnp.array([[4., 2., -2., 4.], [-3., 1., 2., -4.]]), 'bias': jnp.array([0., 0., 0., 0.]) } } unravel_fn = ravel_pytree(params)[1] p1 = np.array([ 0.5, 0.2, 0.1, -0.4, -0.6, 0.4, 0.6, -0.7, 0.0, 0.5, -0.7, 0.2, 0.1, -0.2, 0.4, -0.6, -0.8, 0.7, 0.2, 0.9, -0.1, 0.5 ]) p2 = np.array([ 0.3, -0.1, -0.5, 0.2, -0.4, 0.8, -0.2, 0.0, 0.2, -0.4, 0.6, -0.2, -0.4, 0.2, 0.3, 0.2, -0.2, -0.4, -0.5, 0.2, 0.2, -0.4 ]) p_arr = jnp.array([p1, p2]) p_arr_idx = 1 partial_forward_fn = partial(forward_fn, inputs=batch['inputs']) partial_loss_fn = partial(model.loss_fn, targets=batch['targets']) def obj_fn(variables): return partial_loss_fn(partial_forward_fn(variables)) flattened_p, obj_val = cg_backtracking(p_arr, p_arr_idx, obj_fn, {'params': params}, unravel_fn) # Test the backtracking function. self.assertSameElements(flattened_p, p1) updated_params = apply_updates(params, unravel_fn(p1)) self.assertEqual(opt_cost({'params': updated_params}), obj_val)