예제 #1
0
    def test_initializers(self, init):
        """Test that each initializer runs, and the output is a valid pytree."""

        rng = jax.random.PRNGKey(0)
        flax_module, params, input_shape, model_hps = _load_model(
            'fully_connected')
        _, init_rng = jax.random.split(rng)
        initializer = initializers.get_initializer(init)
        init_hps = initializers.get_initializer_hparams(init)
        init_hps.update(model_hps)
        loss_name = 'cross_entropy'
        loss_fn = losses.get_loss_fn(loss_name)
        new_params = initializer(loss_fn=loss_fn,
                                 flax_module=flax_module,
                                 params=params,
                                 hps=init_hps,
                                 input_shape=input_shape[1:],
                                 output_shape=OUTPUT_SHAPE,
                                 rng_key=init_rng)

        # Check new params are still valid params
        outputs = flax_module.apply({'params': new_params},
                                    jnp.ones(input_shape),
                                    train=True)
        utils.log_pytree_shape_and_statistics(new_params)
        self.assertEqual(outputs.shape, (input_shape[0], OUTPUT_SHAPE[-1]))
예제 #2
0
  def test_initialize_rescale(self):
    """Test rescaling a single layer of a model."""
    input_shape = (28, 28, 1)
    output_shape = (10,)
    model_str = 'fully_connected'
    model_cls = models.get_model(model_str)
    model_hps = models.get_model_hparams(model_str)
    loss_name = 'cross_entropy'
    metrics_name = 'classification_metrics'
    hps = copy.copy(model_hps)
    hps.update({'output_shape': output_shape})
    rng = jax.random.PRNGKey(0)
    model = model_cls(hps, {}, loss_name, metrics_name)
    initializer = initializers.get_initializer('noop')

    rng, init_rng = jax.random.split(rng)

    # First initialize with no rescale.
    flax_module, _ = trainer.initialize(
        model.flax_module_def,
        initializer,
        model.loss_fn,
        input_shape,
        output_shape,
        hps,
        init_rng,
        metrics_logger=None)

    utils.log_pytree_shape_and_statistics(flax_module.params)
    # Now rescale a layer by 100.
    rescale_factor = 100
    hps.layer_rescale_factors = {
        '/Dense_1/kernel': rescale_factor,
    }

    rescaled_module, _ = trainer.initialize(
        model.flax_module_def,
        initializer,
        model.loss_fn,
        input_shape,
        output_shape,
        hps,
        init_rng,
        metrics_logger=None)

    # Check the right variable is rescaled
    v1 = flax_module.params['Dense_1']['kernel']
    v2 = rescaled_module.params['Dense_1']['kernel']
    diff = np.linalg.norm(v1.reshape(-1) * rescale_factor - v2.reshape(-1))
    self.assertAlmostEqual(diff, 0.0)

    # Check that other variables are the same
    v1 = flax_module.params['Dense_2']['kernel']
    v2 = rescaled_module.params['Dense_2']['kernel']
    diff = np.linalg.norm(v1.reshape(-1) - v2.reshape(-1))
    self.assertAlmostEqual(diff, 0.0)
예제 #3
0
파일: main.py 프로젝트: google/init2winit
def _run(train_fn, dataset_name, eval_batch_size, eval_num_batches,
         eval_train_num_batches, eval_frequency, checkpoint_steps,
         num_tf_data_prefetches, num_device_prefetches,
         num_tf_data_map_parallel_calls, early_stopping_target_name,
         early_stopping_target_value, early_stopping_mode, eval_steps,
         hparam_file, hparam_overrides, initializer_name, model_name,
         loss_name, metrics_name, num_train_steps, experiment_dir, worker_id,
         training_metrics_config, callback_configs, external_checkpoint_path):
    """Function that runs a Jax experiment. See flag definitions for args."""
    model_cls = models.get_model(model_name)
    initializer = initializers.get_initializer(initializer_name)
    dataset_builder = datasets.get_dataset(dataset_name)
    dataset_meta_data = datasets.get_dataset_meta_data(dataset_name)
    input_pipeline_hps = config_dict.ConfigDict(
        dict(
            num_tf_data_prefetches=num_tf_data_prefetches,
            num_device_prefetches=num_device_prefetches,
            num_tf_data_map_parallel_calls=num_tf_data_map_parallel_calls,
        ))

    merged_hps = hyperparameters.build_hparams(
        model_name=model_name,
        initializer_name=initializer_name,
        dataset_name=dataset_name,
        hparam_file=hparam_file,
        hparam_overrides=hparam_overrides,
        input_pipeline_hps=input_pipeline_hps)

    # Note that one should never tune an RNG seed!!! The seed is only included in
    # the hparams for convenience of running hparam trials with multiple seeds per
    # point.
    rng_seed = merged_hps.rng_seed
    if merged_hps.rng_seed < 0:
        rng_seed = _create_synchronized_rng_seed()
    xm_experiment = None
    xm_work_unit = None
    if jax.process_index() == 0:
        logging.info('Running with seed %d', rng_seed)
    rng = jax.random.PRNGKey(rng_seed)

    # Build the loss_fn, metrics_bundle, and flax_module.
    model = model_cls(merged_hps, dataset_meta_data, loss_name, metrics_name)
    trial_dir = os.path.join(experiment_dir, str(worker_id))
    meta_data_path = os.path.join(trial_dir, 'meta_data.json')
    meta_data = {'worker_id': worker_id, 'status': 'incomplete'}
    if jax.process_index() == 0:
        logging.info('rng: %s', rng)
        gfile.makedirs(trial_dir)
        # Set up the metric loggers for host 0.
        metrics_logger, init_logger = utils.set_up_loggers(
            trial_dir, xm_work_unit)
        hparams_fname = os.path.join(trial_dir, 'hparams.json')
        logging.info('saving hparams to %s', hparams_fname)
        with gfile.GFile(hparams_fname, 'w') as f:
            f.write(merged_hps.to_json())
        _write_trial_meta_data(meta_data_path, meta_data)
    else:
        metrics_logger = None
        init_logger = None
    try:
        epoch_reports = list(
            train_fn(trial_dir,
                     model,
                     dataset_builder,
                     initializer,
                     num_train_steps,
                     merged_hps,
                     rng,
                     eval_batch_size,
                     eval_num_batches,
                     eval_train_num_batches,
                     eval_frequency,
                     checkpoint_steps,
                     early_stopping_target_name,
                     early_stopping_target_value,
                     early_stopping_mode,
                     eval_steps,
                     metrics_logger,
                     init_logger,
                     training_metrics_config=training_metrics_config,
                     callback_configs=callback_configs,
                     external_checkpoint_path=external_checkpoint_path))
        logging.info(epoch_reports)
        meta_data['status'] = 'done'
    except utils.TrainingDivergedError as err:
        meta_data['status'] = 'diverged'
        raise err
    finally:
        if jax.process_index() == 0:
            _write_trial_meta_data(meta_data_path, meta_data)
