Exemplo n.º 1
0
    def test_cifar10(self):
        """Test example generation in CIFAR10 is reproducible."""
        dataset = small_image_datasets.get_cifar10(
            random.PRNGKey(0), 1, 1,
            config_dict.ConfigDict(
                dict(flip_probability=0.5,
                     alpha=1.0,
                     crop_num_pixels=4,
                     use_mixup=True,
                     train_size=45000,
                     valid_size=5000,
                     test_size=10000,
                     include_example_keys=True,
                     input_shape=(32, 32, 3),
                     output_shape=(10, ))))

        examples = itertools.islice(dataset.valid_epoch(), 10)
        example_keys = [example['example_key'][0] for example in examples]
        self.assertEqual(example_keys, [
            b'cifar10-train.tfrecord-00000-of-00001__45000',
            b'cifar10-train.tfrecord-00000-of-00001__45001',
            b'cifar10-train.tfrecord-00000-of-00001__45002',
            b'cifar10-train.tfrecord-00000-of-00001__45003',
            b'cifar10-train.tfrecord-00000-of-00001__45004',
            b'cifar10-train.tfrecord-00000-of-00001__45005',
            b'cifar10-train.tfrecord-00000-of-00001__45006',
            b'cifar10-train.tfrecord-00000-of-00001__45007',
            b'cifar10-train.tfrecord-00000-of-00001__45008',
            b'cifar10-train.tfrecord-00000-of-00001__45009',
        ])
Exemplo n.º 2
0
 def test_schedule_stretching(self):
     """Test that schedules can be properly stretched."""
     max_training_steps = 100
     lr_hparams = config_dict.ConfigDict({
         'schedule': 'mlperf_polynomial',
         'base_lr': 10.0,
         'warmup_steps': 10,
         'decay_end': -1,
         'end_lr': 1e-4,
         'power': 2.0,
         'start_lr': 0.0,
         'warmup_power': 1.0,
     })
     lr_fn = schedules.get_schedule_fn(lr_hparams, max_training_steps)
     stretch_factor = 3
     stretched_lr_fn = schedules.get_schedule_fn(
         lr_hparams, max_training_steps, stretch_factor=stretch_factor)
     lrs = [lr_fn(t) for t in range(max_training_steps)]
     stretched_lrs = [
         stretched_lr_fn(t)
         for t in range(stretch_factor * max_training_steps)
     ]
     self.assertEqual(lrs, stretched_lrs[::stretch_factor])
     self.assertEqual(lrs, stretched_lrs[1::stretch_factor])
     self.assertEqual(lrs, stretched_lrs[2::stretch_factor])
     # Assert that the stretched schedule has proper staircase behavior.
     for update_step in range(max_training_steps):
         start = update_step * stretch_factor
         end = (update_step + 1) * stretch_factor
         expected = [lrs[update_step]] * stretch_factor
         self.assertEqual(stretched_lrs[start:end], expected)
Exemplo n.º 3
0
    def test_raises(self):
        """Test that an exception is raised with extra hparams."""
        good_hps = config_dict.ConfigDict(
            dict(
                lr_hparams={
                    'schedule': 'mlperf_polynomial',
                    'base_lr': .1,
                    'warmup_steps': 200,
                    'decay_end': -1,
                    'end_lr': 1e-4,
                    'power': 2.0,
                    'start_lr': 0.0,
                    'warmup_power': 1.0,
                }))
        bad_hps = config_dict.ConfigDict(
            dict(
                lr_hparams={
                    'schedule': 'mlperf_polynomial',
                    'base_lr': .1,
                    'warmup_steps': 200,
                    'initial_value': .1,
                    'decay_end': -1,
                    'end_lr': 1e-4,
                    'power': 2.0,
                    'start_lr': 0.0,
                }))
        bad_hps2 = config_dict.ConfigDict(
            dict(
                lr_hparams={
                    'schedule': 'polynomial',
                    'power': 2.0,
                    'initial_value': .1,
                    'end_factor': .01,
                    'decay_steps': 200,
                    'decay_steps_factor': 0.5
                }))
        # This should pass.
        schedules.get_schedule_fn(good_hps.lr_hparams, 1)

        # This should raise an exception due to the extra hparam.
        with self.assertRaises(ValueError):
            schedules.get_schedule_fn(bad_hps.lr_hparams, 1)

        # This should raise an exception due to the mutually exclusive hparams.
        with self.assertRaises(ValueError):
            schedules.get_schedule_fn(bad_hps2.lr_hparams, 1)
Exemplo n.º 4
0
def main(unused_argv):
  # Necessary to use the tfds imagenet loader.
  tf.enable_v2_behavior()


  rng = jax.random.PRNGKey(FLAGS.seed)

  if FLAGS.hessian_eval_config:
    hessian_eval_config = json.loads(FLAGS.hessian_eval_config)
  else:
    hessian_eval_config = hessian_eval.DEFAULT_EVAL_CONFIG

  if FLAGS.experiment_config_filename:
    with tf.io.gfile.GFile(FLAGS.experiment_config_filename, 'r') as f:
      experiment_config = json.load(f)
    if jax.process_index() == 0:
      logging.info('experiment_config: %r', experiment_config)
    dataset_name = experiment_config['dataset']
    model_name = experiment_config['model']
  else:
    assert FLAGS.dataset and FLAGS.model
    dataset_name = FLAGS.dataset
    model_name = FLAGS.model

  if jax.process_index() == 0:
    logging.info('argv:\n%s', ' '.join(sys.argv))
    logging.info('device_count: %d', jax.device_count())
    logging.info('num_hosts : %d', jax.process_count())
    logging.info('host_id : %d', jax.process_index())

  model = models.get_model(model_name)
  dataset_builder = datasets.get_dataset(dataset_name)
  dataset_meta_data = datasets.get_dataset_meta_data(dataset_name)

  with tf.io.gfile.GFile(FLAGS.trial_hparams_filename, 'r') as f:
    hps = config_dict.ConfigDict(json.load(f))

  if FLAGS.hparam_overrides:
    if isinstance(FLAGS.hparam_overrides, str):
      hparam_overrides = json.loads(FLAGS.hparam_overrides)
    hps.update_from_flattened_dict(hparam_overrides)
  run_lanczos.eval_checkpoints(
      FLAGS.checkpoint_dir,
      hps,
      rng,
      FLAGS.eval_num_batches,
      model,
      dataset_builder,
      dataset_meta_data,
      hessian_eval_config,
      FLAGS.min_global_step,
      FLAGS.max_global_step)
Exemplo n.º 5
0
 def _get_dataset(self, hps, rng):
     """Sets ups dataset builders."""
     hparams_dict = hps.to_dict()
     hparams_dict.update(self.callback_config)
     hparams = config_dict.ConfigDict(hparams_dict)
     dataset_builder = datasets.get_dataset(
         self.callback_config['dataset_name'])
     dataset = dataset_builder(
         rng,
         hparams.batch_size,
         eval_batch_size=self.callback_config['eval_batch_size'],
         hps=hparams)
     return dataset
