def test_initialize_rescale(self): """Test rescaling a single layer of a model.""" input_shape = (28, 28, 1) output_shape = (10,) model_str = 'fully_connected' model_cls = models.get_model(model_str) model_hps = models.get_model_hparams(model_str) loss_name = 'cross_entropy' metrics_name = 'classification_metrics' hps = copy.copy(model_hps) hps.update({'output_shape': output_shape}) rng = jax.random.PRNGKey(0) model = model_cls(hps, {}, loss_name, metrics_name) initializer = initializers.get_initializer('noop') rng, init_rng = jax.random.split(rng) # First initialize with no rescale. flax_module, _ = trainer.initialize( model.flax_module_def, initializer, model.loss_fn, input_shape, output_shape, hps, init_rng, metrics_logger=None) utils.log_pytree_shape_and_statistics(flax_module.params) # Now rescale a layer by 100. rescale_factor = 100 hps.layer_rescale_factors = { '/Dense_1/kernel': rescale_factor, } rescaled_module, _ = trainer.initialize( model.flax_module_def, initializer, model.loss_fn, input_shape, output_shape, hps, init_rng, metrics_logger=None) # Check the right variable is rescaled v1 = flax_module.params['Dense_1']['kernel'] v2 = rescaled_module.params['Dense_1']['kernel'] diff = np.linalg.norm(v1.reshape(-1) * rescale_factor - v2.reshape(-1)) self.assertAlmostEqual(diff, 0.0) # Check that other variables are the same v1 = flax_module.params['Dense_2']['kernel'] v2 = rescaled_module.params['Dense_2']['kernel'] diff = np.linalg.norm(v1.reshape(-1) - v2.reshape(-1)) self.assertAlmostEqual(diff, 0.0)
def main(unused_argv): # Necessary to use the tfds loader. tf.enable_v2_behavior() if jax.process_count() > 1: # TODO(ankugarg): Add support for multihost inference. raise NotImplementedError( 'BLEU eval does not support multihost inference.') rng = jax.random.PRNGKey(FLAGS.seed) mt_eval_config = json.loads(FLAGS.mt_eval_config) if FLAGS.experiment_config_filename: with tf.io.gfile.GFile(FLAGS.experiment_config_filename) as f: experiment_config = json.load(f) if jax.process_index() == 0: logging.info('experiment_config: %r', experiment_config) dataset_name = experiment_config['dataset'] model_name = experiment_config['model'] else: assert FLAGS.dataset and FLAGS.model dataset_name = FLAGS.dataset model_name = FLAGS.model if jax.process_index() == 0: logging.info('argv:\n%s', ' '.join(sys.argv)) logging.info('device_count: %d', jax.device_count()) logging.info('num_hosts : %d', jax.host_count()) logging.info('host_id : %d', jax.host_id()) model_class = models.get_model(model_name) dataset_builder = datasets.get_dataset(dataset_name) dataset_meta_data = datasets.get_dataset_meta_data(dataset_name) hparam_overrides = None if FLAGS.hparam_overrides: if isinstance(FLAGS.hparam_overrides, str): hparam_overrides = json.loads(FLAGS.hparam_overrides) merged_hps = hyperparameters.build_hparams( model_name=model_name, initializer_name=experiment_config['initializer'], dataset_name=dataset_name, hparam_file=FLAGS.trial_hparams_filename, hparam_overrides=hparam_overrides) if jax.process_index() == 0: logging.info('Merged hps are: %s', json.dumps(merged_hps.to_json())) evaluator = bleu_evaluator.BLEUEvaluator(FLAGS.checkpoint_dir, merged_hps, rng, model_class, dataset_builder, dataset_meta_data, mt_eval_config) evaluator.translate_and_calculate_bleu()
def main(unused_argv): # Necessary to use the tfds imagenet loader. tf.enable_v2_behavior() rng = jax.random.PRNGKey(FLAGS.seed) if FLAGS.hessian_eval_config: hessian_eval_config = json.loads(FLAGS.hessian_eval_config) else: hessian_eval_config = hessian_eval.DEFAULT_EVAL_CONFIG if FLAGS.experiment_config_filename: with tf.io.gfile.GFile(FLAGS.experiment_config_filename, 'r') as f: experiment_config = json.load(f) if jax.process_index() == 0: logging.info('experiment_config: %r', experiment_config) dataset_name = experiment_config['dataset'] model_name = experiment_config['model'] else: assert FLAGS.dataset and FLAGS.model dataset_name = FLAGS.dataset model_name = FLAGS.model if jax.process_index() == 0: logging.info('argv:\n%s', ' '.join(sys.argv)) logging.info('device_count: %d', jax.device_count()) logging.info('num_hosts : %d', jax.process_count()) logging.info('host_id : %d', jax.process_index()) model = models.get_model(model_name) dataset_builder = datasets.get_dataset(dataset_name) dataset_meta_data = datasets.get_dataset_meta_data(dataset_name) with tf.io.gfile.GFile(FLAGS.trial_hparams_filename, 'r') as f: hps = config_dict.ConfigDict(json.load(f)) if FLAGS.hparam_overrides: if isinstance(FLAGS.hparam_overrides, str): hparam_overrides = json.loads(FLAGS.hparam_overrides) hps.update_from_flattened_dict(hparam_overrides) run_lanczos.eval_checkpoints( FLAGS.checkpoint_dir, hps, rng, FLAGS.eval_num_batches, model, dataset_builder, dataset_meta_data, hessian_eval_config, FLAGS.min_global_step, FLAGS.max_global_step)
def test_nqm(self): """Test the noisy quadratic model.""" batch_size = 2 dim = 10 model_hps = config_dict.ConfigDict( dict( input_shape=(dim, ), output_shape=(1, ), rng_seed=-1, hessian_decay_power=1.0, noise_decay_power=1.0, nqm_mode='diagH_diagC', model_dtype='float32', )) model_cls = models.get_model('nqm') rng = jax.random.PRNGKey(0) model = model_cls(model_hps, {}) noise_eps = jnp.array(np.random.normal(size=(batch_size, dim))) rng, params_rng = jax.random.split(rng) _, flax_module = model.flax_module_def.create_by_shape( params_rng, [(batch_size, dim)]) model_x = flax_module.params['x'] def loss(model, inputs): return model(inputs) grad_loss = jax.grad(loss) hessian = np.diag( np.array([ 1.0 / np.power(i, model_hps.hessian_decay_power) for i in range(1, dim + 1) ])) noise_matrix = np.diag( np.array([ 1.0 / np.power(i, model_hps.noise_decay_power / 2.0) for i in range(1, dim + 1) ])) noise = jnp.dot(noise_eps, noise_matrix) mean_noise = np.mean(noise, axis=0) # NQM gradient = Hx + eps where eps ~ N(0, C / batch_size). expected_grad = np.dot(hessian, model_x) + mean_noise g = grad_loss(flax_module, noise_eps).params['x'] grad_error = np.sum(np.abs(g - expected_grad)) self.assertAlmostEqual(grad_error, 0.0, places=5)
def test_nqm(self): """Test the noisy quadratic model.""" batch_size = 2 dim = 10 model_hps = config_dict.ConfigDict( dict( input_shape=(dim,), output_shape=(1,), rng_seed=-1, hessian_decay_power=1.0, noise_decay_power=1.0, nqm_mode='diagH_diagC', model_dtype='float32', )) model_cls = models.get_model('nqm') params_rng = jax.random.PRNGKey(0) model = model_cls(model_hps, {}, None, None) noise_eps = jnp.array(np.random.normal(size=(batch_size, dim))) xs = np.zeros((batch_size, dim)) model_init_fn = jax.jit( functools.partial(model.flax_module.init, train=False)) params = model_init_fn({'params': params_rng}, xs)['params'] model_x = params['x'] def loss(params, inputs): return model.training_cost(params, batch=inputs) grad_loss = jax.grad(loss, has_aux=True) hessian = np.diag( np.array([ 1.0 / np.power(i, model_hps.hessian_decay_power) for i in range(1, dim + 1) ])) noise_matrix = np.diag( np.array([ 1.0 / np.power(i, model_hps.noise_decay_power / 2.0) for i in range(1, dim + 1) ])) noise = jnp.dot(noise_eps, noise_matrix) mean_noise = np.mean(noise, axis=0) # NQM gradient = Hx + eps where eps ~ N(0, C / batch_size). expected_grad = np.dot(hessian, model_x) + mean_noise g = grad_loss(params, {'inputs': noise_eps})[0]['x'] grad_error = np.sum(np.abs(g - expected_grad)) self.assertAlmostEqual(grad_error, 0.0, places=5)
def setUp(self): super(CheckpointTest, self).setUp() self.test_dir = tempfile.mkdtemp() loss_name = 'cross_entropy' metrics_name = 'classification_metrics' model = models.get_model('fully_connected') model_hps = models.get_model_hparams('fully_connected') hps = copy.copy(model_hps) hps.update({'output_shape': OUTPUT_SHAPE}) rng = jax.random.PRNGKey(0) model = model(hps, {}, loss_name, metrics_name) xs = jnp.array(np.random.normal(size=INPUT_SHAPE)) rng, params_rng = jax.random.split(rng) _, self.flax_module = model.flax_module_def.create(params_rng, xs)
def setUp(self): super(CheckpointTest, self).setUp() self.test_dir = tempfile.mkdtemp() loss_name = 'cross_entropy' metrics_name = 'classification_metrics' model = models.get_model('fully_connected') model_hps = models.get_model_hparams('fully_connected') hps = copy.copy(model_hps) hps.update({'output_shape': OUTPUT_SHAPE}) rng = jax.random.PRNGKey(0) model = model(hps, {}, loss_name, metrics_name) xs = jnp.array(np.random.normal(size=INPUT_SHAPE)) rng, params_rng = jax.random.split(rng) model_init_fn = jax.jit( functools.partial(model.flax_module.init, train=False)) init_dict = model_init_fn({'params': params_rng}, xs) self.params = init_dict['params']
def _load_model(model_name): """Load a test model.""" rng = jax.random.PRNGKey(0) model_cls = models.get_model(model_name) loss_name = 'cross_entropy' metrics_name = 'classification_metrics' model_hps = models.get_model_hparams(model_name) hps = copy.copy(model_hps) hps.update({'output_shape': OUTPUT_SHAPE}) model = model_cls(hps, {}, loss_name, metrics_name) input_shape = (BATCH_SIZE, ) + MODEL_TO_INPUT_SHAPE[model_name] _, flax_module = model.flax_module_def.create_by_shape(rng, [input_shape], train=True) utils.log_pytree_shape_and_statistics(flax_module.params) return flax_module, input_shape
def test_graph_model(self): """Test forward pass of the GNN model.""" edge_input_shape = (5,) node_input_shape = (5,) output_shape = (5,) model_str = 'gnn' model_hps = models.get_model_hparams(model_str) model_hps.update({'output_shape': output_shape, 'latent_dim': 10, 'hidden_dims': (10,), 'batch_size': 5, 'normalizer': 'batch_norm'}) model_cls = models.get_model(model_str) rng = jax.random.PRNGKey(0) dropout_rng, params_rng = jax.random.split(rng) loss = 'sigmoid_binary_cross_entropy' metrics = 'binary_classification_metrics' model = model_cls(model_hps, {}, loss, metrics) num_graphs = 5 node_per_graph = 3 edge_per_graph = 9 inputs = jraph.get_fully_connected_graph( n_node_per_graph=node_per_graph, n_graph=num_graphs, node_features=np.ones((num_graphs * node_per_graph,) + node_input_shape), ) inputs = inputs._replace( edges=np.ones((num_graphs * edge_per_graph,) + edge_input_shape)) padded_inputs = jraph.pad_with_graphs(inputs, 20, 50, 7) model_init_fn = jax.jit( functools.partial(model.flax_module.init, train=False)) init_dict = model_init_fn({'params': params_rng}, padded_inputs) params = init_dict['params'] batch_stats = init_dict['batch_stats'] # Check that the forward pass works with mutated batch_stats. outputs, _ = model.flax_module.apply( {'params': params, 'batch_stats': batch_stats}, padded_inputs, mutable=['batch_stats'], rngs={'dropout': dropout_rng}, train=True) self.assertEqual(outputs.shape, (7,) + output_shape)
def _load_model(model_name): """Load a test model.""" rng = jax.random.PRNGKey(0) model_cls = models.get_model(model_name) loss_name = 'cross_entropy' metrics_name = 'classification_metrics' model_hps = models.get_model_hparams(model_name) hps = copy.copy(model_hps) hps.update({'output_shape': OUTPUT_SHAPE}) model = model_cls(hps, {}, loss_name, metrics_name) input_shape = (BATCH_SIZE, ) + MODEL_TO_INPUT_SHAPE[model_name] model_init_fn = jax.jit( functools.partial(model.flax_module.init, train=True)) init_dict = model_init_fn({'params': rng}, jnp.zeros(input_shape)) # Trainable model parameters. params = init_dict['params'] utils.log_pytree_shape_and_statistics(params) return model.flax_module, params, input_shape, hps
def test_classification_model(self, model_str): """Test forward pass of the image models.""" model_cls = models.get_model(model_str) model_hps = models.get_model_hparams(model_str) loss = 'cross_entropy' metrics = 'classification_metrics' hps = copy.copy(model_hps) hps.update({'output_shape': OUTPUT_SHAPE['classification']}) rng = jax.random.PRNGKey(0) dropout_rng, params_rng = jax.random.split(rng) model = model_cls(hps, {}, loss, metrics) xs = jnp.array(np.random.normal(size=INPUT_SHAPE['classification'])) model_init_fn = jax.jit( functools.partial(model.flax_module.init, train=False)) init_dict = model_init_fn({'params': params_rng}, xs) params = init_dict['params'] batch_stats = init_dict.get('batch_stats', {}) # Check that the forward pass works with mutated batch_stats. outputs, new_batch_stats = model.flax_module.apply( {'params': params, 'batch_stats': batch_stats}, xs, mutable=['batch_stats'], rngs={'dropout': dropout_rng}, train=True) self.assertEqual(outputs.shape, (INPUT_SHAPE['classification'][0], OUTPUT_SHAPE['classification'][-1])) # If it's a batch norm model check the batch stats changed. if batch_stats: bflat, _ = ravel_pytree(batch_stats) new_bflat, _ = ravel_pytree(new_batch_stats) self.assertFalse(jnp.array_equal(bflat, new_bflat)) # Test batch_norm in inference mode. outputs = model.flax_module.apply( {'params': params, 'batch_stats': batch_stats}, xs, train=False) self.assertEqual( outputs.shape, (INPUT_SHAPE['classification'][0], OUTPUT_SHAPE['classification'][-1]))
def test_autoencoder_model(self, model_str): """Test forward pass of the autoencoder models.""" model_cls = models.get_model(model_str) model_hps = models.get_model_hparams(model_str) loss = 'sigmoid_binary_cross_entropy' metrics = 'binary_autoencoder_metrics' hps = copy.copy(model_hps) hps.update({'output_shape': OUTPUT_SHAPE[model_str]}) params_rng = jax.random.PRNGKey(0) model = model_cls(hps, {}, loss, metrics) xs = jnp.array(np.random.normal(size=INPUT_SHAPE[model_str])) model_init_fn = jax.jit( functools.partial(model.flax_module.init, train=False)) init_dict = model_init_fn({'params': params_rng}, xs) params = init_dict['params'] batch_stats = init_dict.get('batch_stats', {}) # Check that the forward pass works with mutated batch_stats. outputs, new_batch_stats = model.flax_module.apply( {'params': params, 'batch_stats': batch_stats}, xs, mutable=['batch_stats'], train=True) self.assertEqual( outputs.shape, tuple([INPUT_SHAPE[model_str][0]] + list(OUTPUT_SHAPE[model_str]))) # If it's a batch norm model check the batch stats changed. if batch_stats: bflat, _ = ravel_pytree(batch_stats) new_bflat, _ = ravel_pytree(new_batch_stats) self.assertFalse(jnp.array_equal(bflat, new_bflat)) # Test batch_norm in inference mode. outputs = model.flax_module.apply( {'params': params, 'batch_stats': batch_stats}, xs, train=False) self.assertEqual( outputs.shape, tuple([INPUT_SHAPE[model_str][0]] + list(OUTPUT_SHAPE[model_str])))
def test_autoencoder_model(self, model_str): """Test forward pass of the autoencoder models.""" model_cls = models.get_model(model_str) model_hps = models.get_model_hparams(model_str) loss = 'sigmoid_binary_cross_entropy' metrics = 'binary_autoencoder_metrics' hps = copy.copy(model_hps) hps.update({'output_shape': OUTPUT_SHAPE[model_str]}) rng = jax.random.PRNGKey(0) model = model_cls(hps, {}, loss, metrics) xs = jnp.array(np.random.normal(size=INPUT_SHAPE[model_str])) rng, params_rng = jax.random.split(rng) with nn.stateful() as batch_stats: with nn.stochastic(params_rng): _, flax_module = model.flax_module_def.create(params_rng, xs) # Check that the forward pass works with mutated batch_stats. with nn.stateful(batch_stats) as new_batch_stats: with nn.stochastic(params_rng): outputs = flax_module(xs) self.assertEqual( outputs.shape, tuple([INPUT_SHAPE[model_str][0]] + list(OUTPUT_SHAPE[model_str]))) # If it's a batch norm model check the batch stats changed. if batch_stats.as_dict(): bflat, _ = ravel_pytree(batch_stats) new_bflat, _ = ravel_pytree(new_batch_stats) self.assertFalse(jnp.array_equal(bflat, new_bflat)) # Test batch_norm in inference mode. with nn.stateful(batch_stats, mutable=False): outputs = flax_module(xs, train=False) self.assertEqual( outputs.shape, tuple([INPUT_SHAPE[model_str][0]] + list(OUTPUT_SHAPE[model_str])))
def test_text_model_trainer(self): """Test training of a small transformer model on fake data.""" rng = jax.random.PRNGKey(42) # Set the numpy seed to make the fake data deterministc. mocking.mock_data # ultimately calls numpy.random. np.random.seed(0) model_cls = models.get_model('transformer') loss_name = 'cross_entropy' metrics_name = 'classification_metrics' hps = config_dict.ConfigDict({ # Architecture Hparams. 'batch_size': _TEXT_BATCH_SIZE, 'emb_dim': 32, 'num_heads': 2, 'num_layers': 3, 'qkv_dim': 32, 'mlp_dim': 64, 'max_target_length': 64, 'max_eval_target_length': 64, 'input_shape': (64, ), 'output_shape': (_VOCAB_SIZE, ), 'dropout_rate': 0.1, 'attention_dropout_rate': 0.1, 'layer_rescale_factors': {}, 'optimizer': 'momentum', 'normalizer': 'layer_norm', 'opt_hparams': { 'momentum': 0.9, }, 'lr_hparams': { 'base_lr': 0.005, 'schedule': 'constant' }, # Training HParams. 'l2_decay_factor': 1e-4, 'l2_decay_rank_threshold': 2, 'train_size': _TEXT_TRAIN_SIZE, 'gradient_clipping': 0.0, 'model_dtype': 'float32', 'decode': False, 'num_device_prefetches': 0, }) initializer = initializers.get_initializer('noop') eval_num_batches = 5 dataset, dataset_meta_data = _get_fake_text_dataset( batch_size=hps.batch_size, eval_num_batches=eval_num_batches) eval_batch_size = hps.batch_size model = model_cls(hps, dataset_meta_data, loss_name, metrics_name) eval_every = 10 checkpoint_steps = [] num_train_steps = _TEXT_TRAIN_SIZE // _TEXT_BATCH_SIZE * 3 metrics_logger, init_logger = utils.set_up_loggers(self.test_dir) _ = list( trainer.train( train_dir=self.test_dir, model=model, dataset_builder=lambda *unused_args, **unused_kwargs: dataset, initializer=initializer, num_train_steps=num_train_steps, hps=hps, rng=rng, eval_batch_size=eval_batch_size, eval_num_batches=eval_num_batches, eval_train_num_batches=eval_num_batches, eval_frequency=eval_every, checkpoint_steps=checkpoint_steps, metrics_logger=metrics_logger, init_logger=init_logger)) with tf.io.gfile.GFile(os.path.join(self.test_dir, 'measurements.csv')) as f: df = pandas.read_csv(f) train_err = df['train/error_rate'].values[-1] # Note that upgrading to Linen made this fail at 0.6. self.assertLess(train_err, 0.7) self.assertEqual(set(df.columns.values), set(get_column_names())) prev_train_err = train_err # Test reload from the checkpoint by increasing num_train_steps. num_train_steps_reload = _TEXT_TRAIN_SIZE // _TEXT_BATCH_SIZE * 6 _ = list( trainer.train( train_dir=self.test_dir, model=model, dataset_builder=lambda *unused_args, **unused_kwargs: dataset, initializer=initializer, num_train_steps=num_train_steps_reload, hps=hps, rng=rng, eval_batch_size=eval_batch_size, eval_num_batches=eval_num_batches, eval_train_num_batches=eval_num_batches, eval_frequency=eval_every, checkpoint_steps=checkpoint_steps, metrics_logger=metrics_logger, init_logger=init_logger)) with tf.io.gfile.GFile(os.path.join(self.test_dir, 'measurements.csv')) as f: df = pandas.read_csv(f) train_err = df['train/error_rate'].values[-1] train_loss = df['train/ce_loss'].values[-1] # Note that upgrading to Linen made this fail at 0.45. self.assertLess(train_err, 0.67) self.assertLess(train_err, prev_train_err) # Note that upgrading to Linen made this fail at 0.9. self.assertLess(train_loss, 1.35) self.assertEqual(df['valid/num_examples'].values[-1], eval_num_batches * eval_batch_size * _MAX_LEN) # Check that the correct learning rate was saved in the measurements file. final_step = df['global_step'].values[-1] self.assertEqual(num_train_steps_reload, final_step) self.assertEqual(set(df.columns.values), set(get_column_names()))
def test_shampoo_wrn(self): """Test distributed shampoo on fake dataset.""" model_name = 'simple_cnn' model_cls = models.get_model(model_name) hparam_overrides = { 'optimizer': 'distributed_shampoo', 'batch_size': 1, 'train_size': 10, 'valid_size': 10, 'input_shape': (32, 32, 3), 'output_shape': (10,), 'opt_hparams': { 'block_size': 32, 'beta1': 0.9, 'beta2': 0.999, 'diagonal_epsilon': 1e-10, 'matrix_epsilon': 1e-6, 'weight_decay': 0.0, 'start_preconditioning_step': 5, 'preconditioning_compute_steps': 1, 'statistics_compute_steps': 1, 'best_effort_shape_interpretation': True, 'graft_type': distributed_shampoo.GraftingType.SGD, 'nesterov': True, 'exponent_override': 0, 'batch_axis_name': 'batch', 'num_devices_for_pjit': None, 'shard_optimizer_states': False, 'inverse_failure_threshold': 0.1, 'clip_by_scaled_gradient_norm': None, 'precision': lax.Precision.HIGHEST, 'moving_average_for_momentum': False, 'skip_preconditioning_dim_size_gt': 4096, 'best_effort_memory_usage_reduction': False, }, } input_pipeline_hps = config_dict.ConfigDict(dict( num_tf_data_prefetches=-1, num_device_prefetches=0, num_tf_data_map_parallel_calls=-1, )) hps = hyperparameters.build_hparams( model_name, initializer_name='noop', dataset_name='fake', hparam_file=None, hparam_overrides=hparam_overrides, input_pipeline_hps=input_pipeline_hps) initializer = initializers.get_initializer('noop') dataset_builder = datasets.get_dataset('fake') dataset = dataset_builder( shuffle_rng=jax.random.PRNGKey(0), batch_size=hps.batch_size, eval_batch_size=hps.batch_size, hps=hps) loss_name = 'cross_entropy' metrics_name = 'classification_metrics' dataset_meta_data = datasets.get_dataset_meta_data('fake') model = model_cls(hps, dataset_meta_data, loss_name, metrics_name) metrics_logger, init_logger = utils.set_up_loggers(self.test_dir) _ = list( trainer.train( train_dir=self.test_dir, model=model, dataset_builder=lambda *unused_args, **unused_kwargs: dataset, initializer=initializer, num_train_steps=1, hps=hps, rng=jax.random.PRNGKey(42), eval_batch_size=hps.batch_size, eval_num_batches=None, eval_train_num_batches=None, eval_frequency=10, checkpoint_steps=[], metrics_logger=metrics_logger, init_logger=init_logger)) with tf.io.gfile.GFile(os.path.join(self.test_dir, 'measurements.csv')) as f: df = pandas.read_csv(f) valid_ce_loss = df['valid/ce_loss'].values[-1] self.assertLess(valid_ce_loss, 1e-3)
def test_cg_backtracking(self): """Tests CG backtracking.""" model_str = 'autoencoder' model_cls = models.get_model(model_str) model_hps = models.get_model_hparams(model_str) loss = 'sigmoid_binary_cross_entropy' metrics = 'binary_autoencoder_metrics' input_shape = (2, 2, 1) output_shape = (4, ) hps = copy.copy(model_hps) hps.update({ 'optimizer': 'hessian_free', 'opt_hparams': { 'weight_decay': 0.0, }, 'hid_sizes': [2], 'activation_function': ['id'], 'input_shape': input_shape, 'output_shape': output_shape }) model = model_cls(hps, {}, loss, metrics) inputs = jnp.array([[[1, 0], [1, 1]], [[1, 0], [0, 1]]]) targets = inputs.reshape(tuple([inputs.shape[0]] + list(output_shape))) batch = {'inputs': inputs, 'targets': targets} def forward_fn(variables, inputs): return model.flax_module.apply(variables, inputs, train=False) def opt_cost(params): return model.loss_fn(forward_fn(params, inputs), targets) params = { 'Dense_0': { 'kernel': jnp.array([[-1., 2.], [2., 0.], [-1., 3.], [-2., 2.]]), 'bias': jnp.array([0., 0.]) }, 'Dense_1': { 'kernel': jnp.array([[4., 2., -2., 4.], [-3., 1., 2., -4.]]), 'bias': jnp.array([0., 0., 0., 0.]) } } unravel_fn = ravel_pytree(params)[1] p1 = np.array([ 0.5, 0.2, 0.1, -0.4, -0.6, 0.4, 0.6, -0.7, 0.0, 0.5, -0.7, 0.2, 0.1, -0.2, 0.4, -0.6, -0.8, 0.7, 0.2, 0.9, -0.1, 0.5 ]) p2 = np.array([ 0.3, -0.1, -0.5, 0.2, -0.4, 0.8, -0.2, 0.0, 0.2, -0.4, 0.6, -0.2, -0.4, 0.2, 0.3, 0.2, -0.2, -0.4, -0.5, 0.2, 0.2, -0.4 ]) p_arr = jnp.array([p1, p2]) p_arr_idx = 1 partial_forward_fn = partial(forward_fn, inputs=batch['inputs']) partial_loss_fn = partial(model.loss_fn, targets=batch['targets']) def obj_fn(variables): return partial_loss_fn(partial_forward_fn(variables)) flattened_p, obj_val = cg_backtracking(p_arr, p_arr_idx, obj_fn, {'params': params}, unravel_fn) # Test the backtracking function. self.assertSameElements(flattened_p, p1) updated_params = apply_updates(params, unravel_fn(p1)) self.assertEqual(opt_cost({'params': updated_params}), obj_val)
def test_translate_model(self): """Test forward pass of the translate model.""" vocab_size = 16 small_hps = config_dict.ConfigDict({ # Architecture Hparams. 'batch_size': 16, 'share_embeddings': False, 'logits_via_embedding': False, 'emb_dim': 32, 'num_heads': 2, 'enc_num_layers': 2, 'dec_num_layers': 2, 'qkv_dim': 32, 'label_smoothing': 0.1, 'mlp_dim': 64, 'max_target_length': 64, 'max_eval_target_length': 64, 'normalizer': 'pre_layer_norm', 'max_predict_length': 64, 'dropout_rate': 0.1, 'attention_dropout_rate': 0.1, 'momentum': 0.9, 'lr_hparams': { 'base_lr': 0.005, 'schedule': 'constant' }, 'output_shape': (vocab_size,), # Training HParams. 'l2_decay_factor': 1e-4, 'enc_self_attn_kernel_init': 'xavier_uniform', 'dec_self_attn_kernel_init': 'xavier_uniform', 'dec_cross_attn_kernel_init': 'xavier_uniform', 'decode': False, }) text_src_input_shape = (32, 64) # batch_size, max_source_length text_tgt_input_shape = (32, 40) # batch_size, max_target_length model_cls = models.get_model('xformer_translate') rng = jax.random.PRNGKey(0) loss = 'cross_entropy' metrics = 'classification_metrics' model = model_cls(small_hps, { 'shift_outputs': True, 'causal': True }, loss, metrics) xs = jnp.array( np.random.randint(size=text_src_input_shape, low=1, high=vocab_size)) ys = jnp.array( np.random.randint(size=text_tgt_input_shape, low=1, high=vocab_size)) dropout_rng, params_rng = jax.random.split(rng) model_init_fn = jax.jit( functools.partial(model.flax_module.init, train=False)) init_dict = model_init_fn({'params': params_rng}, xs, ys) params = init_dict['params'] # Test forward pass. @jax.jit def forward_pass(params, xs, ys, dropout_rng): outputs = model.flax_module.apply( {'params': params}, xs, ys, rngs={'dropout': dropout_rng}, train=True) return outputs logits = forward_pass(params, xs, ys, dropout_rng) # Testing only train mode # TODO(ankugarg): Add tests for individual encoder/decoder (inference mode). self.assertEqual( logits.shape, (text_tgt_input_shape[0], text_tgt_input_shape[1], vocab_size))
def test_text_models(self, model_str): """Test forward pass of the transformer model.""" # TODO(gilmer): Find a clean way to handle small test hparams. vocab_size = 16 small_hps = config_dict.ConfigDict({ # Architecture Hparams. 'batch_size': 16, 'emb_dim': 32, 'num_heads': 2, 'num_layers': 3, 'qkv_dim': 32, 'label_smoothing': 0.1, 'mlp_dim': 64, 'max_target_length': 64, 'max_eval_target_length': 64, 'dropout_rate': 0.1, 'attention_dropout_rate': 0.1, 'momentum': 0.9, 'normalizer': 'layer_norm', 'lr_hparams': { 'base_lr': 0.005, 'schedule': 'constant' }, 'output_shape': (vocab_size,), 'model_dtype': 'float32', # Training HParams. 'l2_decay_factor': 1e-4, 'decode': False, }) text_input_shape = (32, 64) # batch_size, max_target_length model_cls = models.get_model(model_str) rng = jax.random.PRNGKey(0) loss = 'cross_entropy' metrics = 'classification_metrics' model = model_cls(small_hps, { 'max_len': 64, 'shift_inputs': True, 'causal': True }, loss, metrics) xs = jnp.array( np.random.randint(size=text_input_shape, low=1, high=vocab_size)) dropout_rng, params_rng = jax.random.split(rng) model_init_fn = jax.jit( functools.partial(model.flax_module.init, train=False)) init_dict = model_init_fn({'params': params_rng}, xs) params = init_dict['params'] batch_stats = init_dict.get('batch_stats', {}) # Check that the forward pass works with mutated batch_stats. # Due to a bug in flax, this jit is required, otherwise the model errors. @jax.jit def forward_pass(params, xs, dropout_rng): outputs, new_batch_stats = model.flax_module.apply( {'params': params, 'batch_stats': batch_stats}, xs, mutable=['batch_stats'], rngs={'dropout': dropout_rng}, train=True) return outputs, new_batch_stats outputs, new_batch_stats = forward_pass(params, xs, dropout_rng) self.assertEqual(outputs.shape, (text_input_shape[0], text_input_shape[1], vocab_size)) # If it's a batch norm model check the batch stats changed. if batch_stats: bflat, _ = ravel_pytree(batch_stats) new_bflat, _ = ravel_pytree(new_batch_stats) self.assertFalse(jnp.array_equal(bflat, new_bflat)) # Test batch_norm in inference mode. outputs = model.flax_module.apply( {'params': params, 'batch_stats': batch_stats}, xs, train=False) self.assertEqual(outputs.shape, (text_input_shape[0], text_input_shape[1], vocab_size))
def test_text_models(self, model_str): """Test forward pass of the transformer model.""" # TODO(gilmer): Find a clean way to handle small test hparams. vocab_size = 16 small_hps = config_dict.ConfigDict({ # Architecture Hparams. 'batch_size': 16, 'emb_dim': 32, 'num_heads': 2, 'num_layers': 3, 'qkv_dim': 32, 'label_smoothing': 0.1, 'mlp_dim': 64, 'max_target_length': 64, 'max_eval_target_length': 64, 'dropout_rate': 0.1, 'attention_dropout_rate': 0.1, 'momentum': 0.9, 'normalizer': 'layer_norm', 'lr_hparams': { 'initial_value': 0.005, 'schedule': 'constant' }, 'output_shape': (vocab_size, ), # Training HParams. 'l2_decay_factor': 1e-4 }) text_input_shape = (32, 64) # batch_size, max_target_length model_cls = models.get_model(model_str) rng = jax.random.PRNGKey(0) loss = 'cross_entropy' metrics = 'classification_metrics' model = model_cls(small_hps, { 'max_len': 64, 'shift_inputs': True, 'causal': True }, loss, metrics) xs = jnp.array( np.random.randint(size=text_input_shape, low=1, high=vocab_size)) rng, params_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng) with nn.stateful() as batch_stats: _, flax_module = model.flax_module_def.create_by_shape( params_rng, [(text_input_shape, jnp.float32)], train=False) # Check that the forward pass works with mutated batch_stats. # Due to a bug in flax, this jit is required, otherwise the model errors. @jax.jit def forward_pass(flax_module, xs, dropout_rng): with batch_stats.mutate() as new_batch_stats: with nn.stochastic(dropout_rng): return flax_module(xs, train=True), new_batch_stats outputs, new_batch_stats = forward_pass(flax_module, xs, dropout_rng) self.assertEqual( outputs.shape, (text_input_shape[0], text_input_shape[1], vocab_size)) # If it's a batch norm model check the batch stats changed. if batch_stats.as_dict(): bflat, _ = ravel_pytree(batch_stats) new_bflat, _ = ravel_pytree(new_batch_stats) self.assertFalse(jnp.array_equal(bflat, new_bflat)) # Test batch_norm in inference mode. with nn.stateful(batch_stats, mutable=False): outputs = flax_module(xs, train=False) self.assertEqual( outputs.shape, (text_input_shape[0], text_input_shape[1], vocab_size))
def test_hessian_free_optimizer(self): """Tests the Hessian-free optimizer.""" model_str = 'autoencoder' model_cls = models.get_model(model_str) model_hps = models.get_model_hparams(model_str) loss = 'sigmoid_binary_cross_entropy' metrics = 'binary_autoencoder_metrics' input_shape = (2, 2, 1) output_shape = (4, ) hps = copy.copy(model_hps) hps.update({ 'optimizer': 'hessian_free', 'opt_hparams': { 'weight_decay': 0.0, }, 'hid_sizes': [2], 'activation_function': ['id'], 'input_shape': input_shape, 'output_shape': output_shape }) model = model_cls(hps, {}, loss, metrics) inputs = jnp.array([[[1, 0], [1, 1]], [[1, 0], [0, 1]]]) targets = inputs.reshape(tuple([inputs.shape[0]] + list(output_shape))) batch = {'inputs': inputs, 'targets': targets} def forward_fn(variables, inputs): logits = model.flax_module.apply(variables, inputs, train=True) return logits def opt_cost(variables): return model.loss_fn(forward_fn(variables, inputs), targets) init_fn, update_fn = optimizers.get_optimizer(hps, model) params = { 'Dense_0': { 'kernel': jnp.array([[-1., 2.], [2., 0.], [-1., 3.], [-2., 2.]]), 'bias': jnp.array([0., 0.]) }, 'Dense_1': { 'kernel': jnp.array([[4., 2., -2., 4.], [-3., 1., 2., -4.]]), 'bias': jnp.array([0., 0., 0., 0.]) } } variables = {'params': params} grad_fn = jax.grad(opt_cost) grads = grad_fn(variables)['params'] outputs = forward_fn(variables, batch['inputs']) n = inputs.shape[0] m = outputs.shape[-1] d = ravel_pytree(params)[0].shape[0] v = np.ones(d) state = init_fn(params) partial_forward_fn = partial(forward_fn, inputs=batch['inputs']) partial_loss_fn = partial(model.loss_fn, targets=batch['targets']) matmul_fn = partial(gvp, variables, outputs, state.inner_state.damping, partial_forward_fn, partial_loss_fn) jacobian = jax.jacfwd(partial_forward_fn)(variables)['params'] jacobian_tensor = np.concatenate( (jacobian['Dense_0']['bias'].reshape( n, m, -1), jacobian['Dense_0']['kernel'].reshape( n, m, -1), jacobian['Dense_1']['bias'].reshape(n, m, -1), jacobian['Dense_1']['kernel'].reshape(n, m, -1)), axis=2) ggn_matrix = 0 for i in range(n): jacobian_matrix = jacobian_tensor[i] hessian = jax.hessian(partial_loss_fn)(outputs[i, None])[0, :, 0, :] ggn_matrix += np.transpose( jacobian_matrix) @ hessian @ jacobian_matrix ggn_matrix /= n ggn_matrix += state.inner_state.damping * np.identity(d) expected = ggn_matrix @ v # Test the gvp function self.assertAlmostEqual(jnp.linalg.norm(matmul_fn(v) - expected), 0, places=4) update_pmapped = jax.pmap(update_fn, axis_name='batch', in_axes=(None, None, None, 0, None)) batch_shard = data_utils.shard(batch) state.hyperparams['learning_rate'] = 1.0 p, state = update_pmapped(grads, state, params, batch_shard, None) # Test the damping parameter update self.assertEqual(state.inner_state.damping, 3 / 2) # Test the search direction self.assertAlmostEqual(jnp.linalg.norm( ravel_pytree(p)[0] + jnp.linalg.inv(ggn_matrix) @ ravel_pytree(grads)[0]), 0, places=4)
def test_accumulation(self): """Test simple gradient accumulation.""" num_steps = 3 per_step_batch_size = 16 total_batch_size = 48 virtual_batch_size = 8 model_str = 'wide_resnet' # Pick a model with batch norm. model_cls = models.get_model(model_str) model_hps = models.get_model_hparams(model_str) dataset_name = 'cifar10' dataset_builder = datasets.get_dataset(dataset_name) hps = copy.copy(model_hps) hps.update(datasets.get_dataset_hparams(dataset_name)) # Compute updates using gradient accumulation. hps.update({ 'batch_size': per_step_batch_size, 'virtual_batch_size': virtual_batch_size, 'normalizer': 'virtual_batch_norm', 'total_accumulated_batch_size': total_batch_size, }) grad_acc_params, grad_acc_batch_stats, grad_acc_training_cost = _init_model( model_cls, hps) total_dataset = dataset_builder(shuffle_rng=jax.random.PRNGKey(1), batch_size=total_batch_size, eval_batch_size=10, hps=hps) # Ensure we see the same exact batches. train_iter = total_dataset.train_iterator_fn() train_iter = itertools.islice(train_iter, 0, num_steps) train_iter = itertools.cycle(train_iter) def grad_acc_train_iter(): for _ in range(num_steps): total_batch = next(train_iter) # Split each total batch into sub batches. num_sub_batches = total_batch_size // per_step_batch_size start_index = 0 end_index = int(total_batch_size / num_sub_batches) for bi in range(num_sub_batches): yield jax.tree_map(lambda x: x[start_index:end_index], total_batch) # pylint: disable=cell-var-from-loop start_index = end_index end_index = int(total_batch_size * (bi + 2) / num_sub_batches) lrs = jnp.array([1.0, 0.1, 1e-2]) sgd_opt_init, sgd_opt_update = optax.sgd( learning_rate=lambda t: lrs.at[t].get()) opt_init, opt_update = gradient_accumulator.accumulate_gradients( per_step_batch_size=per_step_batch_size, total_batch_size=total_batch_size, virtual_batch_size=virtual_batch_size, base_opt_init_fn=sgd_opt_init, base_opt_update_fn=sgd_opt_update) grad_acc_params, grad_acc_batch_stats = _optimize( # Run for 3x the number of steps to see the same number of examples. num_steps=3 * num_steps, params=grad_acc_params, batch_stats=grad_acc_batch_stats, training_cost=grad_acc_training_cost, train_iter=grad_acc_train_iter(), opt_init=opt_init, opt_update=opt_update) # Compute the same updates, but without gradient accumulation. hps.update({ 'batch_size': total_batch_size, 'total_accumulated_batch_size': None, }) params, batch_stats, training_cost = _init_model(model_cls, hps) params, batch_stats = _optimize(num_steps=num_steps, params=params, batch_stats=batch_stats, training_cost=training_cost, train_iter=train_iter, opt_init=sgd_opt_init, opt_update=sgd_opt_update) diffs_params = jax.tree_multimap(lambda a, b: jnp.mean(jnp.abs(a - b)), grad_acc_params, params) def batch_stats_reduce(a, b): if len(a.shape) > 0: # pylint: disable=g-explicit-length-test return jnp.mean( jnp.abs(jnp.mean(a, axis=0) - jnp.mean(b, axis=0))) # The gradient accumulator counters are scalars. return a - b diffs_batch_stats = jax.tree_multimap(batch_stats_reduce, grad_acc_batch_stats, batch_stats) # We sometimes get small floating point errors in the gradients, so we # cannot test for the values being exactly the same. acceptable_params_diff = 1e-4 acceptable_batch_stats_diff = 5e-3 def check_closeness(root_name, d, max_diff): not_close_dict = {} for name, dd in d.items(): new_name = root_name + '/' + name if root_name else name if isinstance(dd, (dict, core.FrozenDict)): not_close_dict.update( check_closeness(new_name, dd, max_diff)) else: if dd > max_diff: not_close_dict[new_name] = dd return not_close_dict not_close_params = check_closeness('', diffs_params, acceptable_params_diff) self.assertEmpty(not_close_params) not_close_batch_stats = check_closeness('', diffs_batch_stats, acceptable_batch_stats_diff) # Note that for the variance variables in the batch stats collection, they # sometimes can start to diverge slightly over time (with a higher number of # training steps), likely due to numerical issues. self.assertEmpty(not_close_batch_stats)
def test_translate_model(self): """Test forward pass of the translate model.""" vocab_size = 16 small_hps = config_dict.ConfigDict({ # Architecture Hparams. 'batch_size': 16, 'share_embeddings': False, 'logits_via_embedding': False, 'emb_dim': 32, 'num_heads': 2, 'enc_num_layers': 2, 'dec_num_layers': 2, 'qkv_dim': 32, 'label_smoothing': 0.1, 'mlp_dim': 64, 'max_target_length': 64, 'max_eval_target_length': 64, 'normalizer': 'pre_layer_norm', 'max_predict_length': 64, 'dropout_rate': 0.1, 'attention_dropout_rate': 0.1, 'momentum': 0.9, 'lr_hparams': { 'initial_value': 0.005, 'schedule': 'constant' }, 'output_shape': (vocab_size, ), # Training HParams. 'l2_decay_factor': 1e-4 }) text_src_input_shape = (32, 64) # batch_size, max_source_length text_tgt_input_shape = (32, 40) # batch_size, max_target_length model_cls = models.get_model('xformer_translate') rng = jax.random.PRNGKey(0) loss = 'cross_entropy' metrics = 'classification_metrics' model = model_cls(small_hps, { 'shift_outputs': True, 'causal': True }, loss, metrics) xs = jnp.array( np.random.randint(size=text_src_input_shape, low=1, high=vocab_size)) ys = jnp.array( np.random.randint(size=text_tgt_input_shape, low=1, high=vocab_size)) rng, params_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng) with nn.stateful() as batch_stats: _, flax_module = model.flax_module_def.create_by_shape( params_rng, [(text_src_input_shape, jnp.float32), (text_tgt_input_shape, jnp.float32)], train=False) # Test forward pass. @jax.jit def forward_pass(flax_module, xs, ys, dropout_rng): with batch_stats.mutate() as new_batch_stats: with nn.stochastic(dropout_rng): return flax_module(xs, ys, train=True), new_batch_stats logits, _ = forward_pass(flax_module, xs, ys, dropout_rng) # Testing only train mode # TODO(ankugarg): Add tests for individual encoder/decoder (inference mode). self.assertEqual( logits.shape, (text_tgt_input_shape[0], text_tgt_input_shape[1], vocab_size))
def test_early_stopping(self): """Test training early stopping on MNIST with a small model.""" rng = jax.random.PRNGKey(0) # Set the numpy seed to make the fake data deterministc. mocking.mock_data # ultimately calls numpy.random. np.random.seed(0) model_name = 'fully_connected' loss_name = 'cross_entropy' metrics_name = 'classification_metrics' initializer_name = 'noop' dataset_name = 'mnist' model_cls = models.get_model(model_name) initializer = initializers.get_initializer(initializer_name) dataset_builder = datasets.get_dataset(dataset_name) hparam_overrides = { 'lr_hparams': { 'base_lr': 0.1, 'schedule': 'cosine' }, 'batch_size': 8, 'train_size': 160, 'valid_size': 96, 'test_size': 80, } input_pipeline_hps = config_dict.ConfigDict( dict( num_tf_data_prefetches=-1, num_device_prefetches=0, num_tf_data_map_parallel_calls=-1, )) hps = hyperparameters.build_hparams( model_name, initializer_name, dataset_name, hparam_file=None, hparam_overrides=hparam_overrides, input_pipeline_hps=input_pipeline_hps) eval_batch_size = 16 num_examples = 256 def as_dataset(self, *args, **kwargs): del args del kwargs # pylint: disable=g-long-lambda,g-complex-comprehension return tf.data.Dataset.from_generator( lambda: ({ 'image': np.ones(shape=(28, 28, 1), dtype=np.uint8), 'label': 9, } for i in range(num_examples)), output_types=self.info.features.dtype, output_shapes=self.info.features.shape, ) # This will override the tfds.load(mnist) call to return 100 fake samples. with tfds.testing.mock_data(as_dataset_fn=as_dataset, num_examples=num_examples): dataset = dataset_builder(shuffle_rng=jax.random.PRNGKey(0), batch_size=hps.batch_size, eval_batch_size=eval_batch_size, hps=hps) model = model_cls(hps, datasets.get_dataset_meta_data(dataset_name), loss_name, metrics_name) num_train_steps = 40 early_stopping_target_name = 'test/ce_loss' early_stopping_target_value = 0.005 early_stopping_mode = 'less' eval_num_batches = 5 eval_every = 10 checkpoint_steps = [1, 3, 15] metrics_logger, init_logger = utils.set_up_loggers(self.test_dir) epoch_reports = list( trainer.train( train_dir=self.test_dir, model=model, dataset_builder=lambda *unused_args, **unused_kwargs: dataset, initializer=initializer, num_train_steps=num_train_steps, hps=hps, rng=rng, eval_batch_size=eval_batch_size, eval_num_batches=eval_num_batches, eval_train_num_batches=eval_num_batches, eval_frequency=eval_every, checkpoint_steps=checkpoint_steps, early_stopping_target_name=early_stopping_target_name, early_stopping_target_value=early_stopping_target_value, early_stopping_mode=early_stopping_mode, metrics_logger=metrics_logger, init_logger=init_logger)) self.assertLen(epoch_reports, 3) self.assertGreater(epoch_reports[-2][early_stopping_target_name], early_stopping_target_value) self.assertLess(epoch_reports[-1][early_stopping_target_name], early_stopping_target_value)
def test_text_model(self): """Test gradient accumulator training of a small transformer.""" rng = jax.random.PRNGKey(42) # Set the numpy seed to make the fake data deterministc. mocking.mock_data # ultimately calls numpy.random. np.random.seed(0) model_cls = models.get_model('transformer') loss_name = 'cross_entropy' metrics_name = 'classification_metrics' batch_size = 16 train_size = 20 * batch_size hps = config_dict.ConfigDict({ # Architecture Hparams. 'batch_size': batch_size, 'emb_dim': 32, 'num_heads': 2, 'num_layers': 3, 'qkv_dim': 32, 'mlp_dim': 64, 'max_target_length': 64, 'max_eval_target_length': 64, 'input_shape': (64, ), 'output_shape': (4, ), 'dropout_rate': 0.1, 'attention_dropout_rate': 0.1, 'layer_rescale_factors': {}, 'optimizer': 'momentum', 'normalizer': 'layer_norm', 'opt_hparams': { 'momentum': 0.9, }, 'lr_hparams': { 'base_lr': 0.005, 'schedule': 'constant' }, # Training HParams. 'l2_decay_factor': 1e-4, 'l2_decay_rank_threshold': 2, 'train_size': train_size, 'gradient_clipping': 0.0, 'model_dtype': 'float32', 'decode': False, }) initializer = initializers.get_initializer('noop') eval_num_batches = 5 dataset, dataset_meta_data = _get_fake_text_dataset( batch_size=hps.batch_size, eval_num_batches=eval_num_batches) eval_batch_size = hps.batch_size model = model_cls(hps, dataset_meta_data, loss_name, metrics_name) eval_every = 10 checkpoint_steps = [] num_train_steps = train_size // batch_size * 3 metrics_logger, init_logger = trainer.set_up_loggers(self.test_dir) _ = list( trainer.train( train_dir=self.test_dir, model=model, dataset_builder=lambda *unused_args, **unused_kwargs: dataset, initializer=initializer, num_train_steps=num_train_steps, hps=hps, rng=rng, eval_batch_size=eval_batch_size, eval_num_batches=eval_num_batches, eval_train_num_batches=eval_num_batches, eval_frequency=eval_every, checkpoint_steps=checkpoint_steps, metrics_logger=metrics_logger, init_logger=init_logger)) with tf.io.gfile.GFile(os.path.join(self.test_dir, 'measurements.csv')) as f: df = pandas.read_csv(f) train_err = df['train/error_rate'].values[-1] # Note that upgrading to Linen made this fail at 0.6. self.assertLess(train_err, 0.7)
def test_run_lanczos(self): """Test training for two epochs on MNIST with a small model.""" rng = jax.random.PRNGKey(0) # Set the numpy seed to make the fake data deterministc. mocking.mock_data # ultimately calls numpy.random. np.random.seed(0) model_name = 'fully_connected' loss_name = 'cross_entropy' metrics_name = 'classification_metrics' initializer_name = 'noop' dataset_name = 'mnist' model_cls = models.get_model(model_name) initializer = initializers.get_initializer(initializer_name) dataset_builder = datasets.get_dataset(dataset_name) hparam_overrides = { 'lr_hparams': { 'base_lr': 0.1, 'schedule': 'cosine' }, 'batch_size': 8, 'train_size': 160, 'valid_size': 96, 'test_size': 80, } input_pipeline_hps = config_dict.ConfigDict(dict( num_tf_data_prefetches=-1, num_device_prefetches=0, num_tf_data_map_parallel_calls=-1, )) hps = hyperparameters.build_hparams( model_name, initializer_name, dataset_name, hparam_file=None, hparam_overrides=hparam_overrides, input_pipeline_hps=input_pipeline_hps) model = model_cls(hps, datasets.get_dataset_meta_data(dataset_name), loss_name, metrics_name) eval_batch_size = 16 num_examples = 256 def as_dataset(self, *args, **kwargs): del args del kwargs # pylint: disable=g-long-lambda,g-complex-comprehension return tf.data.Dataset.from_generator( lambda: ({ 'image': np.ones(shape=(28, 28, 1), dtype=np.uint8), 'label': 9, } for i in range(num_examples)), output_types=self.info.features.dtype, output_shapes=self.info.features.shape, ) # This will override the tfds.load(mnist) call to return 100 fake samples. with tfds.testing.mock_data( as_dataset_fn=as_dataset, num_examples=num_examples): dataset = dataset_builder( shuffle_rng=jax.random.PRNGKey(0), batch_size=hps.batch_size, eval_batch_size=eval_batch_size, hps=hps) num_train_steps = 41 eval_num_batches = 5 eval_every = 10 checkpoint_steps = [10, 30, 40] metrics_logger, init_logger = None, None _ = list( trainer.train( train_dir=self.test_dir, model=model, dataset_builder=lambda *unused_args, **unused_kwargs: dataset, initializer=initializer, num_train_steps=num_train_steps, hps=hps, rng=rng, eval_batch_size=eval_batch_size, eval_num_batches=eval_num_batches, eval_train_num_batches=eval_num_batches, eval_frequency=eval_every, checkpoint_steps=checkpoint_steps, metrics_logger=metrics_logger, init_logger=init_logger)) checkpoint_dir = os.path.join(self.test_dir, 'checkpoints') rng = jax.random.PRNGKey(0) run_lanczos.eval_checkpoints( checkpoint_dir, hps, rng, eval_num_batches, model_cls=model_cls, dataset_builder=lambda *unused_args, **unused_kwargs: dataset, dataset_meta_data=datasets.get_dataset_meta_data(dataset_name), hessian_eval_config=hessian_eval.DEFAULT_EVAL_CONFIG, ) # Load the saved file. hessian_dir = os.path.join(checkpoint_dir, 'hessian', 'training_metrics') pytree_list = checkpoint.load_pytree(hessian_dir) # Convert to a regular list (checkpointer will have converted the saved # list to a dict of keys '0', '1', ... pytree_list = [pytree_list[str(i)] for i in range(len(pytree_list))] # Test that the logged steps are correct. saved_steps = [row['step'] for row in pytree_list] self.assertEqual(saved_steps, checkpoint_steps)
def test_dlrm_model_trainer(self): """Tests that dlrm model training decreases loss.""" rng = jax.random.PRNGKey(1337) model_str = 'dlrm' dataset_str = 'criteo1tb' model_cls = models.get_model(model_str) model_hps = models.get_model_hparams(model_str) dataset_hps = datasets.get_dataset_hparams(dataset_str) dataset_hps.update({ 'batch_size': model_hps.batch_size, 'num_dense_features': model_hps.num_dense_features, 'vocab_sizes': model_hps.vocab_sizes, }) eval_num_batches = 5 eval_batch_size = dataset_hps.batch_size loss_name = 'sigmoid_binary_cross_entropy' metrics_name = 'binary_classification_metrics' dataset, dataset_meta_data = _get_fake_dlrm_dataset( dataset_hps.batch_size, eval_num_batches, dataset_hps) hps = copy.copy(model_hps) hps.update({ 'train_size': 15, 'valid_size': 10, 'test_size': 10, 'input_shape': (model_hps.num_dense_features + len(model_hps.vocab_sizes), ), 'output_shape': (1, ), 'l2_decay_factor': 1e-4, 'l2_decay_rank_threshold': 2, 'num_device_prefetches': 0, }) model = model_cls(hps, dataset_meta_data, loss_name, metrics_name) initializer = initializers.get_initializer('noop') metrics_logger, init_logger = utils.set_up_loggers(self.test_dir) _ = list( trainer.train( train_dir=self.test_dir, model=model, dataset_builder=lambda *unused_args, **unused_kwargs: dataset, initializer=initializer, num_train_steps=10, hps=hps, rng=rng, eval_batch_size=eval_batch_size, eval_num_batches=eval_num_batches, eval_train_num_batches=eval_num_batches, eval_frequency=2, checkpoint_steps=[], metrics_logger=metrics_logger, init_logger=init_logger)) with tf.io.gfile.GFile(os.path.join(self.test_dir, 'measurements.csv')) as f: df = pandas.read_csv(f) train_loss = df['train/ce_loss'].values self.assertLess(train_loss[-1], train_loss[0])
def test_graph_model_trainer(self): """Tests that graph model training decreases loss.""" rng = jax.random.PRNGKey(1337) model_str = 'gnn' model_cls = models.get_model(model_str) hps = models.get_model_hparams(model_str) hps.update({ 'batch_size': 2, 'input_edge_shape': (7, ), 'input_node_shape': (3, ), 'input_shape': (7, 3), 'output_shape': (5, ), 'model_dtype': 'float32', 'train_size': 15, 'valid_size': 10, 'test_size': 10, 'num_message_passing_steps': 1, 'normalizer': 'none', 'dropout_rate': 0.0, 'lr_hparams': { 'base_lr': 0.001, 'schedule': 'constant' }, 'num_device_prefetches': 0, }) eval_num_batches = 5 eval_batch_size = hps.batch_size loss_name = 'sigmoid_binary_cross_entropy' metrics_name = 'binary_classification_metrics_ogbg_map' dataset, dataset_meta_data = _get_fake_graph_dataset( batch_size=hps.batch_size, eval_num_batches=eval_num_batches, hps=hps) model = model_cls(hps, dataset_meta_data, loss_name, metrics_name) initializer = initializers.get_initializer('noop') metrics_logger, init_logger = utils.set_up_loggers(self.test_dir) _ = list( trainer.train( train_dir=self.test_dir, model=model, dataset_builder=lambda *unused_args, **unused_kwargs: dataset, initializer=initializer, num_train_steps=10, hps=hps, rng=rng, eval_batch_size=eval_batch_size, eval_num_batches=eval_num_batches, eval_train_num_batches=eval_num_batches, # Note that for some reason, moving from the deprecated to linen # Flax model API made training less stable so we need to eval more # frequently in order to get a `train_loss[0]` that is earlier in # training. eval_frequency=2, checkpoint_steps=[], metrics_logger=metrics_logger, init_logger=init_logger)) with tf.io.gfile.GFile(os.path.join(self.test_dir, 'measurements.csv')) as f: df = pandas.read_csv(f) train_loss = df['train/ce_loss'].values self.assertLess(train_loss[-1], train_loss[0])
def _run(train_fn, dataset_name, eval_batch_size, eval_num_batches, eval_train_num_batches, eval_frequency, checkpoint_steps, num_tf_data_prefetches, num_device_prefetches, num_tf_data_map_parallel_calls, early_stopping_target_name, early_stopping_target_value, early_stopping_mode, eval_steps, hparam_file, hparam_overrides, initializer_name, model_name, loss_name, metrics_name, num_train_steps, experiment_dir, worker_id, training_metrics_config, callback_configs, external_checkpoint_path): """Function that runs a Jax experiment. See flag definitions for args.""" model_cls = models.get_model(model_name) initializer = initializers.get_initializer(initializer_name) dataset_builder = datasets.get_dataset(dataset_name) dataset_meta_data = datasets.get_dataset_meta_data(dataset_name) input_pipeline_hps = config_dict.ConfigDict( dict( num_tf_data_prefetches=num_tf_data_prefetches, num_device_prefetches=num_device_prefetches, num_tf_data_map_parallel_calls=num_tf_data_map_parallel_calls, )) merged_hps = hyperparameters.build_hparams( model_name=model_name, initializer_name=initializer_name, dataset_name=dataset_name, hparam_file=hparam_file, hparam_overrides=hparam_overrides, input_pipeline_hps=input_pipeline_hps) # Note that one should never tune an RNG seed!!! The seed is only included in # the hparams for convenience of running hparam trials with multiple seeds per # point. rng_seed = merged_hps.rng_seed if merged_hps.rng_seed < 0: rng_seed = _create_synchronized_rng_seed() xm_experiment = None xm_work_unit = None if jax.process_index() == 0: logging.info('Running with seed %d', rng_seed) rng = jax.random.PRNGKey(rng_seed) # Build the loss_fn, metrics_bundle, and flax_module. model = model_cls(merged_hps, dataset_meta_data, loss_name, metrics_name) trial_dir = os.path.join(experiment_dir, str(worker_id)) meta_data_path = os.path.join(trial_dir, 'meta_data.json') meta_data = {'worker_id': worker_id, 'status': 'incomplete'} if jax.process_index() == 0: logging.info('rng: %s', rng) gfile.makedirs(trial_dir) # Set up the metric loggers for host 0. metrics_logger, init_logger = utils.set_up_loggers( trial_dir, xm_work_unit) hparams_fname = os.path.join(trial_dir, 'hparams.json') logging.info('saving hparams to %s', hparams_fname) with gfile.GFile(hparams_fname, 'w') as f: f.write(merged_hps.to_json()) _write_trial_meta_data(meta_data_path, meta_data) else: metrics_logger = None init_logger = None try: epoch_reports = list( train_fn(trial_dir, model, dataset_builder, initializer, num_train_steps, merged_hps, rng, eval_batch_size, eval_num_batches, eval_train_num_batches, eval_frequency, checkpoint_steps, early_stopping_target_name, early_stopping_target_value, early_stopping_mode, eval_steps, metrics_logger, init_logger, training_metrics_config=training_metrics_config, callback_configs=callback_configs, external_checkpoint_path=external_checkpoint_path)) logging.info(epoch_reports) meta_data['status'] = 'done' except utils.TrainingDivergedError as err: meta_data['status'] = 'diverged' raise err finally: if jax.process_index() == 0: _write_trial_meta_data(meta_data_path, meta_data)
def test_trainer(self): """Test training for two epochs on MNIST with a small model.""" rng = jax.random.PRNGKey(0) # Set the numpy seed to make the fake data deterministc. mocking.mock_data # ultimately calls numpy.random. np.random.seed(0) model_name = 'fully_connected' loss_name = 'cross_entropy' metrics_name = 'classification_metrics' initializer_name = 'noop' dataset_name = 'mnist' model_cls = models.get_model(model_name) initializer = initializers.get_initializer(initializer_name) dataset_builder = datasets.get_dataset(dataset_name) hparam_overrides = { 'lr_hparams': { 'base_lr': 0.1, 'schedule': 'cosine' }, 'batch_size': 8, 'train_size': 160, 'valid_size': 96, 'test_size': 80, } input_pipeline_hps = config_dict.ConfigDict( dict( num_tf_data_prefetches=-1, num_device_prefetches=0, num_tf_data_map_parallel_calls=-1, )) hps = hyperparameters.build_hparams( model_name, initializer_name, dataset_name, hparam_file=None, hparam_overrides=hparam_overrides, input_pipeline_hps=input_pipeline_hps) eval_batch_size = 16 num_examples = 256 def as_dataset(self, *args, **kwargs): del args del kwargs # pylint: disable=g-long-lambda,g-complex-comprehension return tf.data.Dataset.from_generator( lambda: ({ 'image': np.ones(shape=(28, 28, 1), dtype=np.uint8), 'label': 9, } for i in range(num_examples)), output_types=self.info.features.dtype, output_shapes=self.info.features.shape, ) # This will override the tfds.load(mnist) call to return 100 fake samples. with tfds.testing.mock_data(as_dataset_fn=as_dataset, num_examples=num_examples): dataset = dataset_builder(shuffle_rng=jax.random.PRNGKey(0), batch_size=hps.batch_size, eval_batch_size=eval_batch_size, hps=hps) model = model_cls(hps, datasets.get_dataset_meta_data(dataset_name), loss_name, metrics_name) num_train_steps = 40 eval_num_batches = 5 eval_every = 10 checkpoint_steps = [1, 3, 15] metrics_logger, init_logger = utils.set_up_loggers(self.test_dir) epoch_reports = list( trainer.train( train_dir=self.test_dir, model=model, dataset_builder=lambda *unused_args, **unused_kwargs: dataset, initializer=initializer, num_train_steps=num_train_steps, hps=hps, rng=rng, eval_batch_size=eval_batch_size, eval_num_batches=eval_num_batches, eval_train_num_batches=eval_num_batches, eval_frequency=eval_every, checkpoint_steps=checkpoint_steps, metrics_logger=metrics_logger, init_logger=init_logger)) # check that the additional checkpoints are saved. checkpoint_dir = os.path.join(self.test_dir, 'checkpoints') saved_steps = [] for f in tf.io.gfile.listdir(checkpoint_dir): if f[:5] == 'ckpt_': saved_steps.append(int(f[5:])) self.assertEqual(set(saved_steps), set(checkpoint_steps)) self.assertLen(epoch_reports, num_train_steps / eval_every) with tf.io.gfile.GFile(os.path.join(self.test_dir, 'measurements.csv')) as f: df = pandas.read_csv(f) train_err = df['train/error_rate'].values[-1] self.assertEqual(df['preemption_count'].values[-1], 0) self.assertLess(train_err, 0.9) self.assertEqual(set(df.columns.values), set(get_column_names())) model = model_cls(hps, {'apply_one_hot_in_loss': False}, loss_name, metrics_name) # Test reload from the checkpoint by increasing num_train_steps. num_train_steps_reload = 100 epoch_reports = list( trainer.train( train_dir=self.test_dir, model=model, dataset_builder=lambda *unused_args, **unused_kwargs: dataset, initializer=initializer, num_train_steps=num_train_steps_reload, hps=hps, rng=rng, eval_batch_size=eval_batch_size, eval_num_batches=eval_num_batches, eval_train_num_batches=eval_num_batches, eval_frequency=eval_every, checkpoint_steps=checkpoint_steps, metrics_logger=metrics_logger, init_logger=init_logger)) self.assertLen(epoch_reports, (num_train_steps_reload - num_train_steps) / eval_every) with tf.io.gfile.GFile(os.path.join(self.test_dir, 'measurements.csv')) as f: df = pandas.read_csv(f) train_err = df['train/error_rate'].values[-1] train_loss = df['train/ce_loss'].values[-1] self.assertLess(train_err, 0.35) self.assertLess(train_loss, 0.1) self.assertEqual(df['valid/num_examples'].values[-1], eval_num_batches * eval_batch_size) self.assertEqual(df['preemption_count'].values[-1], 1) # Check that the correct learning rate was saved in the measurements file. final_learning_rate = df['learning_rate'].values[-1] final_step = df['global_step'].values[-1] self.assertEqual(num_train_steps_reload, final_step) # final_step will be one larger than the last step used to calculate the # lr_decay, hense we plug in (final_step - 1) to the decay formula. # Note that there is a small numerical different here with np vs jnp. decay_factor = (1 + np.cos( (final_step - 1) / num_train_steps_reload * np.pi)) * 0.5 self.assertEqual(float(final_learning_rate), hps.lr_hparams['base_lr'] * decay_factor) self.assertEqual(set(df.columns.values), set(get_column_names()))