예제 #4
0
    def test_early_stopping(self):
        """Test training early stopping on MNIST with a small model."""
        rng = jax.random.PRNGKey(0)

        # Set the numpy seed to make the fake data deterministc. mocking.mock_data
        # ultimately calls numpy.random.
        np.random.seed(0)

        model_name = 'fully_connected'
        loss_name = 'cross_entropy'
        metrics_name = 'classification_metrics'
        initializer_name = 'noop'
        dataset_name = 'mnist'
        model_cls = models.get_model(model_name)
        initializer = initializers.get_initializer(initializer_name)
        dataset_builder = datasets.get_dataset(dataset_name)
        hparam_overrides = {
            'lr_hparams': {
                'base_lr': 0.1,
                'schedule': 'cosine'
            },
            'batch_size': 8,
            'train_size': 160,
            'valid_size': 96,
            'test_size': 80,
        }
        input_pipeline_hps = config_dict.ConfigDict(
            dict(
                num_tf_data_prefetches=-1,
                num_device_prefetches=0,
                num_tf_data_map_parallel_calls=-1,
            ))
        hps = hyperparameters.build_hparams(
            model_name,
            initializer_name,
            dataset_name,
            hparam_file=None,
            hparam_overrides=hparam_overrides,
            input_pipeline_hps=input_pipeline_hps)

        eval_batch_size = 16
        num_examples = 256

        def as_dataset(self, *args, **kwargs):
            del args
            del kwargs

            # pylint: disable=g-long-lambda,g-complex-comprehension
            return tf.data.Dataset.from_generator(
                lambda: ({
                    'image': np.ones(shape=(28, 28, 1), dtype=np.uint8),
                    'label': 9,
                } for i in range(num_examples)),
                output_types=self.info.features.dtype,
                output_shapes=self.info.features.shape,
            )

        # This will override the tfds.load(mnist) call to return 100 fake samples.
        with tfds.testing.mock_data(as_dataset_fn=as_dataset,
                                    num_examples=num_examples):
            dataset = dataset_builder(shuffle_rng=jax.random.PRNGKey(0),
                                      batch_size=hps.batch_size,
                                      eval_batch_size=eval_batch_size,
                                      hps=hps)

        model = model_cls(hps, datasets.get_dataset_meta_data(dataset_name),
                          loss_name, metrics_name)

        num_train_steps = 40
        early_stopping_target_name = 'test/ce_loss'
        early_stopping_target_value = 0.005
        early_stopping_mode = 'less'
        eval_num_batches = 5
        eval_every = 10
        checkpoint_steps = [1, 3, 15]
        metrics_logger, init_logger = utils.set_up_loggers(self.test_dir)
        epoch_reports = list(
            trainer.train(
                train_dir=self.test_dir,
                model=model,
                dataset_builder=lambda *unused_args, **unused_kwargs: dataset,
                initializer=initializer,
                num_train_steps=num_train_steps,
                hps=hps,
                rng=rng,
                eval_batch_size=eval_batch_size,
                eval_num_batches=eval_num_batches,
                eval_train_num_batches=eval_num_batches,
                eval_frequency=eval_every,
                checkpoint_steps=checkpoint_steps,
                early_stopping_target_name=early_stopping_target_name,
                early_stopping_target_value=early_stopping_target_value,
                early_stopping_mode=early_stopping_mode,
                metrics_logger=metrics_logger,
                init_logger=init_logger))
        self.assertLen(epoch_reports, 3)
        self.assertGreater(epoch_reports[-2][early_stopping_target_name],
                           early_stopping_target_value)
        self.assertLess(epoch_reports[-1][early_stopping_target_name],
                        early_stopping_target_value)