Exemplo n.º 6
0
    def test_nqm(self):
        """Test the noisy quadratic model."""
        batch_size = 2
        dim = 10
        model_hps = config_dict.ConfigDict(
            dict(
                input_shape=(dim, ),
                output_shape=(1, ),
                rng_seed=-1,
                hessian_decay_power=1.0,
                noise_decay_power=1.0,
                nqm_mode='diagH_diagC',
                model_dtype='float32',
            ))

        model_cls = models.get_model('nqm')
        rng = jax.random.PRNGKey(0)
        model = model_cls(model_hps, {})
        noise_eps = jnp.array(np.random.normal(size=(batch_size, dim)))
        rng, params_rng = jax.random.split(rng)
        _, flax_module = model.flax_module_def.create_by_shape(
            params_rng, [(batch_size, dim)])

        model_x = flax_module.params['x']

        def loss(model, inputs):
            return model(inputs)

        grad_loss = jax.grad(loss)

        hessian = np.diag(
            np.array([
                1.0 / np.power(i, model_hps.hessian_decay_power)
                for i in range(1, dim + 1)
            ]))
        noise_matrix = np.diag(
            np.array([
                1.0 / np.power(i, model_hps.noise_decay_power / 2.0)
                for i in range(1, dim + 1)
            ]))

        noise = jnp.dot(noise_eps, noise_matrix)
        mean_noise = np.mean(noise, axis=0)

        # NQM gradient = Hx + eps   where eps ~ N(0, C / batch_size).
        expected_grad = np.dot(hessian, model_x) + mean_noise

        g = grad_loss(flax_module, noise_eps).params['x']

        grad_error = np.sum(np.abs(g - expected_grad))
        self.assertAlmostEqual(grad_error, 0.0, places=5)
Exemplo n.º 7
0
  def test_nqm(self):
    """Test the noisy quadratic model."""
    batch_size = 2
    dim = 10
    model_hps = config_dict.ConfigDict(
        dict(
            input_shape=(dim,),
            output_shape=(1,),
            rng_seed=-1,
            hessian_decay_power=1.0,
            noise_decay_power=1.0,
            nqm_mode='diagH_diagC',
            model_dtype='float32',
        ))

    model_cls = models.get_model('nqm')
    params_rng = jax.random.PRNGKey(0)
    model = model_cls(model_hps, {}, None, None)
    noise_eps = jnp.array(np.random.normal(size=(batch_size, dim)))
    xs = np.zeros((batch_size, dim))
    model_init_fn = jax.jit(
        functools.partial(model.flax_module.init, train=False))
    params = model_init_fn({'params': params_rng}, xs)['params']
    model_x = params['x']

    def loss(params, inputs):
      return model.training_cost(params, batch=inputs)

    grad_loss = jax.grad(loss, has_aux=True)

    hessian = np.diag(
        np.array([
            1.0 / np.power(i, model_hps.hessian_decay_power)
            for i in range(1, dim + 1)
        ]))
    noise_matrix = np.diag(
        np.array([
            1.0 / np.power(i, model_hps.noise_decay_power / 2.0)
            for i in range(1, dim + 1)
        ]))

    noise = jnp.dot(noise_eps, noise_matrix)
    mean_noise = np.mean(noise, axis=0)

    # NQM gradient = Hx + eps   where eps ~ N(0, C / batch_size).
    expected_grad = np.dot(hessian, model_x) + mean_noise

    g = grad_loss(params, {'inputs': noise_eps})[0]['x']

    grad_error = np.sum(np.abs(g - expected_grad))
    self.assertAlmostEqual(grad_error, 0.0, places=5)
Exemplo n.º 8
0
def _get_dataset(shuffle_seed, additional_hps=None):
  """Loads the ogbg-molpcba dataset using mock data."""
  with tfds.testing.mock_data(as_dataset_fn=_as_dataset):
    ds = 'ogbg_molpcba'
    dataset_builder = get_dataset(ds)
    hps_dict = get_dataset_hparams(ds).to_dict()
    if additional_hps is not None:
      hps_dict.update(additional_hps)
    hps = config_dict.ConfigDict(hps_dict)
    hps.train_size = 4
    hps.valid_size = 4
    hps.test_size = 4
    hps.max_nodes_multiplier = NODES_SIZE_MULTIPLIER
    hps.max_edges_multiplier = EDGES_SIZE_MULTIPLIER
    batch_size = BATCH_SIZE
    eval_batch_size = BATCH_SIZE
    dataset = dataset_builder(
        shuffle_rng=shuffle_seed,
        batch_size=batch_size,
        eval_batch_size=eval_batch_size,
        hps=hps)
    return dataset
Exemplo n.º 9
0
 def test_polynomial_decay_decay_steps(self):
     """Test polynomial schedule works correctly with decay_steps."""
     hps = config_dict.ConfigDict(
         dict(
             lr_hparams={
                 'schedule': 'polynomial',
                 'power': 2.0,
                 'initial_value': .1,
                 'end_factor': .01,
                 'decay_steps': 200,
             }))
     max_training_steps = 400
     lr_fn = schedules.get_schedule_fn(hps.lr_hparams, max_training_steps)
     hps = hps.lr_hparams
     decay_steps = hps['decay_steps']
     for step in range(max_training_steps):
         expected_learning_rate = tf.train.polynomial_decay(
             hps['initial_value'],
             step,
             decay_steps,
             hps['end_factor'] * hps['initial_value'],
             power=hps['power'])().numpy()
         self.assertAlmostEqual(lr_fn(step), expected_learning_rate)
Exemplo n.º 10
0
  def test_text_models(self, model_str):
    """Test forward pass of the transformer model."""

    # TODO(gilmer): Find a clean way to handle small test hparams.
    vocab_size = 16

    small_hps = config_dict.ConfigDict({
        # Architecture Hparams.
        'batch_size': 16,
        'emb_dim': 32,
        'num_heads': 2,
        'num_layers': 3,
        'qkv_dim': 32,
        'label_smoothing': 0.1,
        'mlp_dim': 64,
        'max_target_length': 64,
        'max_eval_target_length': 64,
        'dropout_rate': 0.1,
        'attention_dropout_rate': 0.1,
        'momentum': 0.9,
        'normalizer': 'layer_norm',
        'lr_hparams': {
            'base_lr': 0.005,
            'schedule': 'constant'
        },
        'output_shape': (vocab_size,),
        'model_dtype': 'float32',
        # Training HParams.
        'l2_decay_factor': 1e-4,
        'decode': False,
    })

    text_input_shape = (32, 64)  # batch_size, max_target_length
    model_cls = models.get_model(model_str)
    rng = jax.random.PRNGKey(0)
    loss = 'cross_entropy'
    metrics = 'classification_metrics'
    model = model_cls(small_hps, {
        'max_len': 64,
        'shift_inputs': True,
        'causal': True
    }, loss, metrics)
    xs = jnp.array(
        np.random.randint(size=text_input_shape, low=1, high=vocab_size))
    dropout_rng, params_rng = jax.random.split(rng)

    model_init_fn = jax.jit(
        functools.partial(model.flax_module.init, train=False))
    init_dict = model_init_fn({'params': params_rng}, xs)
    params = init_dict['params']
    batch_stats = init_dict.get('batch_stats', {})

    # Check that the forward pass works with mutated batch_stats.
    # Due to a bug in flax, this jit is required, otherwise the model errors.
    @jax.jit
    def forward_pass(params, xs, dropout_rng):
      outputs, new_batch_stats = model.flax_module.apply(
          {'params': params, 'batch_stats': batch_stats},
          xs,
          mutable=['batch_stats'],
          rngs={'dropout': dropout_rng},
          train=True)
      return outputs, new_batch_stats

    outputs, new_batch_stats = forward_pass(params, xs, dropout_rng)
    self.assertEqual(outputs.shape,
                     (text_input_shape[0], text_input_shape[1], vocab_size))

    # If it's a batch norm model check the batch stats changed.
    if batch_stats:
      bflat, _ = ravel_pytree(batch_stats)
      new_bflat, _ = ravel_pytree(new_batch_stats)
      self.assertFalse(jnp.array_equal(bflat, new_bflat))

    # Test batch_norm in inference mode.
    outputs = model.flax_module.apply(
        {'params': params, 'batch_stats': batch_stats}, xs, train=False)
    self.assertEqual(outputs.shape,
                     (text_input_shape[0], text_input_shape[1], vocab_size))
