def test_cifar10(self): """Test example generation in CIFAR10 is reproducible.""" dataset = small_image_datasets.get_cifar10( random.PRNGKey(0), 1, 1, config_dict.ConfigDict( dict(flip_probability=0.5, alpha=1.0, crop_num_pixels=4, use_mixup=True, train_size=45000, valid_size=5000, test_size=10000, include_example_keys=True, input_shape=(32, 32, 3), output_shape=(10, )))) examples = itertools.islice(dataset.valid_epoch(), 10) example_keys = [example['example_key'][0] for example in examples] self.assertEqual(example_keys, [ b'cifar10-train.tfrecord-00000-of-00001__45000', b'cifar10-train.tfrecord-00000-of-00001__45001', b'cifar10-train.tfrecord-00000-of-00001__45002', b'cifar10-train.tfrecord-00000-of-00001__45003', b'cifar10-train.tfrecord-00000-of-00001__45004', b'cifar10-train.tfrecord-00000-of-00001__45005', b'cifar10-train.tfrecord-00000-of-00001__45006', b'cifar10-train.tfrecord-00000-of-00001__45007', b'cifar10-train.tfrecord-00000-of-00001__45008', b'cifar10-train.tfrecord-00000-of-00001__45009', ])
def test_schedule_stretching(self): """Test that schedules can be properly stretched.""" max_training_steps = 100 lr_hparams = config_dict.ConfigDict({ 'schedule': 'mlperf_polynomial', 'base_lr': 10.0, 'warmup_steps': 10, 'decay_end': -1, 'end_lr': 1e-4, 'power': 2.0, 'start_lr': 0.0, 'warmup_power': 1.0, }) lr_fn = schedules.get_schedule_fn(lr_hparams, max_training_steps) stretch_factor = 3 stretched_lr_fn = schedules.get_schedule_fn( lr_hparams, max_training_steps, stretch_factor=stretch_factor) lrs = [lr_fn(t) for t in range(max_training_steps)] stretched_lrs = [ stretched_lr_fn(t) for t in range(stretch_factor * max_training_steps) ] self.assertEqual(lrs, stretched_lrs[::stretch_factor]) self.assertEqual(lrs, stretched_lrs[1::stretch_factor]) self.assertEqual(lrs, stretched_lrs[2::stretch_factor]) # Assert that the stretched schedule has proper staircase behavior. for update_step in range(max_training_steps): start = update_step * stretch_factor end = (update_step + 1) * stretch_factor expected = [lrs[update_step]] * stretch_factor self.assertEqual(stretched_lrs[start:end], expected)
def test_raises(self): """Test that an exception is raised with extra hparams.""" good_hps = config_dict.ConfigDict( dict( lr_hparams={ 'schedule': 'mlperf_polynomial', 'base_lr': .1, 'warmup_steps': 200, 'decay_end': -1, 'end_lr': 1e-4, 'power': 2.0, 'start_lr': 0.0, 'warmup_power': 1.0, })) bad_hps = config_dict.ConfigDict( dict( lr_hparams={ 'schedule': 'mlperf_polynomial', 'base_lr': .1, 'warmup_steps': 200, 'initial_value': .1, 'decay_end': -1, 'end_lr': 1e-4, 'power': 2.0, 'start_lr': 0.0, })) bad_hps2 = config_dict.ConfigDict( dict( lr_hparams={ 'schedule': 'polynomial', 'power': 2.0, 'initial_value': .1, 'end_factor': .01, 'decay_steps': 200, 'decay_steps_factor': 0.5 })) # This should pass. schedules.get_schedule_fn(good_hps.lr_hparams, 1) # This should raise an exception due to the extra hparam. with self.assertRaises(ValueError): schedules.get_schedule_fn(bad_hps.lr_hparams, 1) # This should raise an exception due to the mutually exclusive hparams. with self.assertRaises(ValueError): schedules.get_schedule_fn(bad_hps2.lr_hparams, 1)
def main(unused_argv): # Necessary to use the tfds imagenet loader. tf.enable_v2_behavior() rng = jax.random.PRNGKey(FLAGS.seed) if FLAGS.hessian_eval_config: hessian_eval_config = json.loads(FLAGS.hessian_eval_config) else: hessian_eval_config = hessian_eval.DEFAULT_EVAL_CONFIG if FLAGS.experiment_config_filename: with tf.io.gfile.GFile(FLAGS.experiment_config_filename, 'r') as f: experiment_config = json.load(f) if jax.process_index() == 0: logging.info('experiment_config: %r', experiment_config) dataset_name = experiment_config['dataset'] model_name = experiment_config['model'] else: assert FLAGS.dataset and FLAGS.model dataset_name = FLAGS.dataset model_name = FLAGS.model if jax.process_index() == 0: logging.info('argv:\n%s', ' '.join(sys.argv)) logging.info('device_count: %d', jax.device_count()) logging.info('num_hosts : %d', jax.process_count()) logging.info('host_id : %d', jax.process_index()) model = models.get_model(model_name) dataset_builder = datasets.get_dataset(dataset_name) dataset_meta_data = datasets.get_dataset_meta_data(dataset_name) with tf.io.gfile.GFile(FLAGS.trial_hparams_filename, 'r') as f: hps = config_dict.ConfigDict(json.load(f)) if FLAGS.hparam_overrides: if isinstance(FLAGS.hparam_overrides, str): hparam_overrides = json.loads(FLAGS.hparam_overrides) hps.update_from_flattened_dict(hparam_overrides) run_lanczos.eval_checkpoints( FLAGS.checkpoint_dir, hps, rng, FLAGS.eval_num_batches, model, dataset_builder, dataset_meta_data, hessian_eval_config, FLAGS.min_global_step, FLAGS.max_global_step)
def _get_dataset(self, hps, rng): """Sets ups dataset builders.""" hparams_dict = hps.to_dict() hparams_dict.update(self.callback_config) hparams = config_dict.ConfigDict(hparams_dict) dataset_builder = datasets.get_dataset( self.callback_config['dataset_name']) dataset = dataset_builder( rng, hparams.batch_size, eval_batch_size=self.callback_config['eval_batch_size'], hps=hparams) return dataset
def test_nqm(self): """Test the noisy quadratic model.""" batch_size = 2 dim = 10 model_hps = config_dict.ConfigDict( dict( input_shape=(dim, ), output_shape=(1, ), rng_seed=-1, hessian_decay_power=1.0, noise_decay_power=1.0, nqm_mode='diagH_diagC', model_dtype='float32', )) model_cls = models.get_model('nqm') rng = jax.random.PRNGKey(0) model = model_cls(model_hps, {}) noise_eps = jnp.array(np.random.normal(size=(batch_size, dim))) rng, params_rng = jax.random.split(rng) _, flax_module = model.flax_module_def.create_by_shape( params_rng, [(batch_size, dim)]) model_x = flax_module.params['x'] def loss(model, inputs): return model(inputs) grad_loss = jax.grad(loss) hessian = np.diag( np.array([ 1.0 / np.power(i, model_hps.hessian_decay_power) for i in range(1, dim + 1) ])) noise_matrix = np.diag( np.array([ 1.0 / np.power(i, model_hps.noise_decay_power / 2.0) for i in range(1, dim + 1) ])) noise = jnp.dot(noise_eps, noise_matrix) mean_noise = np.mean(noise, axis=0) # NQM gradient = Hx + eps where eps ~ N(0, C / batch_size). expected_grad = np.dot(hessian, model_x) + mean_noise g = grad_loss(flax_module, noise_eps).params['x'] grad_error = np.sum(np.abs(g - expected_grad)) self.assertAlmostEqual(grad_error, 0.0, places=5)
def test_nqm(self): """Test the noisy quadratic model.""" batch_size = 2 dim = 10 model_hps = config_dict.ConfigDict( dict( input_shape=(dim,), output_shape=(1,), rng_seed=-1, hessian_decay_power=1.0, noise_decay_power=1.0, nqm_mode='diagH_diagC', model_dtype='float32', )) model_cls = models.get_model('nqm') params_rng = jax.random.PRNGKey(0) model = model_cls(model_hps, {}, None, None) noise_eps = jnp.array(np.random.normal(size=(batch_size, dim))) xs = np.zeros((batch_size, dim)) model_init_fn = jax.jit( functools.partial(model.flax_module.init, train=False)) params = model_init_fn({'params': params_rng}, xs)['params'] model_x = params['x'] def loss(params, inputs): return model.training_cost(params, batch=inputs) grad_loss = jax.grad(loss, has_aux=True) hessian = np.diag( np.array([ 1.0 / np.power(i, model_hps.hessian_decay_power) for i in range(1, dim + 1) ])) noise_matrix = np.diag( np.array([ 1.0 / np.power(i, model_hps.noise_decay_power / 2.0) for i in range(1, dim + 1) ])) noise = jnp.dot(noise_eps, noise_matrix) mean_noise = np.mean(noise, axis=0) # NQM gradient = Hx + eps where eps ~ N(0, C / batch_size). expected_grad = np.dot(hessian, model_x) + mean_noise g = grad_loss(params, {'inputs': noise_eps})[0]['x'] grad_error = np.sum(np.abs(g - expected_grad)) self.assertAlmostEqual(grad_error, 0.0, places=5)
def _get_dataset(shuffle_seed, additional_hps=None): """Loads the ogbg-molpcba dataset using mock data.""" with tfds.testing.mock_data(as_dataset_fn=_as_dataset): ds = 'ogbg_molpcba' dataset_builder = get_dataset(ds) hps_dict = get_dataset_hparams(ds).to_dict() if additional_hps is not None: hps_dict.update(additional_hps) hps = config_dict.ConfigDict(hps_dict) hps.train_size = 4 hps.valid_size = 4 hps.test_size = 4 hps.max_nodes_multiplier = NODES_SIZE_MULTIPLIER hps.max_edges_multiplier = EDGES_SIZE_MULTIPLIER batch_size = BATCH_SIZE eval_batch_size = BATCH_SIZE dataset = dataset_builder( shuffle_rng=shuffle_seed, batch_size=batch_size, eval_batch_size=eval_batch_size, hps=hps) return dataset
def test_polynomial_decay_decay_steps(self): """Test polynomial schedule works correctly with decay_steps.""" hps = config_dict.ConfigDict( dict( lr_hparams={ 'schedule': 'polynomial', 'power': 2.0, 'initial_value': .1, 'end_factor': .01, 'decay_steps': 200, })) max_training_steps = 400 lr_fn = schedules.get_schedule_fn(hps.lr_hparams, max_training_steps) hps = hps.lr_hparams decay_steps = hps['decay_steps'] for step in range(max_training_steps): expected_learning_rate = tf.train.polynomial_decay( hps['initial_value'], step, decay_steps, hps['end_factor'] * hps['initial_value'], power=hps['power'])().numpy() self.assertAlmostEqual(lr_fn(step), expected_learning_rate)
def test_text_models(self, model_str): """Test forward pass of the transformer model.""" # TODO(gilmer): Find a clean way to handle small test hparams. vocab_size = 16 small_hps = config_dict.ConfigDict({ # Architecture Hparams. 'batch_size': 16, 'emb_dim': 32, 'num_heads': 2, 'num_layers': 3, 'qkv_dim': 32, 'label_smoothing': 0.1, 'mlp_dim': 64, 'max_target_length': 64, 'max_eval_target_length': 64, 'dropout_rate': 0.1, 'attention_dropout_rate': 0.1, 'momentum': 0.9, 'normalizer': 'layer_norm', 'lr_hparams': { 'base_lr': 0.005, 'schedule': 'constant' }, 'output_shape': (vocab_size,), 'model_dtype': 'float32', # Training HParams. 'l2_decay_factor': 1e-4, 'decode': False, }) text_input_shape = (32, 64) # batch_size, max_target_length model_cls = models.get_model(model_str) rng = jax.random.PRNGKey(0) loss = 'cross_entropy' metrics = 'classification_metrics' model = model_cls(small_hps, { 'max_len': 64, 'shift_inputs': True, 'causal': True }, loss, metrics) xs = jnp.array( np.random.randint(size=text_input_shape, low=1, high=vocab_size)) dropout_rng, params_rng = jax.random.split(rng) model_init_fn = jax.jit( functools.partial(model.flax_module.init, train=False)) init_dict = model_init_fn({'params': params_rng}, xs) params = init_dict['params'] batch_stats = init_dict.get('batch_stats', {}) # Check that the forward pass works with mutated batch_stats. # Due to a bug in flax, this jit is required, otherwise the model errors. @jax.jit def forward_pass(params, xs, dropout_rng): outputs, new_batch_stats = model.flax_module.apply( {'params': params, 'batch_stats': batch_stats}, xs, mutable=['batch_stats'], rngs={'dropout': dropout_rng}, train=True) return outputs, new_batch_stats outputs, new_batch_stats = forward_pass(params, xs, dropout_rng) self.assertEqual(outputs.shape, (text_input_shape[0], text_input_shape[1], vocab_size)) # If it's a batch norm model check the batch stats changed. if batch_stats: bflat, _ = ravel_pytree(batch_stats) new_bflat, _ = ravel_pytree(new_batch_stats) self.assertFalse(jnp.array_equal(bflat, new_bflat)) # Test batch_norm in inference mode. outputs = model.flax_module.apply( {'params': params, 'batch_stats': batch_stats}, xs, train=False) self.assertEqual(outputs.shape, (text_input_shape[0], text_input_shape[1], vocab_size))
def test_translate_model(self): """Test forward pass of the translate model.""" vocab_size = 16 small_hps = config_dict.ConfigDict({ # Architecture Hparams. 'batch_size': 16, 'share_embeddings': False, 'logits_via_embedding': False, 'emb_dim': 32, 'num_heads': 2, 'enc_num_layers': 2, 'dec_num_layers': 2, 'qkv_dim': 32, 'label_smoothing': 0.1, 'mlp_dim': 64, 'max_target_length': 64, 'max_eval_target_length': 64, 'normalizer': 'pre_layer_norm', 'max_predict_length': 64, 'dropout_rate': 0.1, 'attention_dropout_rate': 0.1, 'momentum': 0.9, 'lr_hparams': { 'base_lr': 0.005, 'schedule': 'constant' }, 'output_shape': (vocab_size,), # Training HParams. 'l2_decay_factor': 1e-4, 'enc_self_attn_kernel_init': 'xavier_uniform', 'dec_self_attn_kernel_init': 'xavier_uniform', 'dec_cross_attn_kernel_init': 'xavier_uniform', 'decode': False, }) text_src_input_shape = (32, 64) # batch_size, max_source_length text_tgt_input_shape = (32, 40) # batch_size, max_target_length model_cls = models.get_model('xformer_translate') rng = jax.random.PRNGKey(0) loss = 'cross_entropy' metrics = 'classification_metrics' model = model_cls(small_hps, { 'shift_outputs': True, 'causal': True }, loss, metrics) xs = jnp.array( np.random.randint(size=text_src_input_shape, low=1, high=vocab_size)) ys = jnp.array( np.random.randint(size=text_tgt_input_shape, low=1, high=vocab_size)) dropout_rng, params_rng = jax.random.split(rng) model_init_fn = jax.jit( functools.partial(model.flax_module.init, train=False)) init_dict = model_init_fn({'params': params_rng}, xs, ys) params = init_dict['params'] # Test forward pass. @jax.jit def forward_pass(params, xs, ys, dropout_rng): outputs = model.flax_module.apply( {'params': params}, xs, ys, rngs={'dropout': dropout_rng}, train=True) return outputs logits = forward_pass(params, xs, ys, dropout_rng) # Testing only train mode # TODO(ankugarg): Add tests for individual encoder/decoder (inference mode). self.assertEqual( logits.shape, (text_tgt_input_shape[0], text_tgt_input_shape[1], vocab_size))
from flax import optim as optimizers from flax.core import unfreeze from init2winit.model_lib import model_utils import jax from jax import jvp import jax.numpy as jnp from ml_collections.config_dict import config_dict import numpy as np import optax # Small hparams for quicker tests. DEFAULT_HPARAMS = config_dict.ConfigDict(dict( meta_learning_rate=.1, meta_steps=50, meta_batch_size=8, epsilon=1e-5, meta_momentum=0.5, )) def _count_params(tree): return jax.tree_util.tree_reduce(operator.add, jax.tree_map(lambda x: x.size, tree)) def scale_params(params, scalars): return jax.tree_multimap(lambda w, scale: w * scale, params, scalars) def meta_loss(params_to_loss, scalars, normalized_params, epsilon):
def test_text_models(self, model_str): """Test forward pass of the transformer model.""" # TODO(gilmer): Find a clean way to handle small test hparams. vocab_size = 16 small_hps = config_dict.ConfigDict({ # Architecture Hparams. 'batch_size': 16, 'emb_dim': 32, 'num_heads': 2, 'num_layers': 3, 'qkv_dim': 32, 'label_smoothing': 0.1, 'mlp_dim': 64, 'max_target_length': 64, 'max_eval_target_length': 64, 'dropout_rate': 0.1, 'attention_dropout_rate': 0.1, 'momentum': 0.9, 'normalizer': 'layer_norm', 'lr_hparams': { 'initial_value': 0.005, 'schedule': 'constant' }, 'output_shape': (vocab_size, ), # Training HParams. 'l2_decay_factor': 1e-4 }) text_input_shape = (32, 64) # batch_size, max_target_length model_cls = models.get_model(model_str) rng = jax.random.PRNGKey(0) loss = 'cross_entropy' metrics = 'classification_metrics' model = model_cls(small_hps, { 'max_len': 64, 'shift_inputs': True, 'causal': True }, loss, metrics) xs = jnp.array( np.random.randint(size=text_input_shape, low=1, high=vocab_size)) rng, params_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng) with nn.stateful() as batch_stats: _, flax_module = model.flax_module_def.create_by_shape( params_rng, [(text_input_shape, jnp.float32)], train=False) # Check that the forward pass works with mutated batch_stats. # Due to a bug in flax, this jit is required, otherwise the model errors. @jax.jit def forward_pass(flax_module, xs, dropout_rng): with batch_stats.mutate() as new_batch_stats: with nn.stochastic(dropout_rng): return flax_module(xs, train=True), new_batch_stats outputs, new_batch_stats = forward_pass(flax_module, xs, dropout_rng) self.assertEqual( outputs.shape, (text_input_shape[0], text_input_shape[1], vocab_size)) # If it's a batch norm model check the batch stats changed. if batch_stats.as_dict(): bflat, _ = ravel_pytree(batch_stats) new_bflat, _ = ravel_pytree(new_batch_stats) self.assertFalse(jnp.array_equal(bflat, new_bflat)) # Test batch_norm in inference mode. with nn.stateful(batch_stats, mutable=False): outputs = flax_module(xs, train=False) self.assertEqual( outputs.shape, (text_input_shape[0], text_input_shape[1], vocab_size))
# small hparams used for unit tests DEFAULT_HPARAMS = config_dict.ConfigDict( dict(blocks_per_group=3, channel_multiplier=2, lr_hparams={ 'base_lr': 0.001, 'schedule': 'cosine' }, normalizer='batch_norm', layer_rescale_factors={}, conv_kernel_scale=1.0, dense_kernel_scale=1.0, dropout_rate=0.0, conv_kernel_init='lecun_normal', dense_kernel_init='lecun_normal', optimizer='momentum', opt_hparams={ 'momentum': 0.9, }, batch_size=128, virtual_batch_size=None, total_accumulated_batch_size=None, l2_decay_factor=0.0001, l2_decay_rank_threshold=2, label_smoothing=None, rng_seed=-1, use_shallue_label_smoothing=False, model_dtype='float32', grad_clip=None, activation_function='relu', group_strides=[(1, 1), (2, 2), (2, 2)]))
def test_translate_model(self): """Test forward pass of the translate model.""" vocab_size = 16 small_hps = config_dict.ConfigDict({ # Architecture Hparams. 'batch_size': 16, 'share_embeddings': False, 'logits_via_embedding': False, 'emb_dim': 32, 'num_heads': 2, 'enc_num_layers': 2, 'dec_num_layers': 2, 'qkv_dim': 32, 'label_smoothing': 0.1, 'mlp_dim': 64, 'max_target_length': 64, 'max_eval_target_length': 64, 'normalizer': 'pre_layer_norm', 'max_predict_length': 64, 'dropout_rate': 0.1, 'attention_dropout_rate': 0.1, 'momentum': 0.9, 'lr_hparams': { 'initial_value': 0.005, 'schedule': 'constant' }, 'output_shape': (vocab_size, ), # Training HParams. 'l2_decay_factor': 1e-4 }) text_src_input_shape = (32, 64) # batch_size, max_source_length text_tgt_input_shape = (32, 40) # batch_size, max_target_length model_cls = models.get_model('xformer_translate') rng = jax.random.PRNGKey(0) loss = 'cross_entropy' metrics = 'classification_metrics' model = model_cls(small_hps, { 'shift_outputs': True, 'causal': True }, loss, metrics) xs = jnp.array( np.random.randint(size=text_src_input_shape, low=1, high=vocab_size)) ys = jnp.array( np.random.randint(size=text_tgt_input_shape, low=1, high=vocab_size)) rng, params_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng) with nn.stateful() as batch_stats: _, flax_module = model.flax_module_def.create_by_shape( params_rng, [(text_src_input_shape, jnp.float32), (text_tgt_input_shape, jnp.float32)], train=False) # Test forward pass. @jax.jit def forward_pass(flax_module, xs, ys, dropout_rng): with batch_stats.mutate() as new_batch_stats: with nn.stochastic(dropout_rng): return flax_module(xs, ys, train=True), new_batch_stats logits, _ = forward_pass(flax_module, xs, ys, dropout_rng) # Testing only train mode # TODO(ankugarg): Add tests for individual encoder/decoder (inference mode). self.assertEqual( logits.shape, (text_tgt_input_shape[0], text_tgt_input_shape[1], vocab_size))
def test_trainer(self): """Test training for two epochs on MNIST with a small model.""" rng = jax.random.PRNGKey(0) # Set the numpy seed to make the fake data deterministc. mocking.mock_data # ultimately calls numpy.random. np.random.seed(0) model_name = 'fully_connected' loss_name = 'cross_entropy' metrics_name = 'classification_metrics' initializer_name = 'noop' dataset_name = 'mnist' model_cls = models.get_model(model_name) initializer = initializers.get_initializer(initializer_name) dataset_builder = datasets.get_dataset(dataset_name) hparam_overrides = { 'lr_hparams': { 'base_lr': 0.1, 'schedule': 'cosine' }, 'batch_size': 8, 'train_size': 160, 'valid_size': 96, 'test_size': 80, } input_pipeline_hps = config_dict.ConfigDict( dict( num_tf_data_prefetches=-1, num_device_prefetches=0, num_tf_data_map_parallel_calls=-1, )) hps = hyperparameters.build_hparams( model_name, initializer_name, dataset_name, hparam_file=None, hparam_overrides=hparam_overrides, input_pipeline_hps=input_pipeline_hps) eval_batch_size = 16 num_examples = 256 def as_dataset(self, *args, **kwargs): del args del kwargs # pylint: disable=g-long-lambda,g-complex-comprehension return tf.data.Dataset.from_generator( lambda: ({ 'image': np.ones(shape=(28, 28, 1), dtype=np.uint8), 'label': 9, } for i in range(num_examples)), output_types=self.info.features.dtype, output_shapes=self.info.features.shape, ) # This will override the tfds.load(mnist) call to return 100 fake samples. with tfds.testing.mock_data(as_dataset_fn=as_dataset, num_examples=num_examples): dataset = dataset_builder(shuffle_rng=jax.random.PRNGKey(0), batch_size=hps.batch_size, eval_batch_size=eval_batch_size, hps=hps) model = model_cls(hps, datasets.get_dataset_meta_data(dataset_name), loss_name, metrics_name) num_train_steps = 40 eval_num_batches = 5 eval_every = 10 checkpoint_steps = [1, 3, 15] metrics_logger, init_logger = utils.set_up_loggers(self.test_dir) epoch_reports = list( trainer.train( train_dir=self.test_dir, model=model, dataset_builder=lambda *unused_args, **unused_kwargs: dataset, initializer=initializer, num_train_steps=num_train_steps, hps=hps, rng=rng, eval_batch_size=eval_batch_size, eval_num_batches=eval_num_batches, eval_train_num_batches=eval_num_batches, eval_frequency=eval_every, checkpoint_steps=checkpoint_steps, metrics_logger=metrics_logger, init_logger=init_logger)) # check that the additional checkpoints are saved. checkpoint_dir = os.path.join(self.test_dir, 'checkpoints') saved_steps = [] for f in tf.io.gfile.listdir(checkpoint_dir): if f[:5] == 'ckpt_': saved_steps.append(int(f[5:])) self.assertEqual(set(saved_steps), set(checkpoint_steps)) self.assertLen(epoch_reports, num_train_steps / eval_every) with tf.io.gfile.GFile(os.path.join(self.test_dir, 'measurements.csv')) as f: df = pandas.read_csv(f) train_err = df['train/error_rate'].values[-1] self.assertEqual(df['preemption_count'].values[-1], 0) self.assertLess(train_err, 0.9) self.assertEqual(set(df.columns.values), set(get_column_names())) model = model_cls(hps, {'apply_one_hot_in_loss': False}, loss_name, metrics_name) # Test reload from the checkpoint by increasing num_train_steps. num_train_steps_reload = 100 epoch_reports = list( trainer.train( train_dir=self.test_dir, model=model, dataset_builder=lambda *unused_args, **unused_kwargs: dataset, initializer=initializer, num_train_steps=num_train_steps_reload, hps=hps, rng=rng, eval_batch_size=eval_batch_size, eval_num_batches=eval_num_batches, eval_train_num_batches=eval_num_batches, eval_frequency=eval_every, checkpoint_steps=checkpoint_steps, metrics_logger=metrics_logger, init_logger=init_logger)) self.assertLen(epoch_reports, (num_train_steps_reload - num_train_steps) / eval_every) with tf.io.gfile.GFile(os.path.join(self.test_dir, 'measurements.csv')) as f: df = pandas.read_csv(f) train_err = df['train/error_rate'].values[-1] train_loss = df['train/ce_loss'].values[-1] self.assertLess(train_err, 0.35) self.assertLess(train_loss, 0.1) self.assertEqual(df['valid/num_examples'].values[-1], eval_num_batches * eval_batch_size) self.assertEqual(df['preemption_count'].values[-1], 1) # Check that the correct learning rate was saved in the measurements file. final_learning_rate = df['learning_rate'].values[-1] final_step = df['global_step'].values[-1] self.assertEqual(num_train_steps_reload, final_step) # final_step will be one larger than the last step used to calculate the # lr_decay, hense we plug in (final_step - 1) to the decay formula. # Note that there is a small numerical different here with np vs jnp. decay_factor = (1 + np.cos( (final_step - 1) / num_train_steps_reload * np.pi)) * 0.5 self.assertEqual(float(final_learning_rate), hps.lr_hparams['base_lr'] * decay_factor) self.assertEqual(set(df.columns.values), set(get_column_names()))
import tensorflow.compat.v2 as tf import tensorflow_datasets as tfds CROP_PADDING = 32 MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255] STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255] DEFAULT_HPARAMS = config_dict.ConfigDict(dict( input_shape=(224, 224, 3), output_shape=(1000,), train_size=1281167, valid_size=50000, use_inception_crop=False, use_mixup=False, mixup={'alpha': 0.5}, use_randaug=False, randaug={ 'magnitude': 15, 'num_layers': 2 })) # pylint:disable=raise-missing-from METADATA = { 'apply_one_hot_in_loss': False, } def distorted_bounding_box_crop(image_bytes, bbox,
FAKE_MODEL_DEFAULT_HPARAMS = config_dict.ConfigDict( dict( num_filters=16, num_layers=18, # Must be one of [18, 34, 50, 101, 152, 200] layer_rescale_factors={}, lr_hparams={ 'batch_size': 128, 'base_lr': 10.0, 'decay_end': -1, 'end_lr': 1e-4, 'power': 2.0, 'schedule': 'mlperf_polynomial', 'start_lr': 0.0, 'steps_per_epoch': 10009.250000000002, 'warmup_steps': 18, }, optimizer='mlperf_lars_resnet', opt_hparams={ 'weight_decay': 2e-4, 'beta': 0.9 }, batch_size=128, l2_decay_factor=None, l2_decay_rank_threshold=2, label_smoothing=.1, use_shallue_label_smoothing=False, model_dtype='float32', virtual_batch_size=64, data_format='NHWC', ))
def test_mlperf_schedule(self): """Test there are no changes to the MLPerf polynomial decay schedule.""" expected_lrs = [ 0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0, 2.2, 2.4, 2.6, 2.8, 3.0, 3.2, 3.4, 3.6, 3.8, 4.0, 4.2, 4.4, 4.6, 4.8, 5.0, 5.2, 5.4, 5.6, 5.8, 6.0, 6.2, 6.4, 6.6, 6.8, 7.0, 7.2, 7.4, 7.6, 7.8, 8.0, 8.2, 8.4, 8.6, 8.8, 9.0, 9.2, 9.4, 9.6, 9.8, 10.0, 9.802962, 9.607885, 9.414769, 9.223614, 9.034419, 8.847184, 8.661909, 8.478596, 8.297242, 8.117851, 7.940418, 7.764947, 7.5914364, 7.419886, 7.2502966, 7.082668, 6.917, 6.7532916, 6.591545, 6.4317584, 6.273932, 6.1180663, 5.964162, 5.812217, 5.662234, 5.5142093, 5.368148, 5.2240453, 5.0819044, 4.941723, 4.803503, 4.6672425, 4.532944, 4.4006047, 4.2702274, 4.1418095, 4.0153522, 3.8908558, 3.7683203, 3.647745, 3.5291305, 3.4124763, 3.297783, 3.18505, 3.0742776, 2.965466, 2.858614, 2.753724, 2.6507936, 2.5498245, 2.4508152, 2.353767, 2.2586792, 2.1655521, 2.0743854, 1.9851794, 1.8979341, 1.8126491, 1.7293249, 1.6479613, 1.568558, 1.4911155, 1.4156334, 1.342112, 1.2705511, 1.2009507, 1.133311, 1.0676318, 1.0039133, 0.94215524, 0.8823574, 0.8245205, 0.7686443, 0.71472853, 0.66277343, 0.61277884, 0.56474483, 0.5186714, 0.4745585, 0.43240622, 0.39221448, 0.35398334, 0.31771275, 0.28340274, 0.2510533, 0.22066444, 0.19223614, 0.16576843, 0.14126128, 0.11871469, 0.098128565, 0.079503134, 0.06283828, 0.048134, 0.035390284, 0.02460714, 0.01578457, 0.00892257, 0.004021142, ] hps = config_dict.ConfigDict( dict( lr_hparams={ 'schedule': 'mlperf_polynomial', 'base_lr': 10.0, 'warmup_steps': 50, 'decay_end': -1, 'end_lr': 1e-4, 'power': 2.0, 'start_lr': 0.0, 'warmup_power': 1.0, })) max_training_steps = 50 lr_fn = schedules.get_schedule_fn(hps.lr_hparams, max_training_steps) for step in range(max_training_steps): self.assertAlmostEqual(lr_fn(step), expected_lrs[step])
import itertools from init2winit.dataset_lib import data_utils from init2winit.dataset_lib import mlperf_input_pipeline import jax from ml_collections.config_dict import config_dict import tensorflow.compat.v2 as tf NUM_CLASSES = 1000 DEFAULT_HPARAMS = config_dict.ConfigDict(dict( input_shape=(224, 224, 3), output_shape=(NUM_CLASSES,), train_size=1281167, valid_size=50000)) METADATA = { 'apply_one_hot_in_loss': False, } def transpose_and_normalize_image(image): mean = tf.constant([[mlperf_input_pipeline.MEAN_RGB]], dtype=image.dtype) stddev = tf.constant( [[mlperf_input_pipeline.STDDEV_RGB]], dtype=image.dtype) image -= mean image /= stddev return image
loss_fn=None, flax_module=None, params=None, hps=None, input_shape=None, output_shape=None, rng_key=None, metrics_logger=None, ): """No-op init.""" return params # pylint: enable=unused-argument DEFAULT_HPARAMS = config_dict.ConfigDict() _ALL_INITIALIZERS = { 'noop': (noop, DEFAULT_HPARAMS), 'meta_init': (meta_init.meta_init, meta_init.DEFAULT_HPARAMS), 'sparse_init': (sparse_init.sparse_init, sparse_init.DEFAULT_HPARAMS), } def get_initializer(initializer_name): """Get the corresponding initializer function based on the initializer string. API of an initializer: init_fn, hparams = get_initializer(init) new_params, final_l = init_fn(loss, init_params, hps, num_outputs, input_shape)
from ml_collections.config_dict import config_dict # small hparams used for unit tests DEFAULT_HPARAMS = config_dict.ConfigDict(dict( hid_sizes=[20, 10], kernel_scales=[1.0, 1.0, 1.0], lr_hparams={ 'initial_value': 0.1, 'schedule': 'constant' }, layer_rescale_factors={}, optimizer='momentum', opt_hparams={ 'momentum': 0.9, }, batch_size=128, activation_function='relu', l2_decay_factor=.0005, l2_decay_rank_threshold=2, label_smoothing=None, rng_seed=-1, use_shallue_label_smoothing=False, model_dtype='float32', )) class FullyConnected(nn.Module): """Defines a fully connected neural network.
from init2winit.model_lib import model_utils import jax.numpy as jnp from ml_collections.config_dict import config_dict DEFAULT_HPARAMS = config_dict.ConfigDict( dict( num_layers=11, # Must be one of [11, 13, 16, 19] layer_rescale_factors={}, lr_hparams={ 'schedule': 'constant', 'initial_value': 0.2, }, normalizer='none', optimizer='momentum', opt_hparams={ 'momentum': 0.9, }, batch_size=128, l2_decay_factor=0.0001, l2_decay_rank_threshold=2, label_smoothing=None, rng_seed=-1, use_shallue_label_smoothing=False, model_dtype='float32', )) def classifier(x, num_outputs, dropout_rate, deterministic): """Implements the classification portion of the network."""
def test_early_stopping(self): """Test training early stopping on MNIST with a small model.""" rng = jax.random.PRNGKey(0) # Set the numpy seed to make the fake data deterministc. mocking.mock_data # ultimately calls numpy.random. np.random.seed(0) model_name = 'fully_connected' loss_name = 'cross_entropy' metrics_name = 'classification_metrics' initializer_name = 'noop' dataset_name = 'mnist' model_cls = models.get_model(model_name) initializer = initializers.get_initializer(initializer_name) dataset_builder = datasets.get_dataset(dataset_name) hparam_overrides = { 'lr_hparams': { 'base_lr': 0.1, 'schedule': 'cosine' }, 'batch_size': 8, 'train_size': 160, 'valid_size': 96, 'test_size': 80, } input_pipeline_hps = config_dict.ConfigDict( dict( num_tf_data_prefetches=-1, num_device_prefetches=0, num_tf_data_map_parallel_calls=-1, )) hps = hyperparameters.build_hparams( model_name, initializer_name, dataset_name, hparam_file=None, hparam_overrides=hparam_overrides, input_pipeline_hps=input_pipeline_hps) eval_batch_size = 16 num_examples = 256 def as_dataset(self, *args, **kwargs): del args del kwargs # pylint: disable=g-long-lambda,g-complex-comprehension return tf.data.Dataset.from_generator( lambda: ({ 'image': np.ones(shape=(28, 28, 1), dtype=np.uint8), 'label': 9, } for i in range(num_examples)), output_types=self.info.features.dtype, output_shapes=self.info.features.shape, ) # This will override the tfds.load(mnist) call to return 100 fake samples. with tfds.testing.mock_data(as_dataset_fn=as_dataset, num_examples=num_examples): dataset = dataset_builder(shuffle_rng=jax.random.PRNGKey(0), batch_size=hps.batch_size, eval_batch_size=eval_batch_size, hps=hps) model = model_cls(hps, datasets.get_dataset_meta_data(dataset_name), loss_name, metrics_name) num_train_steps = 40 early_stopping_target_name = 'test/ce_loss' early_stopping_target_value = 0.005 early_stopping_mode = 'less' eval_num_batches = 5 eval_every = 10 checkpoint_steps = [1, 3, 15] metrics_logger, init_logger = utils.set_up_loggers(self.test_dir) epoch_reports = list( trainer.train( train_dir=self.test_dir, model=model, dataset_builder=lambda *unused_args, **unused_kwargs: dataset, initializer=initializer, num_train_steps=num_train_steps, hps=hps, rng=rng, eval_batch_size=eval_batch_size, eval_num_batches=eval_num_batches, eval_train_num_batches=eval_num_batches, eval_frequency=eval_every, checkpoint_steps=checkpoint_steps, early_stopping_target_name=early_stopping_target_name, early_stopping_target_value=early_stopping_target_value, early_stopping_mode=early_stopping_mode, metrics_logger=metrics_logger, init_logger=init_logger)) self.assertLen(epoch_reports, 3) self.assertGreater(epoch_reports[-2][early_stopping_target_name], early_stopping_target_value) self.assertLess(epoch_reports[-1][early_stopping_target_name], early_stopping_target_value)
# small hparams used for unit tests DEFAULT_HPARAMS = config_dict.ConfigDict(dict( num_filters=[64, 96, 128], kernel_sizes=[5, 3, 3], kernel_paddings=['VALID', 'VALID', 'SAME'], window_sizes=[3, 3, 3], window_paddings=['SAME', 'SAME', 'SAME'], strides=[2, 2, 2], num_dense_units=[512, 256], lr_hparams={ 'initial_value': 0.001, 'schedule': 'constant' }, layer_rescale_factors={}, optimizer='momentum', opt_hparams={ 'momentum': 0.9, }, batch_size=128, activation_fn='relu', normalizer='none', l2_decay_factor=.0005, l2_decay_rank_threshold=2, label_smoothing=None, rng_seed=-1, use_shallue_label_smoothing=False, model_dtype='float32', ))
def build_hparams(model_name, initializer_name, dataset_name, hparam_file, hparam_overrides): """Build experiment hyperparameters. Args: model_name: the string model name. initializer_name: the string initializer name. dataset_name: the string dataset name. hparam_file: the string to the hyperparameter override file (possibly on CNS). hparam_overrides: a dict of hyperparameter override names/values, or a JSON string encoding of this hyperparameter override dict. Note that this is applied after the hyperparameter file overrides. Returns: A ConfigDict of experiment hyperparameters. """ model_hps = models.get_model_hparams(model_name) initializer_hps = initializers.get_initializer_hparams(initializer_name) dataset_hps = datasets.get_dataset_hparams(dataset_name) merged_dict = {} hps_dicts = [ hps.to_dict() for hps in [model_hps, initializer_hps, dataset_hps] ] total_hps = 0 for hps_dict in hps_dicts: merged_dict.update(hps_dict) total_hps += len(hps_dict.keys()) # Check that all provided have no overlap. if total_hps != len(merged_dict.keys()): raise ValueError('There is overlap in the provided hparams.') # Convert to the Shallue and Lee label smoothing style. if merged_dict.get('use_shallue_label_smoothing', False): num_classes = merged_dict['output_shape'][-1] merged_dict['label_smoothing'] *= num_classes / float(num_classes - 1) merged = config_dict.ConfigDict(merged_dict) merged.lock() # Subconfig "opt_hparams" and "lr_hparams" are allowed to add new fields. for key in ['opt_hparams', 'lr_hparams']: if key not in merged: with merged.unlocked(): merged[key] = config_dict.ConfigDict() for key in ['opt_hparams', 'lr_hparams']: merged[key].unlock() if hparam_file: logging.info('Loading hparams from %s', hparam_file) with gfile.GFile(hparam_file, 'r') as f: hparam_dict = json.load(f) merged.update_from_flattened_dict(hparam_dict) if hparam_overrides: if isinstance(hparam_overrides, str): hparam_overrides = json.loads(hparam_overrides) # If the user is changing the learning rate schedule or optimizer. We must # wipe all of the keys from the old dictionary. if 'lr_hparams.schedule' in hparam_overrides and merged[ 'lr_hparams']['schedule'] != hparam_overrides[ 'lr_hparams.schedule']: merged['lr_hparams'] = {} if 'optimizer' in hparam_overrides and merged[ 'optimizer'] != hparam_overrides['optimizer']: merged['opt_hparams'] = {} merged.update_from_flattened_dict(hparam_overrides) return merged
DEFAULT_HPARAMS = config_dict.ConfigDict( dict( batch_size=512, emb_dim=128, num_heads=8, num_layers=6, qkv_dim=128, mlp_dim=512, dropout_rate=0.1, attention_dropout_rate=0.1, optimizer='adam', opt_hparams={ 'beta1': .9, 'beta2': .98, 'epsilon': 1e-9, 'weight_decay': 1e-1 }, layer_rescale_factors={}, normalizer='layer_norm', lr_hparams={ 'initial_value': 0.05, 'warmup_steps': 8000, 'factors': 'constant * linear_warmup * rsqrt_decay', 'schedule': 'compound' }, label_smoothing=None, l2_decay_factor=None, l2_decay_rank_threshold=0, rng_seed=-1, use_shallue_label_smoothing=False, model_dtype='float32', ))
from ml_collections.config_dict import config_dict DEFAULT_HPARAMS = config_dict.ConfigDict( dict( num_filters=16, num_layers=18, # Must be one of [18, 34, 50, 101, 152, 200] layer_rescale_factors={}, lr_hparams={ 'schedule': 'constant', 'initial_value': 0.2, }, optimizer='momentum', opt_hparams={ 'momentum': 0.9, }, batch_size=128, l2_decay_factor=0.0001, l2_decay_rank_threshold=2, label_smoothing=None, rng_seed=-1, use_shallue_label_smoothing=False, batch_norm_momentum=0.9, batch_norm_epsilon=1e-5, # Make this a string to avoid having to import jnp into the configs. model_dtype='float32', virtual_batch_size=None, data_format='NHWC', )) class BasicResidualBlock(nn.Module):
def _run(train_fn, dataset_name, eval_batch_size, eval_num_batches, eval_train_num_batches, eval_frequency, checkpoint_steps, num_tf_data_prefetches, num_device_prefetches, num_tf_data_map_parallel_calls, early_stopping_target_name, early_stopping_target_value, early_stopping_mode, eval_steps, hparam_file, hparam_overrides, initializer_name, model_name, loss_name, metrics_name, num_train_steps, experiment_dir, worker_id, training_metrics_config, callback_configs, external_checkpoint_path): """Function that runs a Jax experiment. See flag definitions for args.""" model_cls = models.get_model(model_name) initializer = initializers.get_initializer(initializer_name) dataset_builder = datasets.get_dataset(dataset_name) dataset_meta_data = datasets.get_dataset_meta_data(dataset_name) input_pipeline_hps = config_dict.ConfigDict( dict( num_tf_data_prefetches=num_tf_data_prefetches, num_device_prefetches=num_device_prefetches, num_tf_data_map_parallel_calls=num_tf_data_map_parallel_calls, )) merged_hps = hyperparameters.build_hparams( model_name=model_name, initializer_name=initializer_name, dataset_name=dataset_name, hparam_file=hparam_file, hparam_overrides=hparam_overrides, input_pipeline_hps=input_pipeline_hps) # Note that one should never tune an RNG seed!!! The seed is only included in # the hparams for convenience of running hparam trials with multiple seeds per # point. rng_seed = merged_hps.rng_seed if merged_hps.rng_seed < 0: rng_seed = _create_synchronized_rng_seed() xm_experiment = None xm_work_unit = None if jax.process_index() == 0: logging.info('Running with seed %d', rng_seed) rng = jax.random.PRNGKey(rng_seed) # Build the loss_fn, metrics_bundle, and flax_module. model = model_cls(merged_hps, dataset_meta_data, loss_name, metrics_name) trial_dir = os.path.join(experiment_dir, str(worker_id)) meta_data_path = os.path.join(trial_dir, 'meta_data.json') meta_data = {'worker_id': worker_id, 'status': 'incomplete'} if jax.process_index() == 0: logging.info('rng: %s', rng) gfile.makedirs(trial_dir) # Set up the metric loggers for host 0. metrics_logger, init_logger = utils.set_up_loggers( trial_dir, xm_work_unit) hparams_fname = os.path.join(trial_dir, 'hparams.json') logging.info('saving hparams to %s', hparams_fname) with gfile.GFile(hparams_fname, 'w') as f: f.write(merged_hps.to_json()) _write_trial_meta_data(meta_data_path, meta_data) else: metrics_logger = None init_logger = None try: epoch_reports = list( train_fn(trial_dir, model, dataset_builder, initializer, num_train_steps, merged_hps, rng, eval_batch_size, eval_num_batches, eval_train_num_batches, eval_frequency, checkpoint_steps, early_stopping_target_name, early_stopping_target_value, early_stopping_mode, eval_steps, metrics_logger, init_logger, training_metrics_config=training_metrics_config, callback_configs=callback_configs, external_checkpoint_path=external_checkpoint_path)) logging.info(epoch_reports) meta_data['status'] = 'done' except utils.TrainingDivergedError as err: meta_data['status'] = 'diverged' raise err finally: if jax.process_index() == 0: _write_trial_meta_data(meta_data_path, meta_data)
def test_text_model_trainer(self): """Test training of a small transformer model on fake data.""" rng = jax.random.PRNGKey(42) # Set the numpy seed to make the fake data deterministc. mocking.mock_data # ultimately calls numpy.random. np.random.seed(0) model_cls = models.get_model('transformer') loss_name = 'cross_entropy' metrics_name = 'classification_metrics' hps = config_dict.ConfigDict({ # Architecture Hparams. 'batch_size': _TEXT_BATCH_SIZE, 'emb_dim': 32, 'num_heads': 2, 'num_layers': 3, 'qkv_dim': 32, 'mlp_dim': 64, 'max_target_length': 64, 'max_eval_target_length': 64, 'input_shape': (64, ), 'output_shape': (_VOCAB_SIZE, ), 'dropout_rate': 0.1, 'attention_dropout_rate': 0.1, 'layer_rescale_factors': {}, 'optimizer': 'momentum', 'normalizer': 'layer_norm', 'opt_hparams': { 'momentum': 0.9, }, 'lr_hparams': { 'base_lr': 0.005, 'schedule': 'constant' }, # Training HParams. 'l2_decay_factor': 1e-4, 'l2_decay_rank_threshold': 2, 'train_size': _TEXT_TRAIN_SIZE, 'gradient_clipping': 0.0, 'model_dtype': 'float32', 'decode': False, 'num_device_prefetches': 0, }) initializer = initializers.get_initializer('noop') eval_num_batches = 5 dataset, dataset_meta_data = _get_fake_text_dataset( batch_size=hps.batch_size, eval_num_batches=eval_num_batches) eval_batch_size = hps.batch_size model = model_cls(hps, dataset_meta_data, loss_name, metrics_name) eval_every = 10 checkpoint_steps = [] num_train_steps = _TEXT_TRAIN_SIZE // _TEXT_BATCH_SIZE * 3 metrics_logger, init_logger = utils.set_up_loggers(self.test_dir) _ = list( trainer.train( train_dir=self.test_dir, model=model, dataset_builder=lambda *unused_args, **unused_kwargs: dataset, initializer=initializer, num_train_steps=num_train_steps, hps=hps, rng=rng, eval_batch_size=eval_batch_size, eval_num_batches=eval_num_batches, eval_train_num_batches=eval_num_batches, eval_frequency=eval_every, checkpoint_steps=checkpoint_steps, metrics_logger=metrics_logger, init_logger=init_logger)) with tf.io.gfile.GFile(os.path.join(self.test_dir, 'measurements.csv')) as f: df = pandas.read_csv(f) train_err = df['train/error_rate'].values[-1] # Note that upgrading to Linen made this fail at 0.6. self.assertLess(train_err, 0.7) self.assertEqual(set(df.columns.values), set(get_column_names())) prev_train_err = train_err # Test reload from the checkpoint by increasing num_train_steps. num_train_steps_reload = _TEXT_TRAIN_SIZE // _TEXT_BATCH_SIZE * 6 _ = list( trainer.train( train_dir=self.test_dir, model=model, dataset_builder=lambda *unused_args, **unused_kwargs: dataset, initializer=initializer, num_train_steps=num_train_steps_reload, hps=hps, rng=rng, eval_batch_size=eval_batch_size, eval_num_batches=eval_num_batches, eval_train_num_batches=eval_num_batches, eval_frequency=eval_every, checkpoint_steps=checkpoint_steps, metrics_logger=metrics_logger, init_logger=init_logger)) with tf.io.gfile.GFile(os.path.join(self.test_dir, 'measurements.csv')) as f: df = pandas.read_csv(f) train_err = df['train/error_rate'].values[-1] train_loss = df['train/ce_loss'].values[-1] # Note that upgrading to Linen made this fail at 0.45. self.assertLess(train_err, 0.67) self.assertLess(train_err, prev_train_err) # Note that upgrading to Linen made this fail at 0.9. self.assertLess(train_loss, 1.35) self.assertEqual(df['valid/num_examples'].values[-1], eval_num_batches * eval_batch_size * _MAX_LEN) # Check that the correct learning rate was saved in the measurements file. final_step = df['global_step'].values[-1] self.assertEqual(num_train_steps_reload, final_step) self.assertEqual(set(df.columns.values), set(get_column_names()))