예제 #5
0
    def test_trainer(self):
        """Test training for two epochs on MNIST with a small model."""
        rng = jax.random.PRNGKey(0)

        # Set the numpy seed to make the fake data deterministc. mocking.mock_data
        # ultimately calls numpy.random.
        np.random.seed(0)

        model_name = 'fully_connected'
        loss_name = 'cross_entropy'
        metrics_name = 'classification_metrics'
        initializer_name = 'noop'
        dataset_name = 'mnist'
        model_cls = models.get_model(model_name)
        initializer = initializers.get_initializer(initializer_name)
        dataset_builder = datasets.get_dataset(dataset_name)
        hparam_overrides = {
            'lr_hparams': {
                'base_lr': 0.1,
                'schedule': 'cosine'
            },
            'batch_size': 8,
            'train_size': 160,
            'valid_size': 96,
            'test_size': 80,
        }
        input_pipeline_hps = config_dict.ConfigDict(
            dict(
                num_tf_data_prefetches=-1,
                num_device_prefetches=0,
                num_tf_data_map_parallel_calls=-1,
            ))
        hps = hyperparameters.build_hparams(
            model_name,
            initializer_name,
            dataset_name,
            hparam_file=None,
            hparam_overrides=hparam_overrides,
            input_pipeline_hps=input_pipeline_hps)

        eval_batch_size = 16
        num_examples = 256

        def as_dataset(self, *args, **kwargs):
            del args
            del kwargs

            # pylint: disable=g-long-lambda,g-complex-comprehension
            return tf.data.Dataset.from_generator(
                lambda: ({
                    'image': np.ones(shape=(28, 28, 1), dtype=np.uint8),
                    'label': 9,
                } for i in range(num_examples)),
                output_types=self.info.features.dtype,
                output_shapes=self.info.features.shape,
            )

        # This will override the tfds.load(mnist) call to return 100 fake samples.
        with tfds.testing.mock_data(as_dataset_fn=as_dataset,
                                    num_examples=num_examples):
            dataset = dataset_builder(shuffle_rng=jax.random.PRNGKey(0),
                                      batch_size=hps.batch_size,
                                      eval_batch_size=eval_batch_size,
                                      hps=hps)

        model = model_cls(hps, datasets.get_dataset_meta_data(dataset_name),
                          loss_name, metrics_name)

        num_train_steps = 40
        eval_num_batches = 5
        eval_every = 10
        checkpoint_steps = [1, 3, 15]
        metrics_logger, init_logger = utils.set_up_loggers(self.test_dir)
        epoch_reports = list(
            trainer.train(
                train_dir=self.test_dir,
                model=model,
                dataset_builder=lambda *unused_args, **unused_kwargs: dataset,
                initializer=initializer,
                num_train_steps=num_train_steps,
                hps=hps,
                rng=rng,
                eval_batch_size=eval_batch_size,
                eval_num_batches=eval_num_batches,
                eval_train_num_batches=eval_num_batches,
                eval_frequency=eval_every,
                checkpoint_steps=checkpoint_steps,
                metrics_logger=metrics_logger,
                init_logger=init_logger))

        # check that the additional checkpoints are saved.
        checkpoint_dir = os.path.join(self.test_dir, 'checkpoints')
        saved_steps = []
        for f in tf.io.gfile.listdir(checkpoint_dir):
            if f[:5] == 'ckpt_':
                saved_steps.append(int(f[5:]))

        self.assertEqual(set(saved_steps), set(checkpoint_steps))

        self.assertLen(epoch_reports, num_train_steps / eval_every)
        with tf.io.gfile.GFile(os.path.join(self.test_dir,
                                            'measurements.csv')) as f:
            df = pandas.read_csv(f)
            train_err = df['train/error_rate'].values[-1]
            self.assertEqual(df['preemption_count'].values[-1], 0)
            self.assertLess(train_err, 0.9)

        self.assertEqual(set(df.columns.values), set(get_column_names()))

        model = model_cls(hps, {'apply_one_hot_in_loss': False}, loss_name,
                          metrics_name)

        # Test reload from the checkpoint by increasing num_train_steps.
        num_train_steps_reload = 100
        epoch_reports = list(
            trainer.train(
                train_dir=self.test_dir,
                model=model,
                dataset_builder=lambda *unused_args, **unused_kwargs: dataset,
                initializer=initializer,
                num_train_steps=num_train_steps_reload,
                hps=hps,
                rng=rng,
                eval_batch_size=eval_batch_size,
                eval_num_batches=eval_num_batches,
                eval_train_num_batches=eval_num_batches,
                eval_frequency=eval_every,
                checkpoint_steps=checkpoint_steps,
                metrics_logger=metrics_logger,
                init_logger=init_logger))
        self.assertLen(epoch_reports,
                       (num_train_steps_reload - num_train_steps) / eval_every)
        with tf.io.gfile.GFile(os.path.join(self.test_dir,
                                            'measurements.csv')) as f:
            df = pandas.read_csv(f)
            train_err = df['train/error_rate'].values[-1]
            train_loss = df['train/ce_loss'].values[-1]
            self.assertLess(train_err, 0.35)
            self.assertLess(train_loss, 0.1)

            self.assertEqual(df['valid/num_examples'].values[-1],
                             eval_num_batches * eval_batch_size)
            self.assertEqual(df['preemption_count'].values[-1], 1)
            # Check that the correct learning rate was saved in the measurements file.
            final_learning_rate = df['learning_rate'].values[-1]
            final_step = df['global_step'].values[-1]
            self.assertEqual(num_train_steps_reload, final_step)

            # final_step will be one larger than the last step used to calculate the
            # lr_decay, hense we plug in (final_step - 1) to the decay formula.
            # Note that there is a small numerical different here with np vs jnp.
            decay_factor = (1 + np.cos(
                (final_step - 1) / num_train_steps_reload * np.pi)) * 0.5
            self.assertEqual(float(final_learning_rate),
                             hps.lr_hparams['base_lr'] * decay_factor)

        self.assertEqual(set(df.columns.values), set(get_column_names()))