Exemplo n.º 11
0
  def test_translate_model(self):
    """Test forward pass of the translate model."""
    vocab_size = 16
    small_hps = config_dict.ConfigDict({
        # Architecture Hparams.
        'batch_size': 16,
        'share_embeddings': False,
        'logits_via_embedding': False,
        'emb_dim': 32,
        'num_heads': 2,
        'enc_num_layers': 2,
        'dec_num_layers': 2,
        'qkv_dim': 32,
        'label_smoothing': 0.1,
        'mlp_dim': 64,
        'max_target_length': 64,
        'max_eval_target_length': 64,
        'normalizer': 'pre_layer_norm',
        'max_predict_length': 64,
        'dropout_rate': 0.1,
        'attention_dropout_rate': 0.1,
        'momentum': 0.9,
        'lr_hparams': {
            'base_lr': 0.005,
            'schedule': 'constant'
        },
        'output_shape': (vocab_size,),
        # Training HParams.
        'l2_decay_factor': 1e-4,
        'enc_self_attn_kernel_init': 'xavier_uniform',
        'dec_self_attn_kernel_init': 'xavier_uniform',
        'dec_cross_attn_kernel_init': 'xavier_uniform',
        'decode': False,
    })
    text_src_input_shape = (32, 64)  # batch_size, max_source_length
    text_tgt_input_shape = (32, 40)  # batch_size, max_target_length
    model_cls = models.get_model('xformer_translate')
    rng = jax.random.PRNGKey(0)
    loss = 'cross_entropy'
    metrics = 'classification_metrics'
    model = model_cls(small_hps, {
        'shift_outputs': True,
        'causal': True
    }, loss, metrics)
    xs = jnp.array(
        np.random.randint(size=text_src_input_shape, low=1, high=vocab_size))
    ys = jnp.array(
        np.random.randint(size=text_tgt_input_shape, low=1, high=vocab_size))
    dropout_rng, params_rng = jax.random.split(rng)
    model_init_fn = jax.jit(
        functools.partial(model.flax_module.init, train=False))
    init_dict = model_init_fn({'params': params_rng}, xs, ys)
    params = init_dict['params']

    # Test forward pass.
    @jax.jit
    def forward_pass(params, xs, ys, dropout_rng):
      outputs = model.flax_module.apply(
          {'params': params},
          xs,
          ys,
          rngs={'dropout': dropout_rng},
          train=True)
      return outputs

    logits = forward_pass(params, xs, ys, dropout_rng)
    # Testing only train mode
    # TODO(ankugarg): Add tests for individual encoder/decoder (inference mode).
    self.assertEqual(
        logits.shape,
        (text_tgt_input_shape[0], text_tgt_input_shape[1], vocab_size))
Exemplo n.º 12
0
from flax import optim as optimizers
from flax.core import unfreeze
from init2winit.model_lib import model_utils
import jax
from jax import jvp
import jax.numpy as jnp
from ml_collections.config_dict import config_dict
import numpy as np
import optax


# Small hparams for quicker tests.
DEFAULT_HPARAMS = config_dict.ConfigDict(dict(
    meta_learning_rate=.1,
    meta_steps=50,
    meta_batch_size=8,
    epsilon=1e-5,
    meta_momentum=0.5,
))


def _count_params(tree):
  return jax.tree_util.tree_reduce(operator.add,
                                   jax.tree_map(lambda x: x.size, tree))


def scale_params(params, scalars):
  return jax.tree_multimap(lambda w, scale: w * scale, params, scalars)


def meta_loss(params_to_loss, scalars, normalized_params, epsilon):
Exemplo n.º 13
0
    def test_text_models(self, model_str):
        """Test forward pass of the transformer model."""

        # TODO(gilmer): Find a clean way to handle small test hparams.
        vocab_size = 16

        small_hps = config_dict.ConfigDict({
            # Architecture Hparams.
            'batch_size': 16,
            'emb_dim': 32,
            'num_heads': 2,
            'num_layers': 3,
            'qkv_dim': 32,
            'label_smoothing': 0.1,
            'mlp_dim': 64,
            'max_target_length': 64,
            'max_eval_target_length': 64,
            'dropout_rate': 0.1,
            'attention_dropout_rate': 0.1,
            'momentum': 0.9,
            'normalizer': 'layer_norm',
            'lr_hparams': {
                'initial_value': 0.005,
                'schedule': 'constant'
            },
            'output_shape': (vocab_size, ),
            # Training HParams.
            'l2_decay_factor': 1e-4
        })

        text_input_shape = (32, 64)  # batch_size, max_target_length
        model_cls = models.get_model(model_str)
        rng = jax.random.PRNGKey(0)
        loss = 'cross_entropy'
        metrics = 'classification_metrics'
        model = model_cls(small_hps, {
            'max_len': 64,
            'shift_inputs': True,
            'causal': True
        }, loss, metrics)
        xs = jnp.array(
            np.random.randint(size=text_input_shape, low=1, high=vocab_size))
        rng, params_rng = jax.random.split(rng)
        rng, dropout_rng = jax.random.split(rng)

        with nn.stateful() as batch_stats:
            _, flax_module = model.flax_module_def.create_by_shape(
                params_rng, [(text_input_shape, jnp.float32)], train=False)

        # Check that the forward pass works with mutated batch_stats.
        # Due to a bug in flax, this jit is required, otherwise the model errors.
        @jax.jit
        def forward_pass(flax_module, xs, dropout_rng):
            with batch_stats.mutate() as new_batch_stats:
                with nn.stochastic(dropout_rng):
                    return flax_module(xs, train=True), new_batch_stats

        outputs, new_batch_stats = forward_pass(flax_module, xs, dropout_rng)
        self.assertEqual(
            outputs.shape,
            (text_input_shape[0], text_input_shape[1], vocab_size))

        # If it's a batch norm model check the batch stats changed.
        if batch_stats.as_dict():
            bflat, _ = ravel_pytree(batch_stats)
            new_bflat, _ = ravel_pytree(new_batch_stats)
            self.assertFalse(jnp.array_equal(bflat, new_bflat))

        # Test batch_norm in inference mode.
        with nn.stateful(batch_stats, mutable=False):
            outputs = flax_module(xs, train=False)
        self.assertEqual(
            outputs.shape,
            (text_input_shape[0], text_input_shape[1], vocab_size))
