def test_early_stopping(self): """Test training early stopping on MNIST with a small model.""" rng = jax.random.PRNGKey(0) # Set the numpy seed to make the fake data deterministc. mocking.mock_data # ultimately calls numpy.random. np.random.seed(0) model_name = 'fully_connected' loss_name = 'cross_entropy' metrics_name = 'classification_metrics' initializer_name = 'noop' dataset_name = 'mnist' model_cls = models.get_model(model_name) initializer = initializers.get_initializer(initializer_name) dataset_builder = datasets.get_dataset(dataset_name) hparam_overrides = { 'lr_hparams': { 'base_lr': 0.1, 'schedule': 'cosine' }, 'batch_size': 8, 'train_size': 160, 'valid_size': 96, 'test_size': 80, } input_pipeline_hps = config_dict.ConfigDict( dict( num_tf_data_prefetches=-1, num_device_prefetches=0, num_tf_data_map_parallel_calls=-1, )) hps = hyperparameters.build_hparams( model_name, initializer_name, dataset_name, hparam_file=None, hparam_overrides=hparam_overrides, input_pipeline_hps=input_pipeline_hps) eval_batch_size = 16 num_examples = 256 def as_dataset(self, *args, **kwargs): del args del kwargs # pylint: disable=g-long-lambda,g-complex-comprehension return tf.data.Dataset.from_generator( lambda: ({ 'image': np.ones(shape=(28, 28, 1), dtype=np.uint8), 'label': 9, } for i in range(num_examples)), output_types=self.info.features.dtype, output_shapes=self.info.features.shape, ) # This will override the tfds.load(mnist) call to return 100 fake samples. with tfds.testing.mock_data(as_dataset_fn=as_dataset, num_examples=num_examples): dataset = dataset_builder(shuffle_rng=jax.random.PRNGKey(0), batch_size=hps.batch_size, eval_batch_size=eval_batch_size, hps=hps) model = model_cls(hps, datasets.get_dataset_meta_data(dataset_name), loss_name, metrics_name) num_train_steps = 40 early_stopping_target_name = 'test/ce_loss' early_stopping_target_value = 0.005 early_stopping_mode = 'less' eval_num_batches = 5 eval_every = 10 checkpoint_steps = [1, 3, 15] metrics_logger, init_logger = utils.set_up_loggers(self.test_dir) epoch_reports = list( trainer.train( train_dir=self.test_dir, model=model, dataset_builder=lambda *unused_args, **unused_kwargs: dataset, initializer=initializer, num_train_steps=num_train_steps, hps=hps, rng=rng, eval_batch_size=eval_batch_size, eval_num_batches=eval_num_batches, eval_train_num_batches=eval_num_batches, eval_frequency=eval_every, checkpoint_steps=checkpoint_steps, early_stopping_target_name=early_stopping_target_name, early_stopping_target_value=early_stopping_target_value, early_stopping_mode=early_stopping_mode, metrics_logger=metrics_logger, init_logger=init_logger)) self.assertLen(epoch_reports, 3) self.assertGreater(epoch_reports[-2][early_stopping_target_name], early_stopping_target_value) self.assertLess(epoch_reports[-1][early_stopping_target_name], early_stopping_target_value)
def test_trainer(self): """Test training for two epochs on MNIST with a small model.""" rng = jax.random.PRNGKey(0) # Set the numpy seed to make the fake data deterministc. mocking.mock_data # ultimately calls numpy.random. np.random.seed(0) model_name = 'fully_connected' loss_name = 'cross_entropy' metrics_name = 'classification_metrics' initializer_name = 'noop' dataset_name = 'mnist' model_cls = models.get_model(model_name) initializer = initializers.get_initializer(initializer_name) dataset_builder = datasets.get_dataset(dataset_name) hparam_overrides = { 'lr_hparams': { 'base_lr': 0.1, 'schedule': 'cosine' }, 'batch_size': 8, 'train_size': 160, 'valid_size': 96, 'test_size': 80, } input_pipeline_hps = config_dict.ConfigDict( dict( num_tf_data_prefetches=-1, num_device_prefetches=0, num_tf_data_map_parallel_calls=-1, )) hps = hyperparameters.build_hparams( model_name, initializer_name, dataset_name, hparam_file=None, hparam_overrides=hparam_overrides, input_pipeline_hps=input_pipeline_hps) eval_batch_size = 16 num_examples = 256 def as_dataset(self, *args, **kwargs): del args del kwargs # pylint: disable=g-long-lambda,g-complex-comprehension return tf.data.Dataset.from_generator( lambda: ({ 'image': np.ones(shape=(28, 28, 1), dtype=np.uint8), 'label': 9, } for i in range(num_examples)), output_types=self.info.features.dtype, output_shapes=self.info.features.shape, ) # This will override the tfds.load(mnist) call to return 100 fake samples. with tfds.testing.mock_data(as_dataset_fn=as_dataset, num_examples=num_examples): dataset = dataset_builder(shuffle_rng=jax.random.PRNGKey(0), batch_size=hps.batch_size, eval_batch_size=eval_batch_size, hps=hps) model = model_cls(hps, datasets.get_dataset_meta_data(dataset_name), loss_name, metrics_name) num_train_steps = 40 eval_num_batches = 5 eval_every = 10 checkpoint_steps = [1, 3, 15] metrics_logger, init_logger = utils.set_up_loggers(self.test_dir) epoch_reports = list( trainer.train( train_dir=self.test_dir, model=model, dataset_builder=lambda *unused_args, **unused_kwargs: dataset, initializer=initializer, num_train_steps=num_train_steps, hps=hps, rng=rng, eval_batch_size=eval_batch_size, eval_num_batches=eval_num_batches, eval_train_num_batches=eval_num_batches, eval_frequency=eval_every, checkpoint_steps=checkpoint_steps, metrics_logger=metrics_logger, init_logger=init_logger)) # check that the additional checkpoints are saved. checkpoint_dir = os.path.join(self.test_dir, 'checkpoints') saved_steps = [] for f in tf.io.gfile.listdir(checkpoint_dir): if f[:5] == 'ckpt_': saved_steps.append(int(f[5:])) self.assertEqual(set(saved_steps), set(checkpoint_steps)) self.assertLen(epoch_reports, num_train_steps / eval_every) with tf.io.gfile.GFile(os.path.join(self.test_dir, 'measurements.csv')) as f: df = pandas.read_csv(f) train_err = df['train/error_rate'].values[-1] self.assertEqual(df['preemption_count'].values[-1], 0) self.assertLess(train_err, 0.9) self.assertEqual(set(df.columns.values), set(get_column_names())) model = model_cls(hps, {'apply_one_hot_in_loss': False}, loss_name, metrics_name) # Test reload from the checkpoint by increasing num_train_steps. num_train_steps_reload = 100 epoch_reports = list( trainer.train( train_dir=self.test_dir, model=model, dataset_builder=lambda *unused_args, **unused_kwargs: dataset, initializer=initializer, num_train_steps=num_train_steps_reload, hps=hps, rng=rng, eval_batch_size=eval_batch_size, eval_num_batches=eval_num_batches, eval_train_num_batches=eval_num_batches, eval_frequency=eval_every, checkpoint_steps=checkpoint_steps, metrics_logger=metrics_logger, init_logger=init_logger)) self.assertLen(epoch_reports, (num_train_steps_reload - num_train_steps) / eval_every) with tf.io.gfile.GFile(os.path.join(self.test_dir, 'measurements.csv')) as f: df = pandas.read_csv(f) train_err = df['train/error_rate'].values[-1] train_loss = df['train/ce_loss'].values[-1] self.assertLess(train_err, 0.35) self.assertLess(train_loss, 0.1) self.assertEqual(df['valid/num_examples'].values[-1], eval_num_batches * eval_batch_size) self.assertEqual(df['preemption_count'].values[-1], 1) # Check that the correct learning rate was saved in the measurements file. final_learning_rate = df['learning_rate'].values[-1] final_step = df['global_step'].values[-1] self.assertEqual(num_train_steps_reload, final_step) # final_step will be one larger than the last step used to calculate the # lr_decay, hense we plug in (final_step - 1) to the decay formula. # Note that there is a small numerical different here with np vs jnp. decay_factor = (1 + np.cos( (final_step - 1) / num_train_steps_reload * np.pi)) * 0.5 self.assertEqual(float(final_learning_rate), hps.lr_hparams['base_lr'] * decay_factor) self.assertEqual(set(df.columns.values), set(get_column_names()))
def test_text_model_trainer(self): """Test training of a small transformer model on fake data.""" rng = jax.random.PRNGKey(42) # Set the numpy seed to make the fake data deterministc. mocking.mock_data # ultimately calls numpy.random. np.random.seed(0) model_cls = models.get_model('transformer') loss_name = 'cross_entropy' metrics_name = 'classification_metrics' hps = config_dict.ConfigDict({ # Architecture Hparams. 'batch_size': _TEXT_BATCH_SIZE, 'emb_dim': 32, 'num_heads': 2, 'num_layers': 3, 'qkv_dim': 32, 'mlp_dim': 64, 'max_target_length': 64, 'max_eval_target_length': 64, 'input_shape': (64, ), 'output_shape': (_VOCAB_SIZE, ), 'dropout_rate': 0.1, 'attention_dropout_rate': 0.1, 'layer_rescale_factors': {}, 'optimizer': 'momentum', 'normalizer': 'layer_norm', 'opt_hparams': { 'momentum': 0.9, }, 'lr_hparams': { 'base_lr': 0.005, 'schedule': 'constant' }, # Training HParams. 'l2_decay_factor': 1e-4, 'l2_decay_rank_threshold': 2, 'train_size': _TEXT_TRAIN_SIZE, 'gradient_clipping': 0.0, 'model_dtype': 'float32', 'decode': False, 'num_device_prefetches': 0, }) initializer = initializers.get_initializer('noop') eval_num_batches = 5 dataset, dataset_meta_data = _get_fake_text_dataset( batch_size=hps.batch_size, eval_num_batches=eval_num_batches) eval_batch_size = hps.batch_size model = model_cls(hps, dataset_meta_data, loss_name, metrics_name) eval_every = 10 checkpoint_steps = [] num_train_steps = _TEXT_TRAIN_SIZE // _TEXT_BATCH_SIZE * 3 metrics_logger, init_logger = utils.set_up_loggers(self.test_dir) _ = list( trainer.train( train_dir=self.test_dir, model=model, dataset_builder=lambda *unused_args, **unused_kwargs: dataset, initializer=initializer, num_train_steps=num_train_steps, hps=hps, rng=rng, eval_batch_size=eval_batch_size, eval_num_batches=eval_num_batches, eval_train_num_batches=eval_num_batches, eval_frequency=eval_every, checkpoint_steps=checkpoint_steps, metrics_logger=metrics_logger, init_logger=init_logger)) with tf.io.gfile.GFile(os.path.join(self.test_dir, 'measurements.csv')) as f: df = pandas.read_csv(f) train_err = df['train/error_rate'].values[-1] # Note that upgrading to Linen made this fail at 0.6. self.assertLess(train_err, 0.7) self.assertEqual(set(df.columns.values), set(get_column_names())) prev_train_err = train_err # Test reload from the checkpoint by increasing num_train_steps. num_train_steps_reload = _TEXT_TRAIN_SIZE // _TEXT_BATCH_SIZE * 6 _ = list( trainer.train( train_dir=self.test_dir, model=model, dataset_builder=lambda *unused_args, **unused_kwargs: dataset, initializer=initializer, num_train_steps=num_train_steps_reload, hps=hps, rng=rng, eval_batch_size=eval_batch_size, eval_num_batches=eval_num_batches, eval_train_num_batches=eval_num_batches, eval_frequency=eval_every, checkpoint_steps=checkpoint_steps, metrics_logger=metrics_logger, init_logger=init_logger)) with tf.io.gfile.GFile(os.path.join(self.test_dir, 'measurements.csv')) as f: df = pandas.read_csv(f) train_err = df['train/error_rate'].values[-1] train_loss = df['train/ce_loss'].values[-1] # Note that upgrading to Linen made this fail at 0.45. self.assertLess(train_err, 0.67) self.assertLess(train_err, prev_train_err) # Note that upgrading to Linen made this fail at 0.9. self.assertLess(train_loss, 1.35) self.assertEqual(df['valid/num_examples'].values[-1], eval_num_batches * eval_batch_size * _MAX_LEN) # Check that the correct learning rate was saved in the measurements file. final_step = df['global_step'].values[-1] self.assertEqual(num_train_steps_reload, final_step) self.assertEqual(set(df.columns.values), set(get_column_names()))
def test_dlrm_model_trainer(self): """Tests that dlrm model training decreases loss.""" rng = jax.random.PRNGKey(1337) model_str = 'dlrm' dataset_str = 'criteo1tb' model_cls = models.get_model(model_str) model_hps = models.get_model_hparams(model_str) dataset_hps = datasets.get_dataset_hparams(dataset_str) dataset_hps.update({ 'batch_size': model_hps.batch_size, 'num_dense_features': model_hps.num_dense_features, 'vocab_sizes': model_hps.vocab_sizes, }) eval_num_batches = 5 eval_batch_size = dataset_hps.batch_size loss_name = 'sigmoid_binary_cross_entropy' metrics_name = 'binary_classification_metrics' dataset, dataset_meta_data = _get_fake_dlrm_dataset( dataset_hps.batch_size, eval_num_batches, dataset_hps) hps = copy.copy(model_hps) hps.update({ 'train_size': 15, 'valid_size': 10, 'test_size': 10, 'input_shape': (model_hps.num_dense_features + len(model_hps.vocab_sizes), ), 'output_shape': (1, ), 'l2_decay_factor': 1e-4, 'l2_decay_rank_threshold': 2, 'num_device_prefetches': 0, }) model = model_cls(hps, dataset_meta_data, loss_name, metrics_name) initializer = initializers.get_initializer('noop') metrics_logger, init_logger = utils.set_up_loggers(self.test_dir) _ = list( trainer.train( train_dir=self.test_dir, model=model, dataset_builder=lambda *unused_args, **unused_kwargs: dataset, initializer=initializer, num_train_steps=10, hps=hps, rng=rng, eval_batch_size=eval_batch_size, eval_num_batches=eval_num_batches, eval_train_num_batches=eval_num_batches, eval_frequency=2, checkpoint_steps=[], metrics_logger=metrics_logger, init_logger=init_logger)) with tf.io.gfile.GFile(os.path.join(self.test_dir, 'measurements.csv')) as f: df = pandas.read_csv(f) train_loss = df['train/ce_loss'].values self.assertLess(train_loss[-1], train_loss[0])
def test_graph_model_trainer(self): """Tests that graph model training decreases loss.""" rng = jax.random.PRNGKey(1337) model_str = 'gnn' model_cls = models.get_model(model_str) hps = models.get_model_hparams(model_str) hps.update({ 'batch_size': 2, 'input_edge_shape': (7, ), 'input_node_shape': (3, ), 'input_shape': (7, 3), 'output_shape': (5, ), 'model_dtype': 'float32', 'train_size': 15, 'valid_size': 10, 'test_size': 10, 'num_message_passing_steps': 1, 'normalizer': 'none', 'dropout_rate': 0.0, 'lr_hparams': { 'base_lr': 0.001, 'schedule': 'constant' }, 'num_device_prefetches': 0, }) eval_num_batches = 5 eval_batch_size = hps.batch_size loss_name = 'sigmoid_binary_cross_entropy' metrics_name = 'binary_classification_metrics_ogbg_map' dataset, dataset_meta_data = _get_fake_graph_dataset( batch_size=hps.batch_size, eval_num_batches=eval_num_batches, hps=hps) model = model_cls(hps, dataset_meta_data, loss_name, metrics_name) initializer = initializers.get_initializer('noop') metrics_logger, init_logger = utils.set_up_loggers(self.test_dir) _ = list( trainer.train( train_dir=self.test_dir, model=model, dataset_builder=lambda *unused_args, **unused_kwargs: dataset, initializer=initializer, num_train_steps=10, hps=hps, rng=rng, eval_batch_size=eval_batch_size, eval_num_batches=eval_num_batches, eval_train_num_batches=eval_num_batches, # Note that for some reason, moving from the deprecated to linen # Flax model API made training less stable so we need to eval more # frequently in order to get a `train_loss[0]` that is earlier in # training. eval_frequency=2, checkpoint_steps=[], metrics_logger=metrics_logger, init_logger=init_logger)) with tf.io.gfile.GFile(os.path.join(self.test_dir, 'measurements.csv')) as f: df = pandas.read_csv(f) train_loss = df['train/ce_loss'].values self.assertLess(train_loss[-1], train_loss[0])
def test_shampoo_wrn(self): """Test distributed shampoo on fake dataset.""" model_name = 'simple_cnn' model_cls = models.get_model(model_name) hparam_overrides = { 'optimizer': 'distributed_shampoo', 'batch_size': 1, 'train_size': 10, 'valid_size': 10, 'input_shape': (32, 32, 3), 'output_shape': (10,), 'opt_hparams': { 'block_size': 32, 'beta1': 0.9, 'beta2': 0.999, 'diagonal_epsilon': 1e-10, 'matrix_epsilon': 1e-6, 'weight_decay': 0.0, 'start_preconditioning_step': 5, 'preconditioning_compute_steps': 1, 'statistics_compute_steps': 1, 'best_effort_shape_interpretation': True, 'graft_type': distributed_shampoo.GraftingType.SGD, 'nesterov': True, 'exponent_override': 0, 'batch_axis_name': 'batch', 'num_devices_for_pjit': None, 'shard_optimizer_states': False, 'inverse_failure_threshold': 0.1, 'clip_by_scaled_gradient_norm': None, 'precision': lax.Precision.HIGHEST, 'moving_average_for_momentum': False, 'skip_preconditioning_dim_size_gt': 4096, 'best_effort_memory_usage_reduction': False, }, } input_pipeline_hps = config_dict.ConfigDict(dict( num_tf_data_prefetches=-1, num_device_prefetches=0, num_tf_data_map_parallel_calls=-1, )) hps = hyperparameters.build_hparams( model_name, initializer_name='noop', dataset_name='fake', hparam_file=None, hparam_overrides=hparam_overrides, input_pipeline_hps=input_pipeline_hps) initializer = initializers.get_initializer('noop') dataset_builder = datasets.get_dataset('fake') dataset = dataset_builder( shuffle_rng=jax.random.PRNGKey(0), batch_size=hps.batch_size, eval_batch_size=hps.batch_size, hps=hps) loss_name = 'cross_entropy' metrics_name = 'classification_metrics' dataset_meta_data = datasets.get_dataset_meta_data('fake') model = model_cls(hps, dataset_meta_data, loss_name, metrics_name) metrics_logger, init_logger = utils.set_up_loggers(self.test_dir) _ = list( trainer.train( train_dir=self.test_dir, model=model, dataset_builder=lambda *unused_args, **unused_kwargs: dataset, initializer=initializer, num_train_steps=1, hps=hps, rng=jax.random.PRNGKey(42), eval_batch_size=hps.batch_size, eval_num_batches=None, eval_train_num_batches=None, eval_frequency=10, checkpoint_steps=[], metrics_logger=metrics_logger, init_logger=init_logger)) with tf.io.gfile.GFile(os.path.join(self.test_dir, 'measurements.csv')) as f: df = pandas.read_csv(f) valid_ce_loss = df['valid/ce_loss'].values[-1] self.assertLess(valid_ce_loss, 1e-3)