예제 #6
0
    def test_text_model_trainer(self):
        """Test training of a small transformer model on fake data."""
        rng = jax.random.PRNGKey(42)

        # Set the numpy seed to make the fake data deterministc. mocking.mock_data
        # ultimately calls numpy.random.
        np.random.seed(0)

        model_cls = models.get_model('transformer')
        loss_name = 'cross_entropy'
        metrics_name = 'classification_metrics'
        hps = config_dict.ConfigDict({
            # Architecture Hparams.
            'batch_size': _TEXT_BATCH_SIZE,
            'emb_dim': 32,
            'num_heads': 2,
            'num_layers': 3,
            'qkv_dim': 32,
            'mlp_dim': 64,
            'max_target_length': 64,
            'max_eval_target_length': 64,
            'input_shape': (64, ),
            'output_shape': (_VOCAB_SIZE, ),
            'dropout_rate': 0.1,
            'attention_dropout_rate': 0.1,
            'layer_rescale_factors': {},
            'optimizer': 'momentum',
            'normalizer': 'layer_norm',
            'opt_hparams': {
                'momentum': 0.9,
            },
            'lr_hparams': {
                'base_lr': 0.005,
                'schedule': 'constant'
            },
            # Training HParams.
            'l2_decay_factor': 1e-4,
            'l2_decay_rank_threshold': 2,
            'train_size': _TEXT_TRAIN_SIZE,
            'gradient_clipping': 0.0,
            'model_dtype': 'float32',
            'decode': False,
            'num_device_prefetches': 0,
        })
        initializer = initializers.get_initializer('noop')
        eval_num_batches = 5
        dataset, dataset_meta_data = _get_fake_text_dataset(
            batch_size=hps.batch_size, eval_num_batches=eval_num_batches)
        eval_batch_size = hps.batch_size

        model = model_cls(hps, dataset_meta_data, loss_name, metrics_name)

        eval_every = 10
        checkpoint_steps = []
        num_train_steps = _TEXT_TRAIN_SIZE // _TEXT_BATCH_SIZE * 3

        metrics_logger, init_logger = utils.set_up_loggers(self.test_dir)
        _ = list(
            trainer.train(
                train_dir=self.test_dir,
                model=model,
                dataset_builder=lambda *unused_args, **unused_kwargs: dataset,
                initializer=initializer,
                num_train_steps=num_train_steps,
                hps=hps,
                rng=rng,
                eval_batch_size=eval_batch_size,
                eval_num_batches=eval_num_batches,
                eval_train_num_batches=eval_num_batches,
                eval_frequency=eval_every,
                checkpoint_steps=checkpoint_steps,
                metrics_logger=metrics_logger,
                init_logger=init_logger))

        with tf.io.gfile.GFile(os.path.join(self.test_dir,
                                            'measurements.csv')) as f:
            df = pandas.read_csv(f)
            train_err = df['train/error_rate'].values[-1]
            # Note that upgrading to Linen made this fail at 0.6.
            self.assertLess(train_err, 0.7)

        self.assertEqual(set(df.columns.values), set(get_column_names()))
        prev_train_err = train_err

        # Test reload from the checkpoint by increasing num_train_steps.
        num_train_steps_reload = _TEXT_TRAIN_SIZE // _TEXT_BATCH_SIZE * 6
        _ = list(
            trainer.train(
                train_dir=self.test_dir,
                model=model,
                dataset_builder=lambda *unused_args, **unused_kwargs: dataset,
                initializer=initializer,
                num_train_steps=num_train_steps_reload,
                hps=hps,
                rng=rng,
                eval_batch_size=eval_batch_size,
                eval_num_batches=eval_num_batches,
                eval_train_num_batches=eval_num_batches,
                eval_frequency=eval_every,
                checkpoint_steps=checkpoint_steps,
                metrics_logger=metrics_logger,
                init_logger=init_logger))
        with tf.io.gfile.GFile(os.path.join(self.test_dir,
                                            'measurements.csv')) as f:
            df = pandas.read_csv(f)
            train_err = df['train/error_rate'].values[-1]
            train_loss = df['train/ce_loss'].values[-1]
            # Note that upgrading to Linen made this fail at 0.45.
            self.assertLess(train_err, 0.67)
            self.assertLess(train_err, prev_train_err)
            # Note that upgrading to Linen made this fail at 0.9.
            self.assertLess(train_loss, 1.35)

            self.assertEqual(df['valid/num_examples'].values[-1],
                             eval_num_batches * eval_batch_size * _MAX_LEN)
            # Check that the correct learning rate was saved in the measurements file.
            final_step = df['global_step'].values[-1]
            self.assertEqual(num_train_steps_reload, final_step)

        self.assertEqual(set(df.columns.values), set(get_column_names()))