Exemplo n.º 14
0
# small hparams used for unit tests
DEFAULT_HPARAMS = config_dict.ConfigDict(
    dict(blocks_per_group=3,
         channel_multiplier=2,
         lr_hparams={
             'base_lr': 0.001,
             'schedule': 'cosine'
         },
         normalizer='batch_norm',
         layer_rescale_factors={},
         conv_kernel_scale=1.0,
         dense_kernel_scale=1.0,
         dropout_rate=0.0,
         conv_kernel_init='lecun_normal',
         dense_kernel_init='lecun_normal',
         optimizer='momentum',
         opt_hparams={
             'momentum': 0.9,
         },
         batch_size=128,
         virtual_batch_size=None,
         total_accumulated_batch_size=None,
         l2_decay_factor=0.0001,
         l2_decay_rank_threshold=2,
         label_smoothing=None,
         rng_seed=-1,
         use_shallue_label_smoothing=False,
         model_dtype='float32',
         grad_clip=None,
         activation_function='relu',
         group_strides=[(1, 1), (2, 2), (2, 2)]))
Exemplo n.º 15
0
    def test_translate_model(self):
        """Test forward pass of the translate model."""
        vocab_size = 16
        small_hps = config_dict.ConfigDict({
            # Architecture Hparams.
            'batch_size': 16,
            'share_embeddings': False,
            'logits_via_embedding': False,
            'emb_dim': 32,
            'num_heads': 2,
            'enc_num_layers': 2,
            'dec_num_layers': 2,
            'qkv_dim': 32,
            'label_smoothing': 0.1,
            'mlp_dim': 64,
            'max_target_length': 64,
            'max_eval_target_length': 64,
            'normalizer': 'pre_layer_norm',
            'max_predict_length': 64,
            'dropout_rate': 0.1,
            'attention_dropout_rate': 0.1,
            'momentum': 0.9,
            'lr_hparams': {
                'initial_value': 0.005,
                'schedule': 'constant'
            },
            'output_shape': (vocab_size, ),
            # Training HParams.
            'l2_decay_factor': 1e-4
        })
        text_src_input_shape = (32, 64)  # batch_size, max_source_length
        text_tgt_input_shape = (32, 40)  # batch_size, max_target_length
        model_cls = models.get_model('xformer_translate')
        rng = jax.random.PRNGKey(0)
        loss = 'cross_entropy'
        metrics = 'classification_metrics'
        model = model_cls(small_hps, {
            'shift_outputs': True,
            'causal': True
        }, loss, metrics)
        xs = jnp.array(
            np.random.randint(size=text_src_input_shape,
                              low=1,
                              high=vocab_size))
        ys = jnp.array(
            np.random.randint(size=text_tgt_input_shape,
                              low=1,
                              high=vocab_size))
        rng, params_rng = jax.random.split(rng)
        rng, dropout_rng = jax.random.split(rng)
        with nn.stateful() as batch_stats:
            _, flax_module = model.flax_module_def.create_by_shape(
                params_rng, [(text_src_input_shape, jnp.float32),
                             (text_tgt_input_shape, jnp.float32)],
                train=False)

        # Test forward pass.
        @jax.jit
        def forward_pass(flax_module, xs, ys, dropout_rng):
            with batch_stats.mutate() as new_batch_stats:
                with nn.stochastic(dropout_rng):
                    return flax_module(xs, ys, train=True), new_batch_stats

        logits, _ = forward_pass(flax_module, xs, ys, dropout_rng)
        # Testing only train mode
        # TODO(ankugarg): Add tests for individual encoder/decoder (inference mode).
        self.assertEqual(
            logits.shape,
            (text_tgt_input_shape[0], text_tgt_input_shape[1], vocab_size))
Exemplo n.º 16
0
    def test_trainer(self):
        """Test training for two epochs on MNIST with a small model."""
        rng = jax.random.PRNGKey(0)

        # Set the numpy seed to make the fake data deterministc. mocking.mock_data
        # ultimately calls numpy.random.
        np.random.seed(0)

        model_name = 'fully_connected'
        loss_name = 'cross_entropy'
        metrics_name = 'classification_metrics'
        initializer_name = 'noop'
        dataset_name = 'mnist'
        model_cls = models.get_model(model_name)
        initializer = initializers.get_initializer(initializer_name)
        dataset_builder = datasets.get_dataset(dataset_name)
        hparam_overrides = {
            'lr_hparams': {
                'base_lr': 0.1,
                'schedule': 'cosine'
            },
            'batch_size': 8,
            'train_size': 160,
            'valid_size': 96,
            'test_size': 80,
        }
        input_pipeline_hps = config_dict.ConfigDict(
            dict(
                num_tf_data_prefetches=-1,
                num_device_prefetches=0,
                num_tf_data_map_parallel_calls=-1,
            ))
        hps = hyperparameters.build_hparams(
            model_name,
            initializer_name,
            dataset_name,
            hparam_file=None,
            hparam_overrides=hparam_overrides,
            input_pipeline_hps=input_pipeline_hps)

        eval_batch_size = 16
        num_examples = 256

        def as_dataset(self, *args, **kwargs):
            del args
            del kwargs

            # pylint: disable=g-long-lambda,g-complex-comprehension
            return tf.data.Dataset.from_generator(
                lambda: ({
                    'image': np.ones(shape=(28, 28, 1), dtype=np.uint8),
                    'label': 9,
                } for i in range(num_examples)),
                output_types=self.info.features.dtype,
                output_shapes=self.info.features.shape,
            )

        # This will override the tfds.load(mnist) call to return 100 fake samples.
        with tfds.testing.mock_data(as_dataset_fn=as_dataset,
                                    num_examples=num_examples):
            dataset = dataset_builder(shuffle_rng=jax.random.PRNGKey(0),
                                      batch_size=hps.batch_size,
                                      eval_batch_size=eval_batch_size,
                                      hps=hps)

        model = model_cls(hps, datasets.get_dataset_meta_data(dataset_name),
                          loss_name, metrics_name)

        num_train_steps = 40
        eval_num_batches = 5
        eval_every = 10
        checkpoint_steps = [1, 3, 15]
        metrics_logger, init_logger = utils.set_up_loggers(self.test_dir)
        epoch_reports = list(
            trainer.train(
                train_dir=self.test_dir,
                model=model,
                dataset_builder=lambda *unused_args, **unused_kwargs: dataset,
                initializer=initializer,
                num_train_steps=num_train_steps,
                hps=hps,
                rng=rng,
                eval_batch_size=eval_batch_size,
                eval_num_batches=eval_num_batches,
                eval_train_num_batches=eval_num_batches,
                eval_frequency=eval_every,
                checkpoint_steps=checkpoint_steps,
                metrics_logger=metrics_logger,
                init_logger=init_logger))

        # check that the additional checkpoints are saved.
        checkpoint_dir = os.path.join(self.test_dir, 'checkpoints')
        saved_steps = []
        for f in tf.io.gfile.listdir(checkpoint_dir):
            if f[:5] == 'ckpt_':
                saved_steps.append(int(f[5:]))

        self.assertEqual(set(saved_steps), set(checkpoint_steps))

        self.assertLen(epoch_reports, num_train_steps / eval_every)
        with tf.io.gfile.GFile(os.path.join(self.test_dir,
                                            'measurements.csv')) as f:
            df = pandas.read_csv(f)
            train_err = df['train/error_rate'].values[-1]
            self.assertEqual(df['preemption_count'].values[-1], 0)
            self.assertLess(train_err, 0.9)

        self.assertEqual(set(df.columns.values), set(get_column_names()))

        model = model_cls(hps, {'apply_one_hot_in_loss': False}, loss_name,
                          metrics_name)

        # Test reload from the checkpoint by increasing num_train_steps.
        num_train_steps_reload = 100
        epoch_reports = list(
            trainer.train(
                train_dir=self.test_dir,
                model=model,
                dataset_builder=lambda *unused_args, **unused_kwargs: dataset,
                initializer=initializer,
                num_train_steps=num_train_steps_reload,
                hps=hps,
                rng=rng,
                eval_batch_size=eval_batch_size,
                eval_num_batches=eval_num_batches,
                eval_train_num_batches=eval_num_batches,
                eval_frequency=eval_every,
                checkpoint_steps=checkpoint_steps,
                metrics_logger=metrics_logger,
                init_logger=init_logger))
        self.assertLen(epoch_reports,
                       (num_train_steps_reload - num_train_steps) / eval_every)
        with tf.io.gfile.GFile(os.path.join(self.test_dir,
                                            'measurements.csv')) as f:
            df = pandas.read_csv(f)
            train_err = df['train/error_rate'].values[-1]
            train_loss = df['train/ce_loss'].values[-1]
            self.assertLess(train_err, 0.35)
            self.assertLess(train_loss, 0.1)

            self.assertEqual(df['valid/num_examples'].values[-1],
                             eval_num_batches * eval_batch_size)
            self.assertEqual(df['preemption_count'].values[-1], 1)
            # Check that the correct learning rate was saved in the measurements file.
            final_learning_rate = df['learning_rate'].values[-1]
            final_step = df['global_step'].values[-1]
            self.assertEqual(num_train_steps_reload, final_step)

            # final_step will be one larger than the last step used to calculate the
            # lr_decay, hense we plug in (final_step - 1) to the decay formula.
            # Note that there is a small numerical different here with np vs jnp.
            decay_factor = (1 + np.cos(
                (final_step - 1) / num_train_steps_reload * np.pi)) * 0.5
            self.assertEqual(float(final_learning_rate),
                             hps.lr_hparams['base_lr'] * decay_factor)

        self.assertEqual(set(df.columns.values), set(get_column_names()))
Exemplo n.º 17
0
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds


CROP_PADDING = 32
MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255]
STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255]


DEFAULT_HPARAMS = config_dict.ConfigDict(dict(
    input_shape=(224, 224, 3),
    output_shape=(1000,),
    train_size=1281167,
    valid_size=50000,
    use_inception_crop=False,
    use_mixup=False,
    mixup={'alpha': 0.5},
    use_randaug=False,
    randaug={
        'magnitude': 15,
        'num_layers': 2
    }))

    # pylint:disable=raise-missing-from
METADATA = {
    'apply_one_hot_in_loss': False,
}


def distorted_bounding_box_crop(image_bytes,
                                bbox,
Exemplo n.º 18
0
FAKE_MODEL_DEFAULT_HPARAMS = config_dict.ConfigDict(
    dict(
        num_filters=16,
        num_layers=18,  # Must be one of [18, 34, 50, 101, 152, 200]
        layer_rescale_factors={},
        lr_hparams={
            'batch_size': 128,
            'base_lr': 10.0,
            'decay_end': -1,
            'end_lr': 1e-4,
            'power': 2.0,
            'schedule': 'mlperf_polynomial',
            'start_lr': 0.0,
            'steps_per_epoch': 10009.250000000002,
            'warmup_steps': 18,
        },
        optimizer='mlperf_lars_resnet',
        opt_hparams={
            'weight_decay': 2e-4,
            'beta': 0.9
        },
        batch_size=128,
        l2_decay_factor=None,
        l2_decay_rank_threshold=2,
        label_smoothing=.1,
        use_shallue_label_smoothing=False,
        model_dtype='float32',
        virtual_batch_size=64,
        data_format='NHWC',
    ))
Exemplo n.º 19
0
 def test_mlperf_schedule(self):
     """Test there are no changes to the MLPerf polynomial decay schedule."""
     expected_lrs = [
         0.0,
         0.2,
         0.4,
         0.6,
         0.8,
         1.0,
         1.2,
         1.4,
         1.6,
         1.8,
         2.0,
         2.2,
         2.4,
         2.6,
         2.8,
         3.0,
         3.2,
         3.4,
         3.6,
         3.8,
         4.0,
         4.2,
         4.4,
         4.6,
         4.8,
         5.0,
         5.2,
         5.4,
         5.6,
         5.8,
         6.0,
         6.2,
         6.4,
         6.6,
         6.8,
         7.0,
         7.2,
         7.4,
         7.6,
         7.8,
         8.0,
         8.2,
         8.4,
         8.6,
         8.8,
         9.0,
         9.2,
         9.4,
         9.6,
         9.8,
         10.0,
         9.802962,
         9.607885,
         9.414769,
         9.223614,
         9.034419,
         8.847184,
         8.661909,
         8.478596,
         8.297242,
         8.117851,
         7.940418,
         7.764947,
         7.5914364,
         7.419886,
         7.2502966,
         7.082668,
         6.917,
         6.7532916,
         6.591545,
         6.4317584,
         6.273932,
         6.1180663,
         5.964162,
         5.812217,
         5.662234,
         5.5142093,
         5.368148,
         5.2240453,
         5.0819044,
         4.941723,
         4.803503,
         4.6672425,
         4.532944,
         4.4006047,
         4.2702274,
         4.1418095,
         4.0153522,
         3.8908558,
         3.7683203,
         3.647745,
         3.5291305,
         3.4124763,
         3.297783,
         3.18505,
         3.0742776,
         2.965466,
         2.858614,
         2.753724,
         2.6507936,
         2.5498245,
         2.4508152,
         2.353767,
         2.2586792,
         2.1655521,
         2.0743854,
         1.9851794,
         1.8979341,
         1.8126491,
         1.7293249,
         1.6479613,
         1.568558,
         1.4911155,
         1.4156334,
         1.342112,
         1.2705511,
         1.2009507,
         1.133311,
         1.0676318,
         1.0039133,
         0.94215524,
         0.8823574,
         0.8245205,
         0.7686443,
         0.71472853,
         0.66277343,
         0.61277884,
         0.56474483,
         0.5186714,
         0.4745585,
         0.43240622,
         0.39221448,
         0.35398334,
         0.31771275,
         0.28340274,
         0.2510533,
         0.22066444,
         0.19223614,
         0.16576843,
         0.14126128,
         0.11871469,
         0.098128565,
         0.079503134,
         0.06283828,
         0.048134,
         0.035390284,
         0.02460714,
         0.01578457,
         0.00892257,
         0.004021142,
     ]
     hps = config_dict.ConfigDict(
         dict(
             lr_hparams={
                 'schedule': 'mlperf_polynomial',
                 'base_lr': 10.0,
                 'warmup_steps': 50,
                 'decay_end': -1,
                 'end_lr': 1e-4,
                 'power': 2.0,
                 'start_lr': 0.0,
                 'warmup_power': 1.0,
             }))
     max_training_steps = 50
     lr_fn = schedules.get_schedule_fn(hps.lr_hparams, max_training_steps)
     for step in range(max_training_steps):
         self.assertAlmostEqual(lr_fn(step), expected_lrs[step])
Exemplo n.º 20
0
import itertools

from init2winit.dataset_lib import data_utils
from init2winit.dataset_lib import mlperf_input_pipeline
import jax
from ml_collections.config_dict import config_dict
import tensorflow.compat.v2 as tf


NUM_CLASSES = 1000


DEFAULT_HPARAMS = config_dict.ConfigDict(dict(
    input_shape=(224, 224, 3),
    output_shape=(NUM_CLASSES,),
    train_size=1281167,
    valid_size=50000))

METADATA = {
    'apply_one_hot_in_loss': False,
}


def transpose_and_normalize_image(image):
  mean = tf.constant([[mlperf_input_pipeline.MEAN_RGB]], dtype=image.dtype)
  stddev = tf.constant(
      [[mlperf_input_pipeline.STDDEV_RGB]], dtype=image.dtype)
  image -= mean
  image /= stddev
  return image
Exemplo n.º 21
0
    loss_fn=None,
    flax_module=None,
    params=None,
    hps=None,
    input_shape=None,
    output_shape=None,
    rng_key=None,
    metrics_logger=None,
):
    """No-op init."""
    return params