예제 #7
0
    def test_dlrm_model_trainer(self):
        """Tests that dlrm model training decreases loss."""
        rng = jax.random.PRNGKey(1337)
        model_str = 'dlrm'
        dataset_str = 'criteo1tb'
        model_cls = models.get_model(model_str)
        model_hps = models.get_model_hparams(model_str)
        dataset_hps = datasets.get_dataset_hparams(dataset_str)
        dataset_hps.update({
            'batch_size': model_hps.batch_size,
            'num_dense_features': model_hps.num_dense_features,
            'vocab_sizes': model_hps.vocab_sizes,
        })
        eval_num_batches = 5
        eval_batch_size = dataset_hps.batch_size
        loss_name = 'sigmoid_binary_cross_entropy'
        metrics_name = 'binary_classification_metrics'
        dataset, dataset_meta_data = _get_fake_dlrm_dataset(
            dataset_hps.batch_size, eval_num_batches, dataset_hps)
        hps = copy.copy(model_hps)
        hps.update({
            'train_size':
            15,
            'valid_size':
            10,
            'test_size':
            10,
            'input_shape':
            (model_hps.num_dense_features + len(model_hps.vocab_sizes), ),
            'output_shape': (1, ),
            'l2_decay_factor':
            1e-4,
            'l2_decay_rank_threshold':
            2,
            'num_device_prefetches':
            0,
        })
        model = model_cls(hps, dataset_meta_data, loss_name, metrics_name)
        initializer = initializers.get_initializer('noop')

        metrics_logger, init_logger = utils.set_up_loggers(self.test_dir)
        _ = list(
            trainer.train(
                train_dir=self.test_dir,
                model=model,
                dataset_builder=lambda *unused_args, **unused_kwargs: dataset,
                initializer=initializer,
                num_train_steps=10,
                hps=hps,
                rng=rng,
                eval_batch_size=eval_batch_size,
                eval_num_batches=eval_num_batches,
                eval_train_num_batches=eval_num_batches,
                eval_frequency=2,
                checkpoint_steps=[],
                metrics_logger=metrics_logger,
                init_logger=init_logger))

        with tf.io.gfile.GFile(os.path.join(self.test_dir,
                                            'measurements.csv')) as f:
            df = pandas.read_csv(f)
            train_loss = df['train/ce_loss'].values
            self.assertLess(train_loss[-1], train_loss[0])
예제 #8
0
    def test_graph_model_trainer(self):
        """Tests that graph model training decreases loss."""
        rng = jax.random.PRNGKey(1337)
        model_str = 'gnn'
        model_cls = models.get_model(model_str)
        hps = models.get_model_hparams(model_str)
        hps.update({
            'batch_size': 2,
            'input_edge_shape': (7, ),
            'input_node_shape': (3, ),
            'input_shape': (7, 3),
            'output_shape': (5, ),
            'model_dtype': 'float32',
            'train_size': 15,
            'valid_size': 10,
            'test_size': 10,
            'num_message_passing_steps': 1,
            'normalizer': 'none',
            'dropout_rate': 0.0,
            'lr_hparams': {
                'base_lr': 0.001,
                'schedule': 'constant'
            },
            'num_device_prefetches': 0,
        })
        eval_num_batches = 5
        eval_batch_size = hps.batch_size
        loss_name = 'sigmoid_binary_cross_entropy'
        metrics_name = 'binary_classification_metrics_ogbg_map'
        dataset, dataset_meta_data = _get_fake_graph_dataset(
            batch_size=hps.batch_size,
            eval_num_batches=eval_num_batches,
            hps=hps)
        model = model_cls(hps, dataset_meta_data, loss_name, metrics_name)
        initializer = initializers.get_initializer('noop')

        metrics_logger, init_logger = utils.set_up_loggers(self.test_dir)
        _ = list(
            trainer.train(
                train_dir=self.test_dir,
                model=model,
                dataset_builder=lambda *unused_args, **unused_kwargs: dataset,
                initializer=initializer,
                num_train_steps=10,
                hps=hps,
                rng=rng,
                eval_batch_size=eval_batch_size,
                eval_num_batches=eval_num_batches,
                eval_train_num_batches=eval_num_batches,
                # Note that for some reason, moving from the deprecated to linen
                # Flax model API made training less stable so we need to eval more
                # frequently in order to get a `train_loss[0]` that is earlier in
                # training.
                eval_frequency=2,
                checkpoint_steps=[],
                metrics_logger=metrics_logger,
                init_logger=init_logger))

        with tf.io.gfile.GFile(os.path.join(self.test_dir,
                                            'measurements.csv')) as f:
            df = pandas.read_csv(f)
            train_loss = df['train/ce_loss'].values
            self.assertLess(train_loss[-1], train_loss[0])