# pylint: enable=unused-argument

DEFAULT_HPARAMS = config_dict.ConfigDict()

_ALL_INITIALIZERS = {
    'noop': (noop, DEFAULT_HPARAMS),
    'meta_init': (meta_init.meta_init, meta_init.DEFAULT_HPARAMS),
    'sparse_init': (sparse_init.sparse_init, sparse_init.DEFAULT_HPARAMS),
}


def get_initializer(initializer_name):
    """Get the corresponding initializer function based on the initializer string.

  API of an initializer:
  init_fn, hparams = get_initializer(init)
  new_params, final_l = init_fn(loss, init_params, hps,
                                num_outputs, input_shape)
Exemplo n.º 22
0
from ml_collections.config_dict import config_dict


# small hparams used for unit tests
DEFAULT_HPARAMS = config_dict.ConfigDict(dict(
    hid_sizes=[20, 10],
    kernel_scales=[1.0, 1.0, 1.0],
    lr_hparams={
        'initial_value': 0.1,
        'schedule': 'constant'
    },
    layer_rescale_factors={},
    optimizer='momentum',
    opt_hparams={
        'momentum': 0.9,
    },
    batch_size=128,
    activation_function='relu',
    l2_decay_factor=.0005,
    l2_decay_rank_threshold=2,
    label_smoothing=None,
    rng_seed=-1,
    use_shallue_label_smoothing=False,
    model_dtype='float32',
))


class FullyConnected(nn.Module):
  """Defines a fully connected neural network.
Exemplo n.º 23
0
from init2winit.model_lib import model_utils
import jax.numpy as jnp

from ml_collections.config_dict import config_dict

DEFAULT_HPARAMS = config_dict.ConfigDict(
    dict(
        num_layers=11,  # Must be one of [11, 13, 16, 19]
        layer_rescale_factors={},
        lr_hparams={
            'schedule': 'constant',
            'initial_value': 0.2,
        },
        normalizer='none',
        optimizer='momentum',
        opt_hparams={
            'momentum': 0.9,
        },
        batch_size=128,
        l2_decay_factor=0.0001,
        l2_decay_rank_threshold=2,
        label_smoothing=None,
        rng_seed=-1,
        use_shallue_label_smoothing=False,
        model_dtype='float32',
    ))


def classifier(x, num_outputs, dropout_rate, deterministic):
    """Implements the classification portion of the network."""
Exemplo n.º 24
0
    def test_early_stopping(self):
        """Test training early stopping on MNIST with a small model."""
        rng = jax.random.PRNGKey(0)

        # Set the numpy seed to make the fake data deterministc. mocking.mock_data
        # ultimately calls numpy.random.
        np.random.seed(0)

        model_name = 'fully_connected'
        loss_name = 'cross_entropy'
        metrics_name = 'classification_metrics'
        initializer_name = 'noop'
        dataset_name = 'mnist'
        model_cls = models.get_model(model_name)
        initializer = initializers.get_initializer(initializer_name)
        dataset_builder = datasets.get_dataset(dataset_name)
        hparam_overrides = {
            'lr_hparams': {
                'base_lr': 0.1,
                'schedule': 'cosine'
            },
            'batch_size': 8,
            'train_size': 160,
            'valid_size': 96,
            'test_size': 80,
        }
        input_pipeline_hps = config_dict.ConfigDict(
            dict(
                num_tf_data_prefetches=-1,
                num_device_prefetches=0,
                num_tf_data_map_parallel_calls=-1,
            ))
        hps = hyperparameters.build_hparams(
            model_name,
            initializer_name,
            dataset_name,
            hparam_file=None,
            hparam_overrides=hparam_overrides,
            input_pipeline_hps=input_pipeline_hps)

        eval_batch_size = 16
        num_examples = 256

        def as_dataset(self, *args, **kwargs):
            del args
            del kwargs

            # pylint: disable=g-long-lambda,g-complex-comprehension
            return tf.data.Dataset.from_generator(
                lambda: ({
                    'image': np.ones(shape=(28, 28, 1), dtype=np.uint8),
                    'label': 9,
                } for i in range(num_examples)),
                output_types=self.info.features.dtype,
                output_shapes=self.info.features.shape,
            )

        # This will override the tfds.load(mnist) call to return 100 fake samples.
        with tfds.testing.mock_data(as_dataset_fn=as_dataset,
                                    num_examples=num_examples):
            dataset = dataset_builder(shuffle_rng=jax.random.PRNGKey(0),
                                      batch_size=hps.batch_size,
                                      eval_batch_size=eval_batch_size,
                                      hps=hps)

        model = model_cls(hps, datasets.get_dataset_meta_data(dataset_name),
                          loss_name, metrics_name)

        num_train_steps = 40
        early_stopping_target_name = 'test/ce_loss'
        early_stopping_target_value = 0.005
        early_stopping_mode = 'less'
        eval_num_batches = 5
        eval_every = 10
        checkpoint_steps = [1, 3, 15]
        metrics_logger, init_logger = utils.set_up_loggers(self.test_dir)
        epoch_reports = list(
            trainer.train(
                train_dir=self.test_dir,
                model=model,
                dataset_builder=lambda *unused_args, **unused_kwargs: dataset,
                initializer=initializer,
                num_train_steps=num_train_steps,
                hps=hps,
                rng=rng,
                eval_batch_size=eval_batch_size,
                eval_num_batches=eval_num_batches,
                eval_train_num_batches=eval_num_batches,
                eval_frequency=eval_every,
                checkpoint_steps=checkpoint_steps,
                early_stopping_target_name=early_stopping_target_name,
                early_stopping_target_value=early_stopping_target_value,
                early_stopping_mode=early_stopping_mode,
                metrics_logger=metrics_logger,
                init_logger=init_logger))
        self.assertLen(epoch_reports, 3)
        self.assertGreater(epoch_reports[-2][early_stopping_target_name],
                           early_stopping_target_value)
        self.assertLess(epoch_reports[-1][early_stopping_target_name],
                        early_stopping_target_value)
Exemplo n.º 25
0
# small hparams used for unit tests
DEFAULT_HPARAMS = config_dict.ConfigDict(dict(
    num_filters=[64, 96, 128],
    kernel_sizes=[5, 3, 3],
    kernel_paddings=['VALID', 'VALID', 'SAME'],
    window_sizes=[3, 3, 3],
    window_paddings=['SAME', 'SAME', 'SAME'],
    strides=[2, 2, 2],
    num_dense_units=[512, 256],
    lr_hparams={
        'initial_value': 0.001,
        'schedule': 'constant'
    },
    layer_rescale_factors={},
    optimizer='momentum',
    opt_hparams={
        'momentum': 0.9,
    },
    batch_size=128,
    activation_fn='relu',
    normalizer='none',
    l2_decay_factor=.0005,
    l2_decay_rank_threshold=2,
    label_smoothing=None,
    rng_seed=-1,
    use_shallue_label_smoothing=False,
    model_dtype='float32',
))

Exemplo n.º 26
0
def build_hparams(model_name,
                  initializer_name,
                  dataset_name,
                  hparam_file,
                  hparam_overrides):
  """Build experiment hyperparameters.

  Args:
    model_name: the string model name.
    initializer_name: the string initializer name.
    dataset_name: the string dataset name.
    hparam_file: the string to the hyperparameter override file (possibly on
      CNS).
    hparam_overrides: a dict of hyperparameter override names/values, or a JSON
      string encoding of this hyperparameter override dict. Note that this is
      applied after the hyperparameter file overrides.

  Returns:
    A ConfigDict of experiment hyperparameters.
  """
  model_hps = models.get_model_hparams(model_name)
  initializer_hps = initializers.get_initializer_hparams(initializer_name)
  dataset_hps = datasets.get_dataset_hparams(dataset_name)

  merged_dict = {}

  hps_dicts = [
      hps.to_dict() for hps in [model_hps, initializer_hps, dataset_hps]
  ]

  total_hps = 0
  for hps_dict in hps_dicts:
    merged_dict.update(hps_dict)
    total_hps += len(hps_dict.keys())

  # Check that all provided have no overlap.
  if total_hps != len(merged_dict.keys()):
    raise ValueError('There is overlap in the provided hparams.')

  # Convert to the Shallue and Lee label smoothing style.
  if merged_dict.get('use_shallue_label_smoothing', False):
    num_classes = merged_dict['output_shape'][-1]
    merged_dict['label_smoothing'] *= num_classes / float(num_classes - 1)

  merged = config_dict.ConfigDict(merged_dict)
  merged.lock()

  # Subconfig "opt_hparams" and "lr_hparams" are allowed to add new fields.
  for key in ['opt_hparams', 'lr_hparams']:
    if key not in merged:
      with merged.unlocked():
        merged[key] = config_dict.ConfigDict()

  for key in ['opt_hparams', 'lr_hparams']:
    merged[key].unlock()

  if hparam_file:
    logging.info('Loading hparams from %s', hparam_file)
    with gfile.GFile(hparam_file, 'r') as f:
      hparam_dict = json.load(f)
      merged.update_from_flattened_dict(hparam_dict)

  if hparam_overrides:
    if isinstance(hparam_overrides, str):
      hparam_overrides = json.loads(hparam_overrides)

    # If the user is changing the learning rate schedule or optimizer. We must
    # wipe all of the keys from the old dictionary.
    if 'lr_hparams.schedule' in hparam_overrides and merged[
        'lr_hparams']['schedule'] != hparam_overrides[
            'lr_hparams.schedule']:
      merged['lr_hparams'] = {}
    if 'optimizer' in hparam_overrides and merged[
        'optimizer'] != hparam_overrides['optimizer']:
      merged['opt_hparams'] = {}
    merged.update_from_flattened_dict(hparam_overrides)

  return merged
Exemplo n.º 27
0
DEFAULT_HPARAMS = config_dict.ConfigDict(
    dict(
        batch_size=512,
        emb_dim=128,
        num_heads=8,
        num_layers=6,
        qkv_dim=128,
        mlp_dim=512,
        dropout_rate=0.1,
        attention_dropout_rate=0.1,
        optimizer='adam',
        opt_hparams={
            'beta1': .9,
            'beta2': .98,
            'epsilon': 1e-9,
            'weight_decay': 1e-1
        },
        layer_rescale_factors={},
        normalizer='layer_norm',
        lr_hparams={
            'initial_value': 0.05,
            'warmup_steps': 8000,
            'factors': 'constant * linear_warmup * rsqrt_decay',
            'schedule': 'compound'
        },
        label_smoothing=None,
        l2_decay_factor=None,
        l2_decay_rank_threshold=0,
        rng_seed=-1,
        use_shallue_label_smoothing=False,
        model_dtype='float32',
    ))
Exemplo n.º 28
0
from ml_collections.config_dict import config_dict

DEFAULT_HPARAMS = config_dict.ConfigDict(
    dict(
        num_filters=16,
        num_layers=18,  # Must be one of [18, 34, 50, 101, 152, 200]
        layer_rescale_factors={},
        lr_hparams={
            'schedule': 'constant',
            'initial_value': 0.2,
        },
        optimizer='momentum',
        opt_hparams={
            'momentum': 0.9,
        },
        batch_size=128,
        l2_decay_factor=0.0001,
        l2_decay_rank_threshold=2,
        label_smoothing=None,
        rng_seed=-1,
        use_shallue_label_smoothing=False,
        batch_norm_momentum=0.9,
        batch_norm_epsilon=1e-5,
        # Make this a string to avoid having to import jnp into the configs.
        model_dtype='float32',
        virtual_batch_size=None,
        data_format='NHWC',
    ))