예제 #9
0
def eval_checkpoints(
    checkpoint_dir,
    hps,
    rng,
    eval_num_batches,
    model_cls,
    dataset_builder,
    dataset_meta_data,
    hessian_eval_config,
    min_global_step=None,
    max_global_step=None,
):
  """Evaluate the Hessian of the given checkpoints.

  Iterates over all checkpoints in the specified directory, loads the checkpoint
  then evaluates the Hessian on the given checkpoint. A list of dicts will be
  saved to cns at checkpoint_dir/hessian_eval_config['name'].

  Args:
    checkpoint_dir: Directory of checkpoints to load.
    hps: (tf.HParams) Model, initialization and training hparams.
    rng: (jax.random.PRNGKey) Rng seed used in model initialization and data
      shuffling.
    eval_num_batches: (int) The batch size used for evaluating on
      validation, and test sets. Set to None to evaluate on the whole test set.
    model_cls: One of the model classes (not an instance) defined in model_lib.
    dataset_builder: dataset builder returned by datasets.get_dataset.
    dataset_meta_data: dict of meta_data about the dataset.
    hessian_eval_config: a dict specifying the configuration of the Hessian
      eval.
    min_global_step: Lower bound on what steps to filter checkpoints. Set to
      None to evaluate all checkpoints in the directory.
    max_global_step: Upper bound on what steps to filter checkpoints.
  """
  rng, init_rng = jax.random.split(rng)
  rng = jax.random.fold_in(rng, jax.process_index())
  rng, data_rng = jax.random.split(rng)

  initializer = initializers.get_initializer('noop')

  loss_name = 'cross_entropy'
  metrics_name = 'classification_metrics'
  model = model_cls(hps, dataset_meta_data, loss_name, metrics_name)

  # Maybe run the initializer.
  unreplicated_params, unreplicated_batch_stats = init_utils.initialize(
      model.flax_module,
      initializer, model.loss_fn,
      hps.input_shape,
      hps.output_shape, hps, init_rng,
      None)

  # Fold in a the unreplicated batch_stats and rng into the loss used by
  # hessian eval.
  def batch_loss(params, batch_rng):
    batch, rng = batch_rng
    return model.training_cost(
        params, batch, batch_stats=unreplicated_batch_stats, dropout_rng=rng)[0]
  batch_stats = jax_utils.replicate(unreplicated_batch_stats)

  if jax.process_index() == 0:
    utils.log_pytree_shape_and_statistics(unreplicated_params)
    logging.info('train_size: %d,', hps.train_size)
    logging.info(hps)
    # Save the hessian computation hps to the experiment directory
    exp_dir = os.path.join(checkpoint_dir, hessian_eval_config['name'])
    if not gfile.exists(exp_dir):
      gfile.mkdir(exp_dir)
    if min_global_step == 0:
      hparams_fname = os.path.join(exp_dir, 'hparams.json')
      with gfile.GFile(hparams_fname, 'w') as f:
        f.write(hps.to_json())
      config_fname = os.path.join(exp_dir, 'hconfig.json')
      with gfile.GFile(config_fname, 'w') as f:
        f.write(json.dumps(hessian_eval_config))

  optimizer_init_fn, optimizer_update_fn = optimizers.get_optimizer(hps)
  unreplicated_optimizer_state = optimizer_init_fn(unreplicated_params)
  # Note that we do not use the learning rate.
  # The optimizer state is a list of all the optax transformation states, and
  # we inject the learning rate into all states that will accept it.
  for state in unreplicated_optimizer_state:
    if (isinstance(state, optax.InjectHyperparamsState) and
        'learning_rate' in state.hyperparams):
      state.hyperparams['learning_rate'] = jax_utils.replicate(1.0)
  optimizer_state = jax_utils.replicate(unreplicated_optimizer_state)
  params = jax_utils.replicate(unreplicated_params)
  data_rng = jax.random.fold_in(data_rng, 0)

  assert hps.batch_size % (jax.device_count()) == 0
  dataset = dataset_builder(
      data_rng,
      hps.batch_size,
      eval_batch_size=hps.batch_size,  # eval iterators not used.
      hps=hps,
  )

  # pmap functions for the training loop
  evaluate_batch_pmapped = jax.pmap(model.evaluate_batch, axis_name='batch')

  if jax.process_index() == 0:
    logging.info('Starting eval!')
    logging.info('Number of hosts: %d', jax.process_count())

  hessian_evaluator = hessian_eval.CurvatureEvaluator(
      params,
      hessian_eval_config,
      dataset=dataset,
      loss=batch_loss)
  if min_global_step is None:
    suffix = ''
  else:
    suffix = '{}_{}'.format(min_global_step, max_global_step)
  pytree_path = os.path.join(checkpoint_dir, hessian_eval_config['name'],
                             suffix)
  logger = utils.MetricLogger(pytree_path=pytree_path)
  for checkpoint_path, step in iterate_checkpoints(checkpoint_dir,
                                                   min_global_step,
                                                   max_global_step):
    unreplicated_checkpoint_state = dict(
        params=unreplicated_params,
        optimizer_state=unreplicated_optimizer_state,
        batch_stats=unreplicated_batch_stats,
        global_step=0,
        preemption_count=0,
        sum_train_cost=0.0)
    ckpt = checkpoint.load_checkpoint(
        checkpoint_path,
        target=unreplicated_checkpoint_state)
    results, _ = checkpoint.replicate_checkpoint(
        ckpt,
        pytree_keys=['params', 'optimizer_state', 'batch_stats'])
    params = results['params']
    optimizer_state = results['optimizer_state']
    batch_stats = results['batch_stats']
    # pylint: disable=protected-access
    batch_stats = trainer_utils.maybe_sync_batchnorm_stats(batch_stats)
    # pylint: enable=protected-access
    report, _ = trainer.eval_metrics(params, batch_stats, dataset,
                                     eval_num_batches, eval_num_batches,
                                     evaluate_batch_pmapped)
    if jax.process_index() == 0:
      logging.info('Global Step: %d', step)
      logging.info(report)
    row = {}
    grads, updates = [], []
    hess_evecs, cov_evecs = [], []
    stats, hess_evecs, cov_evecs = hessian_evaluator.evaluate_spectrum(
        params, step)
    row.update(stats)
    if hessian_eval_config[
        'compute_stats'] or hessian_eval_config['compute_interps']:
      grads, updates = hessian_evaluator.compute_dirs(
          params, optimizer_state, optimizer_update_fn)
    row.update(hessian_evaluator.evaluate_stats(params, grads,
                                                updates, hess_evecs,
                                                cov_evecs, step))
    row.update(hessian_evaluator.compute_interpolations(params, grads,
                                                        updates, hess_evecs,
                                                        cov_evecs, step))
    if jax.process_index() == 0:
      logger.append_pytree(row)