class BasicResidualBlock(nn.Module):
Exemplo n.º 29
0
def _run(train_fn, dataset_name, eval_batch_size, eval_num_batches,
         eval_train_num_batches, eval_frequency, checkpoint_steps,
         num_tf_data_prefetches, num_device_prefetches,
         num_tf_data_map_parallel_calls, early_stopping_target_name,
         early_stopping_target_value, early_stopping_mode, eval_steps,
         hparam_file, hparam_overrides, initializer_name, model_name,
         loss_name, metrics_name, num_train_steps, experiment_dir, worker_id,
         training_metrics_config, callback_configs, external_checkpoint_path):
    """Function that runs a Jax experiment. See flag definitions for args."""
    model_cls = models.get_model(model_name)
    initializer = initializers.get_initializer(initializer_name)
    dataset_builder = datasets.get_dataset(dataset_name)
    dataset_meta_data = datasets.get_dataset_meta_data(dataset_name)
    input_pipeline_hps = config_dict.ConfigDict(
        dict(
            num_tf_data_prefetches=num_tf_data_prefetches,
            num_device_prefetches=num_device_prefetches,
            num_tf_data_map_parallel_calls=num_tf_data_map_parallel_calls,
        ))

    merged_hps = hyperparameters.build_hparams(
        model_name=model_name,
        initializer_name=initializer_name,
        dataset_name=dataset_name,
        hparam_file=hparam_file,
        hparam_overrides=hparam_overrides,
        input_pipeline_hps=input_pipeline_hps)

    # Note that one should never tune an RNG seed!!! The seed is only included in
    # the hparams for convenience of running hparam trials with multiple seeds per
    # point.
    rng_seed = merged_hps.rng_seed
    if merged_hps.rng_seed < 0:
        rng_seed = _create_synchronized_rng_seed()
    xm_experiment = None
    xm_work_unit = None
    if jax.process_index() == 0:
        logging.info('Running with seed %d', rng_seed)
    rng = jax.random.PRNGKey(rng_seed)

    # Build the loss_fn, metrics_bundle, and flax_module.
    model = model_cls(merged_hps, dataset_meta_data, loss_name, metrics_name)
    trial_dir = os.path.join(experiment_dir, str(worker_id))
    meta_data_path = os.path.join(trial_dir, 'meta_data.json')
    meta_data = {'worker_id': worker_id, 'status': 'incomplete'}
    if jax.process_index() == 0:
        logging.info('rng: %s', rng)
        gfile.makedirs(trial_dir)
        # Set up the metric loggers for host 0.
        metrics_logger, init_logger = utils.set_up_loggers(
            trial_dir, xm_work_unit)
        hparams_fname = os.path.join(trial_dir, 'hparams.json')
        logging.info('saving hparams to %s', hparams_fname)
        with gfile.GFile(hparams_fname, 'w') as f:
            f.write(merged_hps.to_json())
        _write_trial_meta_data(meta_data_path, meta_data)
    else:
        metrics_logger = None
        init_logger = None
    try:
        epoch_reports = list(
            train_fn(trial_dir,
                     model,
                     dataset_builder,
                     initializer,
                     num_train_steps,
                     merged_hps,
                     rng,
                     eval_batch_size,
                     eval_num_batches,
                     eval_train_num_batches,
                     eval_frequency,
                     checkpoint_steps,
                     early_stopping_target_name,
                     early_stopping_target_value,
                     early_stopping_mode,
                     eval_steps,
                     metrics_logger,
                     init_logger,
                     training_metrics_config=training_metrics_config,
                     callback_configs=callback_configs,
                     external_checkpoint_path=external_checkpoint_path))
        logging.info(epoch_reports)
        meta_data['status'] = 'done'
    except utils.TrainingDivergedError as err:
        meta_data['status'] = 'diverged'
        raise err
    finally:
        if jax.process_index() == 0:
            _write_trial_meta_data(meta_data_path, meta_data)
Exemplo n.º 30
0
    def test_text_model_trainer(self):
        """Test training of a small transformer model on fake data."""
        rng = jax.random.PRNGKey(42)

        # Set the numpy seed to make the fake data deterministc. mocking.mock_data
        # ultimately calls numpy.random.
        np.random.seed(0)

        model_cls = models.get_model('transformer')
        loss_name = 'cross_entropy'
        metrics_name = 'classification_metrics'
        hps = config_dict.ConfigDict({
            # Architecture Hparams.
            'batch_size': _TEXT_BATCH_SIZE,
            'emb_dim': 32,
            'num_heads': 2,
            'num_layers': 3,
            'qkv_dim': 32,
            'mlp_dim': 64,
            'max_target_length': 64,
            'max_eval_target_length': 64,
            'input_shape': (64, ),
            'output_shape': (_VOCAB_SIZE, ),
            'dropout_rate': 0.1,
            'attention_dropout_rate': 0.1,
            'layer_rescale_factors': {},
            'optimizer': 'momentum',
            'normalizer': 'layer_norm',
            'opt_hparams': {
                'momentum': 0.9,
            },
            'lr_hparams': {
                'base_lr': 0.005,
                'schedule': 'constant'
            },
            # Training HParams.
            'l2_decay_factor': 1e-4,
            'l2_decay_rank_threshold': 2,
            'train_size': _TEXT_TRAIN_SIZE,
            'gradient_clipping': 0.0,
            'model_dtype': 'float32',
            'decode': False,
            'num_device_prefetches': 0,
        })
        initializer = initializers.get_initializer('noop')
        eval_num_batches = 5
        dataset, dataset_meta_data = _get_fake_text_dataset(
            batch_size=hps.batch_size, eval_num_batches=eval_num_batches)
        eval_batch_size = hps.batch_size

        model = model_cls(hps, dataset_meta_data, loss_name, metrics_name)

        eval_every = 10
        checkpoint_steps = []
        num_train_steps = _TEXT_TRAIN_SIZE // _TEXT_BATCH_SIZE * 3

        metrics_logger, init_logger = utils.set_up_loggers(self.test_dir)
        _ = list(
            trainer.train(
                train_dir=self.test_dir,
                model=model,
                dataset_builder=lambda *unused_args, **unused_kwargs: dataset,
                initializer=initializer,
                num_train_steps=num_train_steps,
                hps=hps,
                rng=rng,
                eval_batch_size=eval_batch_size,
                eval_num_batches=eval_num_batches,
                eval_train_num_batches=eval_num_batches,
                eval_frequency=eval_every,
                checkpoint_steps=checkpoint_steps,
                metrics_logger=metrics_logger,
                init_logger=init_logger))

        with tf.io.gfile.GFile(os.path.join(self.test_dir,
                                            'measurements.csv')) as f:
            df = pandas.read_csv(f)
            train_err = df['train/error_rate'].values[-1]
            # Note that upgrading to Linen made this fail at 0.6.
            self.assertLess(train_err, 0.7)

        self.assertEqual(set(df.columns.values), set(get_column_names()))
        prev_train_err = train_err

        # Test reload from the checkpoint by increasing num_train_steps.
        num_train_steps_reload = _TEXT_TRAIN_SIZE // _TEXT_BATCH_SIZE * 6
        _ = list(
            trainer.train(
                train_dir=self.test_dir,
                model=model,
                dataset_builder=lambda *unused_args, **unused_kwargs: dataset,
                initializer=initializer,
                num_train_steps=num_train_steps_reload,
                hps=hps,
                rng=rng,
                eval_batch_size=eval_batch_size,
                eval_num_batches=eval_num_batches,
                eval_train_num_batches=eval_num_batches,
                eval_frequency=eval_every,
                checkpoint_steps=checkpoint_steps,
                metrics_logger=metrics_logger,
                init_logger=init_logger))
        with tf.io.gfile.GFile(os.path.join(self.test_dir,
                                            'measurements.csv')) as f:
            df = pandas.read_csv(f)
            train_err = df['train/error_rate'].values[-1]
            train_loss = df['train/ce_loss'].values[-1]
            # Note that upgrading to Linen made this fail at 0.45.
            self.assertLess(train_err, 0.67)
            self.assertLess(train_err, prev_train_err)
            # Note that upgrading to Linen made this fail at 0.9.
            self.assertLess(train_loss, 1.35)

            self.assertEqual(df['valid/num_examples'].values[-1],
                             eval_num_batches * eval_batch_size * _MAX_LEN)
            # Check that the correct learning rate was saved in the measurements file.
            final_step = df['global_step'].values[-1]
            self.assertEqual(num_train_steps_reload, final_step)

        self.assertEqual(set(df.columns.values), set(get_column_names()))