예제 #10
0
  def test_run_lanczos(self):
    """Test training for two epochs on MNIST with a small model."""
    rng = jax.random.PRNGKey(0)

    # Set the numpy seed to make the fake data deterministc. mocking.mock_data
    # ultimately calls numpy.random.
    np.random.seed(0)

    model_name = 'fully_connected'
    loss_name = 'cross_entropy'
    metrics_name = 'classification_metrics'
    initializer_name = 'noop'
    dataset_name = 'mnist'
    model_cls = models.get_model(model_name)
    initializer = initializers.get_initializer(initializer_name)
    dataset_builder = datasets.get_dataset(dataset_name)
    hparam_overrides = {
        'lr_hparams': {
            'base_lr': 0.1,
            'schedule': 'cosine'
        },
        'batch_size': 8,
        'train_size': 160,
        'valid_size': 96,
        'test_size': 80,
    }
    input_pipeline_hps = config_dict.ConfigDict(dict(
        num_tf_data_prefetches=-1,
        num_device_prefetches=0,
        num_tf_data_map_parallel_calls=-1,
    ))
    hps = hyperparameters.build_hparams(
        model_name,
        initializer_name,
        dataset_name,
        hparam_file=None,
        hparam_overrides=hparam_overrides,
        input_pipeline_hps=input_pipeline_hps)
    model = model_cls(hps, datasets.get_dataset_meta_data(dataset_name),
                      loss_name, metrics_name)

    eval_batch_size = 16
    num_examples = 256

    def as_dataset(self, *args, **kwargs):
      del args
      del kwargs

      # pylint: disable=g-long-lambda,g-complex-comprehension
      return tf.data.Dataset.from_generator(
          lambda: ({
              'image': np.ones(shape=(28, 28, 1), dtype=np.uint8),
              'label': 9,
          } for i in range(num_examples)),
          output_types=self.info.features.dtype,
          output_shapes=self.info.features.shape,
      )

    # This will override the tfds.load(mnist) call to return 100 fake samples.
    with tfds.testing.mock_data(
        as_dataset_fn=as_dataset, num_examples=num_examples):
      dataset = dataset_builder(
          shuffle_rng=jax.random.PRNGKey(0),
          batch_size=hps.batch_size,
          eval_batch_size=eval_batch_size,
          hps=hps)

    num_train_steps = 41
    eval_num_batches = 5
    eval_every = 10
    checkpoint_steps = [10, 30, 40]
    metrics_logger, init_logger = None, None
    _ = list(
        trainer.train(
            train_dir=self.test_dir,
            model=model,
            dataset_builder=lambda *unused_args, **unused_kwargs: dataset,
            initializer=initializer,
            num_train_steps=num_train_steps,
            hps=hps,
            rng=rng,
            eval_batch_size=eval_batch_size,
            eval_num_batches=eval_num_batches,
            eval_train_num_batches=eval_num_batches,
            eval_frequency=eval_every,
            checkpoint_steps=checkpoint_steps,
            metrics_logger=metrics_logger,
            init_logger=init_logger))

    checkpoint_dir = os.path.join(self.test_dir, 'checkpoints')
    rng = jax.random.PRNGKey(0)

    run_lanczos.eval_checkpoints(
        checkpoint_dir,
        hps,
        rng,
        eval_num_batches,
        model_cls=model_cls,
        dataset_builder=lambda *unused_args, **unused_kwargs: dataset,
        dataset_meta_data=datasets.get_dataset_meta_data(dataset_name),
        hessian_eval_config=hessian_eval.DEFAULT_EVAL_CONFIG,
    )

    # Load the saved file.
    hessian_dir = os.path.join(checkpoint_dir, 'hessian', 'training_metrics')
    pytree_list = checkpoint.load_pytree(hessian_dir)

    # Convert to a regular list (checkpointer will have converted the saved
    # list to a dict of keys '0', '1', ...
    pytree_list = [pytree_list[str(i)] for i in range(len(pytree_list))]
    # Test that the logged steps are correct.
    saved_steps = [row['step'] for row in pytree_list]
    self.assertEqual(saved_steps, checkpoint_steps)
예제 #11
0
    def test_text_model(self):
        """Test gradient accumulator training of a small transformer."""
        rng = jax.random.PRNGKey(42)

        # Set the numpy seed to make the fake data deterministc. mocking.mock_data
        # ultimately calls numpy.random.
        np.random.seed(0)

        model_cls = models.get_model('transformer')
        loss_name = 'cross_entropy'
        metrics_name = 'classification_metrics'
        batch_size = 16
        train_size = 20 * batch_size
        hps = config_dict.ConfigDict({
            # Architecture Hparams.
            'batch_size': batch_size,
            'emb_dim': 32,
            'num_heads': 2,
            'num_layers': 3,
            'qkv_dim': 32,
            'mlp_dim': 64,
            'max_target_length': 64,
            'max_eval_target_length': 64,
            'input_shape': (64, ),
            'output_shape': (4, ),
            'dropout_rate': 0.1,
            'attention_dropout_rate': 0.1,
            'layer_rescale_factors': {},
            'optimizer': 'momentum',
            'normalizer': 'layer_norm',
            'opt_hparams': {
                'momentum': 0.9,
            },
            'lr_hparams': {
                'base_lr': 0.005,
                'schedule': 'constant'
            },
            # Training HParams.
            'l2_decay_factor': 1e-4,
            'l2_decay_rank_threshold': 2,
            'train_size': train_size,
            'gradient_clipping': 0.0,
            'model_dtype': 'float32',
            'decode': False,
        })
        initializer = initializers.get_initializer('noop')
        eval_num_batches = 5
        dataset, dataset_meta_data = _get_fake_text_dataset(
            batch_size=hps.batch_size, eval_num_batches=eval_num_batches)
        eval_batch_size = hps.batch_size

        model = model_cls(hps, dataset_meta_data, loss_name, metrics_name)

        eval_every = 10
        checkpoint_steps = []
        num_train_steps = train_size // batch_size * 3

        metrics_logger, init_logger = trainer.set_up_loggers(self.test_dir)
        _ = list(
            trainer.train(
                train_dir=self.test_dir,
                model=model,
                dataset_builder=lambda *unused_args, **unused_kwargs: dataset,
                initializer=initializer,
                num_train_steps=num_train_steps,
                hps=hps,
                rng=rng,
                eval_batch_size=eval_batch_size,
                eval_num_batches=eval_num_batches,
                eval_train_num_batches=eval_num_batches,
                eval_frequency=eval_every,
                checkpoint_steps=checkpoint_steps,
                metrics_logger=metrics_logger,
                init_logger=init_logger))

        with tf.io.gfile.GFile(os.path.join(self.test_dir,
                                            'measurements.csv')) as f:
            df = pandas.read_csv(f)
            train_err = df['train/error_rate'].values[-1]
            # Note that upgrading to Linen made this fail at 0.6.
            self.assertLess(train_err, 0.7)
예제 #12
0
  def test_shampoo_wrn(self):
    """Test distributed shampoo on fake dataset."""
    model_name = 'simple_cnn'
    model_cls = models.get_model(model_name)
    hparam_overrides = {
        'optimizer': 'distributed_shampoo',
        'batch_size': 1,
        'train_size': 10,
        'valid_size': 10,
        'input_shape': (32, 32, 3),
        'output_shape': (10,),
        'opt_hparams': {
            'block_size': 32,
            'beta1': 0.9,
            'beta2': 0.999,
            'diagonal_epsilon': 1e-10,
            'matrix_epsilon': 1e-6,
            'weight_decay': 0.0,
            'start_preconditioning_step': 5,
            'preconditioning_compute_steps': 1,
            'statistics_compute_steps': 1,
            'best_effort_shape_interpretation': True,
            'graft_type': distributed_shampoo.GraftingType.SGD,
            'nesterov': True,
            'exponent_override': 0,
            'batch_axis_name': 'batch',
            'num_devices_for_pjit': None,
            'shard_optimizer_states': False,
            'inverse_failure_threshold': 0.1,
            'clip_by_scaled_gradient_norm': None,
            'precision': lax.Precision.HIGHEST,
            'moving_average_for_momentum': False,
            'skip_preconditioning_dim_size_gt': 4096,
            'best_effort_memory_usage_reduction': False,
        },
    }
    input_pipeline_hps = config_dict.ConfigDict(dict(
        num_tf_data_prefetches=-1,
        num_device_prefetches=0,
        num_tf_data_map_parallel_calls=-1,
    ))
    hps = hyperparameters.build_hparams(
        model_name,
        initializer_name='noop',
        dataset_name='fake',
        hparam_file=None,
        hparam_overrides=hparam_overrides,
        input_pipeline_hps=input_pipeline_hps)
    initializer = initializers.get_initializer('noop')
    dataset_builder = datasets.get_dataset('fake')
    dataset = dataset_builder(
        shuffle_rng=jax.random.PRNGKey(0),
        batch_size=hps.batch_size,
        eval_batch_size=hps.batch_size,
        hps=hps)

    loss_name = 'cross_entropy'
    metrics_name = 'classification_metrics'
    dataset_meta_data = datasets.get_dataset_meta_data('fake')
    model = model_cls(hps, dataset_meta_data, loss_name, metrics_name)

    metrics_logger, init_logger = utils.set_up_loggers(self.test_dir)
    _ = list(
        trainer.train(
            train_dir=self.test_dir,
            model=model,
            dataset_builder=lambda *unused_args, **unused_kwargs: dataset,
            initializer=initializer,
            num_train_steps=1,
            hps=hps,
            rng=jax.random.PRNGKey(42),
            eval_batch_size=hps.batch_size,
            eval_num_batches=None,
            eval_train_num_batches=None,
            eval_frequency=10,
            checkpoint_steps=[],
            metrics_logger=metrics_logger,
            init_logger=init_logger))

    with tf.io.gfile.GFile(os.path.join(self.test_dir,
                                        'measurements.csv')) as f:
      df = pandas.read_csv(f)
      valid_ce_loss = df['valid/ce_loss'].values[-1]
      self.assertLess(valid_ce_loss, 1e-3)