Beispiel #1
0
  def test_initialize_rescale(self):
    """Test rescaling a single layer of a model."""
    input_shape = (28, 28, 1)
    output_shape = (10,)
    model_str = 'fully_connected'
    model_cls = models.get_model(model_str)
    model_hps = models.get_model_hparams(model_str)
    loss_name = 'cross_entropy'
    metrics_name = 'classification_metrics'
    hps = copy.copy(model_hps)
    hps.update({'output_shape': output_shape})
    rng = jax.random.PRNGKey(0)
    model = model_cls(hps, {}, loss_name, metrics_name)
    initializer = initializers.get_initializer('noop')

    rng, init_rng = jax.random.split(rng)

    # First initialize with no rescale.
    flax_module, _ = trainer.initialize(
        model.flax_module_def,
        initializer,
        model.loss_fn,
        input_shape,
        output_shape,
        hps,
        init_rng,
        metrics_logger=None)

    utils.log_pytree_shape_and_statistics(flax_module.params)
    # Now rescale a layer by 100.
    rescale_factor = 100
    hps.layer_rescale_factors = {
        '/Dense_1/kernel': rescale_factor,
    }

    rescaled_module, _ = trainer.initialize(
        model.flax_module_def,
        initializer,
        model.loss_fn,
        input_shape,
        output_shape,
        hps,
        init_rng,
        metrics_logger=None)

    # Check the right variable is rescaled
    v1 = flax_module.params['Dense_1']['kernel']
    v2 = rescaled_module.params['Dense_1']['kernel']
    diff = np.linalg.norm(v1.reshape(-1) * rescale_factor - v2.reshape(-1))
    self.assertAlmostEqual(diff, 0.0)

    # Check that other variables are the same
    v1 = flax_module.params['Dense_2']['kernel']
    v2 = rescaled_module.params['Dense_2']['kernel']
    diff = np.linalg.norm(v1.reshape(-1) - v2.reshape(-1))
    self.assertAlmostEqual(diff, 0.0)
Beispiel #2
0
def main(unused_argv):
    # Necessary to use the tfds loader.
    tf.enable_v2_behavior()

    if jax.process_count() > 1:
        # TODO(ankugarg): Add support for multihost inference.
        raise NotImplementedError(
            'BLEU eval does not support multihost inference.')

    rng = jax.random.PRNGKey(FLAGS.seed)

    mt_eval_config = json.loads(FLAGS.mt_eval_config)

    if FLAGS.experiment_config_filename:
        with tf.io.gfile.GFile(FLAGS.experiment_config_filename) as f:
            experiment_config = json.load(f)
        if jax.process_index() == 0:
            logging.info('experiment_config: %r', experiment_config)
        dataset_name = experiment_config['dataset']
        model_name = experiment_config['model']
    else:
        assert FLAGS.dataset and FLAGS.model
        dataset_name = FLAGS.dataset
        model_name = FLAGS.model

    if jax.process_index() == 0:
        logging.info('argv:\n%s', ' '.join(sys.argv))
        logging.info('device_count: %d', jax.device_count())
        logging.info('num_hosts : %d', jax.host_count())
        logging.info('host_id : %d', jax.host_id())

    model_class = models.get_model(model_name)
    dataset_builder = datasets.get_dataset(dataset_name)
    dataset_meta_data = datasets.get_dataset_meta_data(dataset_name)

    hparam_overrides = None
    if FLAGS.hparam_overrides:
        if isinstance(FLAGS.hparam_overrides, str):
            hparam_overrides = json.loads(FLAGS.hparam_overrides)

    merged_hps = hyperparameters.build_hparams(
        model_name=model_name,
        initializer_name=experiment_config['initializer'],
        dataset_name=dataset_name,
        hparam_file=FLAGS.trial_hparams_filename,
        hparam_overrides=hparam_overrides)

    if jax.process_index() == 0:
        logging.info('Merged hps are: %s', json.dumps(merged_hps.to_json()))

    evaluator = bleu_evaluator.BLEUEvaluator(FLAGS.checkpoint_dir, merged_hps,
                                             rng, model_class, dataset_builder,
                                             dataset_meta_data, mt_eval_config)
    evaluator.translate_and_calculate_bleu()
Beispiel #3
0
def main(unused_argv):
  # Necessary to use the tfds imagenet loader.
  tf.enable_v2_behavior()


  rng = jax.random.PRNGKey(FLAGS.seed)

  if FLAGS.hessian_eval_config:
    hessian_eval_config = json.loads(FLAGS.hessian_eval_config)
  else:
    hessian_eval_config = hessian_eval.DEFAULT_EVAL_CONFIG

  if FLAGS.experiment_config_filename:
    with tf.io.gfile.GFile(FLAGS.experiment_config_filename, 'r') as f:
      experiment_config = json.load(f)
    if jax.process_index() == 0:
      logging.info('experiment_config: %r', experiment_config)
    dataset_name = experiment_config['dataset']
    model_name = experiment_config['model']
  else:
    assert FLAGS.dataset and FLAGS.model
    dataset_name = FLAGS.dataset
    model_name = FLAGS.model

  if jax.process_index() == 0:
    logging.info('argv:\n%s', ' '.join(sys.argv))
    logging.info('device_count: %d', jax.device_count())
    logging.info('num_hosts : %d', jax.process_count())
    logging.info('host_id : %d', jax.process_index())

  model = models.get_model(model_name)
  dataset_builder = datasets.get_dataset(dataset_name)
  dataset_meta_data = datasets.get_dataset_meta_data(dataset_name)

  with tf.io.gfile.GFile(FLAGS.trial_hparams_filename, 'r') as f:
    hps = config_dict.ConfigDict(json.load(f))

  if FLAGS.hparam_overrides:
    if isinstance(FLAGS.hparam_overrides, str):
      hparam_overrides = json.loads(FLAGS.hparam_overrides)
    hps.update_from_flattened_dict(hparam_overrides)
  run_lanczos.eval_checkpoints(
      FLAGS.checkpoint_dir,
      hps,
      rng,
      FLAGS.eval_num_batches,
      model,
      dataset_builder,
      dataset_meta_data,
      hessian_eval_config,
      FLAGS.min_global_step,
      FLAGS.max_global_step)
Beispiel #4
0
    def test_nqm(self):
        """Test the noisy quadratic model."""
        batch_size = 2
        dim = 10
        model_hps = config_dict.ConfigDict(
            dict(
                input_shape=(dim, ),
                output_shape=(1, ),
                rng_seed=-1,
                hessian_decay_power=1.0,
                noise_decay_power=1.0,
                nqm_mode='diagH_diagC',
                model_dtype='float32',
            ))

        model_cls = models.get_model('nqm')
        rng = jax.random.PRNGKey(0)
        model = model_cls(model_hps, {})
        noise_eps = jnp.array(np.random.normal(size=(batch_size, dim)))
        rng, params_rng = jax.random.split(rng)
        _, flax_module = model.flax_module_def.create_by_shape(
            params_rng, [(batch_size, dim)])

        model_x = flax_module.params['x']

        def loss(model, inputs):
            return model(inputs)

        grad_loss = jax.grad(loss)

        hessian = np.diag(
            np.array([
                1.0 / np.power(i, model_hps.hessian_decay_power)
                for i in range(1, dim + 1)
            ]))
        noise_matrix = np.diag(
            np.array([
                1.0 / np.power(i, model_hps.noise_decay_power / 2.0)
                for i in range(1, dim + 1)
            ]))

        noise = jnp.dot(noise_eps, noise_matrix)
        mean_noise = np.mean(noise, axis=0)

        # NQM gradient = Hx + eps   where eps ~ N(0, C / batch_size).
        expected_grad = np.dot(hessian, model_x) + mean_noise

        g = grad_loss(flax_module, noise_eps).params['x']

        grad_error = np.sum(np.abs(g - expected_grad))
        self.assertAlmostEqual(grad_error, 0.0, places=5)
Beispiel #5
0
  def test_nqm(self):
    """Test the noisy quadratic model."""
    batch_size = 2
    dim = 10
    model_hps = config_dict.ConfigDict(
        dict(
            input_shape=(dim,),
            output_shape=(1,),
            rng_seed=-1,
            hessian_decay_power=1.0,
            noise_decay_power=1.0,
            nqm_mode='diagH_diagC',
            model_dtype='float32',
        ))

    model_cls = models.get_model('nqm')
    params_rng = jax.random.PRNGKey(0)
    model = model_cls(model_hps, {}, None, None)
    noise_eps = jnp.array(np.random.normal(size=(batch_size, dim)))
    xs = np.zeros((batch_size, dim))
    model_init_fn = jax.jit(
        functools.partial(model.flax_module.init, train=False))
    params = model_init_fn({'params': params_rng}, xs)['params']
    model_x = params['x']

    def loss(params, inputs):
      return model.training_cost(params, batch=inputs)

    grad_loss = jax.grad(loss, has_aux=True)

    hessian = np.diag(
        np.array([
            1.0 / np.power(i, model_hps.hessian_decay_power)
            for i in range(1, dim + 1)
        ]))
    noise_matrix = np.diag(
        np.array([
            1.0 / np.power(i, model_hps.noise_decay_power / 2.0)
            for i in range(1, dim + 1)
        ]))

    noise = jnp.dot(noise_eps, noise_matrix)
    mean_noise = np.mean(noise, axis=0)

    # NQM gradient = Hx + eps   where eps ~ N(0, C / batch_size).
    expected_grad = np.dot(hessian, model_x) + mean_noise

    g = grad_loss(params, {'inputs': noise_eps})[0]['x']

    grad_error = np.sum(np.abs(g - expected_grad))
    self.assertAlmostEqual(grad_error, 0.0, places=5)
Beispiel #6
0
 def setUp(self):
     super(CheckpointTest, self).setUp()
     self.test_dir = tempfile.mkdtemp()
     loss_name = 'cross_entropy'
     metrics_name = 'classification_metrics'
     model = models.get_model('fully_connected')
     model_hps = models.get_model_hparams('fully_connected')
     hps = copy.copy(model_hps)
     hps.update({'output_shape': OUTPUT_SHAPE})
     rng = jax.random.PRNGKey(0)
     model = model(hps, {}, loss_name, metrics_name)
     xs = jnp.array(np.random.normal(size=INPUT_SHAPE))
     rng, params_rng = jax.random.split(rng)
     _, self.flax_module = model.flax_module_def.create(params_rng, xs)
Beispiel #7
0
 def setUp(self):
     super(CheckpointTest, self).setUp()
     self.test_dir = tempfile.mkdtemp()
     loss_name = 'cross_entropy'
     metrics_name = 'classification_metrics'
     model = models.get_model('fully_connected')
     model_hps = models.get_model_hparams('fully_connected')
     hps = copy.copy(model_hps)
     hps.update({'output_shape': OUTPUT_SHAPE})
     rng = jax.random.PRNGKey(0)
     model = model(hps, {}, loss_name, metrics_name)
     xs = jnp.array(np.random.normal(size=INPUT_SHAPE))
     rng, params_rng = jax.random.split(rng)
     model_init_fn = jax.jit(
         functools.partial(model.flax_module.init, train=False))
     init_dict = model_init_fn({'params': params_rng}, xs)
     self.params = init_dict['params']
def _load_model(model_name):
    """Load a test model."""
    rng = jax.random.PRNGKey(0)
    model_cls = models.get_model(model_name)
    loss_name = 'cross_entropy'
    metrics_name = 'classification_metrics'
    model_hps = models.get_model_hparams(model_name)

    hps = copy.copy(model_hps)
    hps.update({'output_shape': OUTPUT_SHAPE})
    model = model_cls(hps, {}, loss_name, metrics_name)

    input_shape = (BATCH_SIZE, ) + MODEL_TO_INPUT_SHAPE[model_name]
    _, flax_module = model.flax_module_def.create_by_shape(rng, [input_shape],
                                                           train=True)
    utils.log_pytree_shape_and_statistics(flax_module.params)
    return flax_module, input_shape
Beispiel #9
0
  def test_graph_model(self):
    """Test forward pass of the GNN model."""
    edge_input_shape = (5,)
    node_input_shape = (5,)
    output_shape = (5,)
    model_str = 'gnn'
    model_hps = models.get_model_hparams(model_str)
    model_hps.update({'output_shape': output_shape,
                      'latent_dim': 10,
                      'hidden_dims': (10,),
                      'batch_size': 5,
                      'normalizer': 'batch_norm'})
    model_cls = models.get_model(model_str)
    rng = jax.random.PRNGKey(0)
    dropout_rng, params_rng = jax.random.split(rng)
    loss = 'sigmoid_binary_cross_entropy'
    metrics = 'binary_classification_metrics'
    model = model_cls(model_hps, {}, loss, metrics)

    num_graphs = 5
    node_per_graph = 3
    edge_per_graph = 9
    inputs = jraph.get_fully_connected_graph(
        n_node_per_graph=node_per_graph,
        n_graph=num_graphs,
        node_features=np.ones((num_graphs * node_per_graph,) +
                              node_input_shape),
    )
    inputs = inputs._replace(
        edges=np.ones((num_graphs * edge_per_graph,) + edge_input_shape))
    padded_inputs = jraph.pad_with_graphs(inputs, 20, 50, 7)
    model_init_fn = jax.jit(
        functools.partial(model.flax_module.init, train=False))
    init_dict = model_init_fn({'params': params_rng}, padded_inputs)
    params = init_dict['params']
    batch_stats = init_dict['batch_stats']

    # Check that the forward pass works with mutated batch_stats.
    outputs, _ = model.flax_module.apply(
        {'params': params, 'batch_stats': batch_stats},
        padded_inputs,
        mutable=['batch_stats'],
        rngs={'dropout': dropout_rng},
        train=True)
    self.assertEqual(outputs.shape, (7,) + output_shape)
Beispiel #10
0
def _load_model(model_name):
    """Load a test model."""
    rng = jax.random.PRNGKey(0)
    model_cls = models.get_model(model_name)
    loss_name = 'cross_entropy'
    metrics_name = 'classification_metrics'
    model_hps = models.get_model_hparams(model_name)

    hps = copy.copy(model_hps)
    hps.update({'output_shape': OUTPUT_SHAPE})
    model = model_cls(hps, {}, loss_name, metrics_name)

    input_shape = (BATCH_SIZE, ) + MODEL_TO_INPUT_SHAPE[model_name]
    model_init_fn = jax.jit(
        functools.partial(model.flax_module.init, train=True))
    init_dict = model_init_fn({'params': rng}, jnp.zeros(input_shape))
    # Trainable model parameters.
    params = init_dict['params']
    utils.log_pytree_shape_and_statistics(params)
    return model.flax_module, params, input_shape, hps
Beispiel #11
0
  def test_classification_model(self, model_str):
    """Test forward pass of the image models."""

    model_cls = models.get_model(model_str)
    model_hps = models.get_model_hparams(model_str)
    loss = 'cross_entropy'
    metrics = 'classification_metrics'
    hps = copy.copy(model_hps)
    hps.update({'output_shape': OUTPUT_SHAPE['classification']})
    rng = jax.random.PRNGKey(0)
    dropout_rng, params_rng = jax.random.split(rng)
    model = model_cls(hps, {}, loss, metrics)
    xs = jnp.array(np.random.normal(size=INPUT_SHAPE['classification']))
    model_init_fn = jax.jit(
        functools.partial(model.flax_module.init, train=False))
    init_dict = model_init_fn({'params': params_rng}, xs)
    params = init_dict['params']
    batch_stats = init_dict.get('batch_stats', {})

    # Check that the forward pass works with mutated batch_stats.
    outputs, new_batch_stats = model.flax_module.apply(
        {'params': params, 'batch_stats': batch_stats},
        xs,
        mutable=['batch_stats'],
        rngs={'dropout': dropout_rng},
        train=True)
    self.assertEqual(outputs.shape, (INPUT_SHAPE['classification'][0],
                                     OUTPUT_SHAPE['classification'][-1]))

    # If it's a batch norm model check the batch stats changed.
    if batch_stats:
      bflat, _ = ravel_pytree(batch_stats)
      new_bflat, _ = ravel_pytree(new_batch_stats)
      self.assertFalse(jnp.array_equal(bflat, new_bflat))

    # Test batch_norm in inference mode.
    outputs = model.flax_module.apply(
        {'params': params, 'batch_stats': batch_stats}, xs, train=False)
    self.assertEqual(
        outputs.shape,
        (INPUT_SHAPE['classification'][0], OUTPUT_SHAPE['classification'][-1]))
Beispiel #12
0
  def test_autoencoder_model(self, model_str):
    """Test forward pass of the autoencoder models."""

    model_cls = models.get_model(model_str)
    model_hps = models.get_model_hparams(model_str)
    loss = 'sigmoid_binary_cross_entropy'
    metrics = 'binary_autoencoder_metrics'
    hps = copy.copy(model_hps)
    hps.update({'output_shape': OUTPUT_SHAPE[model_str]})
    params_rng = jax.random.PRNGKey(0)
    model = model_cls(hps, {}, loss, metrics)
    xs = jnp.array(np.random.normal(size=INPUT_SHAPE[model_str]))
    model_init_fn = jax.jit(
        functools.partial(model.flax_module.init, train=False))
    init_dict = model_init_fn({'params': params_rng}, xs)
    params = init_dict['params']
    batch_stats = init_dict.get('batch_stats', {})

    # Check that the forward pass works with mutated batch_stats.
    outputs, new_batch_stats = model.flax_module.apply(
        {'params': params, 'batch_stats': batch_stats},
        xs,
        mutable=['batch_stats'],
        train=True)
    self.assertEqual(
        outputs.shape,
        tuple([INPUT_SHAPE[model_str][0]] + list(OUTPUT_SHAPE[model_str])))

    # If it's a batch norm model check the batch stats changed.
    if batch_stats:
      bflat, _ = ravel_pytree(batch_stats)
      new_bflat, _ = ravel_pytree(new_batch_stats)
      self.assertFalse(jnp.array_equal(bflat, new_bflat))

    # Test batch_norm in inference mode.
    outputs = model.flax_module.apply(
        {'params': params, 'batch_stats': batch_stats}, xs, train=False)
    self.assertEqual(
        outputs.shape,
        tuple([INPUT_SHAPE[model_str][0]] + list(OUTPUT_SHAPE[model_str])))
Beispiel #13
0
    def test_autoencoder_model(self, model_str):
        """Test forward pass of the autoencoder models."""

        model_cls = models.get_model(model_str)
        model_hps = models.get_model_hparams(model_str)
        loss = 'sigmoid_binary_cross_entropy'
        metrics = 'binary_autoencoder_metrics'
        hps = copy.copy(model_hps)
        hps.update({'output_shape': OUTPUT_SHAPE[model_str]})
        rng = jax.random.PRNGKey(0)
        model = model_cls(hps, {}, loss, metrics)
        xs = jnp.array(np.random.normal(size=INPUT_SHAPE[model_str]))
        rng, params_rng = jax.random.split(rng)
        with nn.stateful() as batch_stats:
            with nn.stochastic(params_rng):
                _, flax_module = model.flax_module_def.create(params_rng, xs)

        # Check that the forward pass works with mutated batch_stats.
        with nn.stateful(batch_stats) as new_batch_stats:
            with nn.stochastic(params_rng):
                outputs = flax_module(xs)
                self.assertEqual(
                    outputs.shape,
                    tuple([INPUT_SHAPE[model_str][0]] +
                          list(OUTPUT_SHAPE[model_str])))

        # If it's a batch norm model check the batch stats changed.
        if batch_stats.as_dict():
            bflat, _ = ravel_pytree(batch_stats)
            new_bflat, _ = ravel_pytree(new_batch_stats)
            self.assertFalse(jnp.array_equal(bflat, new_bflat))

        # Test batch_norm in inference mode.
        with nn.stateful(batch_stats, mutable=False):
            outputs = flax_module(xs, train=False)
        self.assertEqual(
            outputs.shape,
            tuple([INPUT_SHAPE[model_str][0]] + list(OUTPUT_SHAPE[model_str])))
Beispiel #14
0
    def test_text_model_trainer(self):
        """Test training of a small transformer model on fake data."""
        rng = jax.random.PRNGKey(42)

        # Set the numpy seed to make the fake data deterministc. mocking.mock_data
        # ultimately calls numpy.random.
        np.random.seed(0)

        model_cls = models.get_model('transformer')
        loss_name = 'cross_entropy'
        metrics_name = 'classification_metrics'
        hps = config_dict.ConfigDict({
            # Architecture Hparams.
            'batch_size': _TEXT_BATCH_SIZE,
            'emb_dim': 32,
            'num_heads': 2,
            'num_layers': 3,
            'qkv_dim': 32,
            'mlp_dim': 64,
            'max_target_length': 64,
            'max_eval_target_length': 64,
            'input_shape': (64, ),
            'output_shape': (_VOCAB_SIZE, ),
            'dropout_rate': 0.1,
            'attention_dropout_rate': 0.1,
            'layer_rescale_factors': {},
            'optimizer': 'momentum',
            'normalizer': 'layer_norm',
            'opt_hparams': {
                'momentum': 0.9,
            },
            'lr_hparams': {
                'base_lr': 0.005,
                'schedule': 'constant'
            },
            # Training HParams.
            'l2_decay_factor': 1e-4,
            'l2_decay_rank_threshold': 2,
            'train_size': _TEXT_TRAIN_SIZE,
            'gradient_clipping': 0.0,
            'model_dtype': 'float32',
            'decode': False,
            'num_device_prefetches': 0,
        })
        initializer = initializers.get_initializer('noop')
        eval_num_batches = 5
        dataset, dataset_meta_data = _get_fake_text_dataset(
            batch_size=hps.batch_size, eval_num_batches=eval_num_batches)
        eval_batch_size = hps.batch_size

        model = model_cls(hps, dataset_meta_data, loss_name, metrics_name)

        eval_every = 10
        checkpoint_steps = []
        num_train_steps = _TEXT_TRAIN_SIZE // _TEXT_BATCH_SIZE * 3

        metrics_logger, init_logger = utils.set_up_loggers(self.test_dir)
        _ = list(
            trainer.train(
                train_dir=self.test_dir,
                model=model,
                dataset_builder=lambda *unused_args, **unused_kwargs: dataset,
                initializer=initializer,
                num_train_steps=num_train_steps,
                hps=hps,
                rng=rng,
                eval_batch_size=eval_batch_size,
                eval_num_batches=eval_num_batches,
                eval_train_num_batches=eval_num_batches,
                eval_frequency=eval_every,
                checkpoint_steps=checkpoint_steps,
                metrics_logger=metrics_logger,
                init_logger=init_logger))

        with tf.io.gfile.GFile(os.path.join(self.test_dir,
                                            'measurements.csv')) as f:
            df = pandas.read_csv(f)
            train_err = df['train/error_rate'].values[-1]
            # Note that upgrading to Linen made this fail at 0.6.
            self.assertLess(train_err, 0.7)

        self.assertEqual(set(df.columns.values), set(get_column_names()))
        prev_train_err = train_err

        # Test reload from the checkpoint by increasing num_train_steps.
        num_train_steps_reload = _TEXT_TRAIN_SIZE // _TEXT_BATCH_SIZE * 6
        _ = list(
            trainer.train(
                train_dir=self.test_dir,
                model=model,
                dataset_builder=lambda *unused_args, **unused_kwargs: dataset,
                initializer=initializer,
                num_train_steps=num_train_steps_reload,
                hps=hps,
                rng=rng,
                eval_batch_size=eval_batch_size,
                eval_num_batches=eval_num_batches,
                eval_train_num_batches=eval_num_batches,
                eval_frequency=eval_every,
                checkpoint_steps=checkpoint_steps,
                metrics_logger=metrics_logger,
                init_logger=init_logger))
        with tf.io.gfile.GFile(os.path.join(self.test_dir,
                                            'measurements.csv')) as f:
            df = pandas.read_csv(f)
            train_err = df['train/error_rate'].values[-1]
            train_loss = df['train/ce_loss'].values[-1]
            # Note that upgrading to Linen made this fail at 0.45.
            self.assertLess(train_err, 0.67)
            self.assertLess(train_err, prev_train_err)
            # Note that upgrading to Linen made this fail at 0.9.
            self.assertLess(train_loss, 1.35)

            self.assertEqual(df['valid/num_examples'].values[-1],
                             eval_num_batches * eval_batch_size * _MAX_LEN)
            # Check that the correct learning rate was saved in the measurements file.
            final_step = df['global_step'].values[-1]
            self.assertEqual(num_train_steps_reload, final_step)

        self.assertEqual(set(df.columns.values), set(get_column_names()))
Beispiel #15
0
  def test_shampoo_wrn(self):
    """Test distributed shampoo on fake dataset."""
    model_name = 'simple_cnn'
    model_cls = models.get_model(model_name)
    hparam_overrides = {
        'optimizer': 'distributed_shampoo',
        'batch_size': 1,
        'train_size': 10,
        'valid_size': 10,
        'input_shape': (32, 32, 3),
        'output_shape': (10,),
        'opt_hparams': {
            'block_size': 32,
            'beta1': 0.9,
            'beta2': 0.999,
            'diagonal_epsilon': 1e-10,
            'matrix_epsilon': 1e-6,
            'weight_decay': 0.0,
            'start_preconditioning_step': 5,
            'preconditioning_compute_steps': 1,
            'statistics_compute_steps': 1,
            'best_effort_shape_interpretation': True,
            'graft_type': distributed_shampoo.GraftingType.SGD,
            'nesterov': True,
            'exponent_override': 0,
            'batch_axis_name': 'batch',
            'num_devices_for_pjit': None,
            'shard_optimizer_states': False,
            'inverse_failure_threshold': 0.1,
            'clip_by_scaled_gradient_norm': None,
            'precision': lax.Precision.HIGHEST,
            'moving_average_for_momentum': False,
            'skip_preconditioning_dim_size_gt': 4096,
            'best_effort_memory_usage_reduction': False,
        },
    }
    input_pipeline_hps = config_dict.ConfigDict(dict(
        num_tf_data_prefetches=-1,
        num_device_prefetches=0,
        num_tf_data_map_parallel_calls=-1,
    ))
    hps = hyperparameters.build_hparams(
        model_name,
        initializer_name='noop',
        dataset_name='fake',
        hparam_file=None,
        hparam_overrides=hparam_overrides,
        input_pipeline_hps=input_pipeline_hps)
    initializer = initializers.get_initializer('noop')
    dataset_builder = datasets.get_dataset('fake')
    dataset = dataset_builder(
        shuffle_rng=jax.random.PRNGKey(0),
        batch_size=hps.batch_size,
        eval_batch_size=hps.batch_size,
        hps=hps)

    loss_name = 'cross_entropy'
    metrics_name = 'classification_metrics'
    dataset_meta_data = datasets.get_dataset_meta_data('fake')
    model = model_cls(hps, dataset_meta_data, loss_name, metrics_name)

    metrics_logger, init_logger = utils.set_up_loggers(self.test_dir)
    _ = list(
        trainer.train(
            train_dir=self.test_dir,
            model=model,
            dataset_builder=lambda *unused_args, **unused_kwargs: dataset,
            initializer=initializer,
            num_train_steps=1,
            hps=hps,
            rng=jax.random.PRNGKey(42),
            eval_batch_size=hps.batch_size,
            eval_num_batches=None,
            eval_train_num_batches=None,
            eval_frequency=10,
            checkpoint_steps=[],
            metrics_logger=metrics_logger,
            init_logger=init_logger))

    with tf.io.gfile.GFile(os.path.join(self.test_dir,
                                        'measurements.csv')) as f:
      df = pandas.read_csv(f)
      valid_ce_loss = df['valid/ce_loss'].values[-1]
      self.assertLess(valid_ce_loss, 1e-3)
Beispiel #16
0
    def test_cg_backtracking(self):
        """Tests CG backtracking."""

        model_str = 'autoencoder'
        model_cls = models.get_model(model_str)
        model_hps = models.get_model_hparams(model_str)

        loss = 'sigmoid_binary_cross_entropy'
        metrics = 'binary_autoencoder_metrics'

        input_shape = (2, 2, 1)
        output_shape = (4, )

        hps = copy.copy(model_hps)
        hps.update({
            'optimizer': 'hessian_free',
            'opt_hparams': {
                'weight_decay': 0.0,
            },
            'hid_sizes': [2],
            'activation_function': ['id'],
            'input_shape': input_shape,
            'output_shape': output_shape
        })

        model = model_cls(hps, {}, loss, metrics)

        inputs = jnp.array([[[1, 0], [1, 1]], [[1, 0], [0, 1]]])
        targets = inputs.reshape(tuple([inputs.shape[0]] + list(output_shape)))
        batch = {'inputs': inputs, 'targets': targets}

        def forward_fn(variables, inputs):
            return model.flax_module.apply(variables, inputs, train=False)

        def opt_cost(params):
            return model.loss_fn(forward_fn(params, inputs), targets)

        params = {
            'Dense_0': {
                'kernel': jnp.array([[-1., 2.], [2., 0.], [-1., 3.], [-2.,
                                                                      2.]]),
                'bias': jnp.array([0., 0.])
            },
            'Dense_1': {
                'kernel': jnp.array([[4., 2., -2., 4.], [-3., 1., 2., -4.]]),
                'bias': jnp.array([0., 0., 0., 0.])
            }
        }
        unravel_fn = ravel_pytree(params)[1]

        p1 = np.array([
            0.5, 0.2, 0.1, -0.4, -0.6, 0.4, 0.6, -0.7, 0.0, 0.5, -0.7, 0.2,
            0.1, -0.2, 0.4, -0.6, -0.8, 0.7, 0.2, 0.9, -0.1, 0.5
        ])
        p2 = np.array([
            0.3, -0.1, -0.5, 0.2, -0.4, 0.8, -0.2, 0.0, 0.2, -0.4, 0.6, -0.2,
            -0.4, 0.2, 0.3, 0.2, -0.2, -0.4, -0.5, 0.2, 0.2, -0.4
        ])

        p_arr = jnp.array([p1, p2])
        p_arr_idx = 1

        partial_forward_fn = partial(forward_fn, inputs=batch['inputs'])
        partial_loss_fn = partial(model.loss_fn, targets=batch['targets'])

        def obj_fn(variables):
            return partial_loss_fn(partial_forward_fn(variables))

        flattened_p, obj_val = cg_backtracking(p_arr, p_arr_idx, obj_fn,
                                               {'params': params}, unravel_fn)

        # Test the backtracking function.
        self.assertSameElements(flattened_p, p1)
        updated_params = apply_updates(params, unravel_fn(p1))
        self.assertEqual(opt_cost({'params': updated_params}), obj_val)
Beispiel #17
0
  def test_translate_model(self):
    """Test forward pass of the translate model."""
    vocab_size = 16
    small_hps = config_dict.ConfigDict({
        # Architecture Hparams.
        'batch_size': 16,
        'share_embeddings': False,
        'logits_via_embedding': False,
        'emb_dim': 32,
        'num_heads': 2,
        'enc_num_layers': 2,
        'dec_num_layers': 2,
        'qkv_dim': 32,
        'label_smoothing': 0.1,
        'mlp_dim': 64,
        'max_target_length': 64,
        'max_eval_target_length': 64,
        'normalizer': 'pre_layer_norm',
        'max_predict_length': 64,
        'dropout_rate': 0.1,
        'attention_dropout_rate': 0.1,
        'momentum': 0.9,
        'lr_hparams': {
            'base_lr': 0.005,
            'schedule': 'constant'
        },
        'output_shape': (vocab_size,),
        # Training HParams.
        'l2_decay_factor': 1e-4,
        'enc_self_attn_kernel_init': 'xavier_uniform',
        'dec_self_attn_kernel_init': 'xavier_uniform',
        'dec_cross_attn_kernel_init': 'xavier_uniform',
        'decode': False,
    })
    text_src_input_shape = (32, 64)  # batch_size, max_source_length
    text_tgt_input_shape = (32, 40)  # batch_size, max_target_length
    model_cls = models.get_model('xformer_translate')
    rng = jax.random.PRNGKey(0)
    loss = 'cross_entropy'
    metrics = 'classification_metrics'
    model = model_cls(small_hps, {
        'shift_outputs': True,
        'causal': True
    }, loss, metrics)
    xs = jnp.array(
        np.random.randint(size=text_src_input_shape, low=1, high=vocab_size))
    ys = jnp.array(
        np.random.randint(size=text_tgt_input_shape, low=1, high=vocab_size))
    dropout_rng, params_rng = jax.random.split(rng)
    model_init_fn = jax.jit(
        functools.partial(model.flax_module.init, train=False))
    init_dict = model_init_fn({'params': params_rng}, xs, ys)
    params = init_dict['params']

    # Test forward pass.
    @jax.jit
    def forward_pass(params, xs, ys, dropout_rng):
      outputs = model.flax_module.apply(
          {'params': params},
          xs,
          ys,
          rngs={'dropout': dropout_rng},
          train=True)
      return outputs

    logits = forward_pass(params, xs, ys, dropout_rng)
    # Testing only train mode
    # TODO(ankugarg): Add tests for individual encoder/decoder (inference mode).
    self.assertEqual(
        logits.shape,
        (text_tgt_input_shape[0], text_tgt_input_shape[1], vocab_size))
Beispiel #18
0
  def test_text_models(self, model_str):
    """Test forward pass of the transformer model."""

    # TODO(gilmer): Find a clean way to handle small test hparams.
    vocab_size = 16

    small_hps = config_dict.ConfigDict({
        # Architecture Hparams.
        'batch_size': 16,
        'emb_dim': 32,
        'num_heads': 2,
        'num_layers': 3,
        'qkv_dim': 32,
        'label_smoothing': 0.1,
        'mlp_dim': 64,
        'max_target_length': 64,
        'max_eval_target_length': 64,
        'dropout_rate': 0.1,
        'attention_dropout_rate': 0.1,
        'momentum': 0.9,
        'normalizer': 'layer_norm',
        'lr_hparams': {
            'base_lr': 0.005,
            'schedule': 'constant'
        },
        'output_shape': (vocab_size,),
        'model_dtype': 'float32',
        # Training HParams.
        'l2_decay_factor': 1e-4,
        'decode': False,
    })

    text_input_shape = (32, 64)  # batch_size, max_target_length
    model_cls = models.get_model(model_str)
    rng = jax.random.PRNGKey(0)
    loss = 'cross_entropy'
    metrics = 'classification_metrics'
    model = model_cls(small_hps, {
        'max_len': 64,
        'shift_inputs': True,
        'causal': True
    }, loss, metrics)
    xs = jnp.array(
        np.random.randint(size=text_input_shape, low=1, high=vocab_size))
    dropout_rng, params_rng = jax.random.split(rng)

    model_init_fn = jax.jit(
        functools.partial(model.flax_module.init, train=False))
    init_dict = model_init_fn({'params': params_rng}, xs)
    params = init_dict['params']
    batch_stats = init_dict.get('batch_stats', {})

    # Check that the forward pass works with mutated batch_stats.
    # Due to a bug in flax, this jit is required, otherwise the model errors.
    @jax.jit
    def forward_pass(params, xs, dropout_rng):
      outputs, new_batch_stats = model.flax_module.apply(
          {'params': params, 'batch_stats': batch_stats},
          xs,
          mutable=['batch_stats'],
          rngs={'dropout': dropout_rng},
          train=True)
      return outputs, new_batch_stats

    outputs, new_batch_stats = forward_pass(params, xs, dropout_rng)
    self.assertEqual(outputs.shape,
                     (text_input_shape[0], text_input_shape[1], vocab_size))

    # If it's a batch norm model check the batch stats changed.
    if batch_stats:
      bflat, _ = ravel_pytree(batch_stats)
      new_bflat, _ = ravel_pytree(new_batch_stats)
      self.assertFalse(jnp.array_equal(bflat, new_bflat))

    # Test batch_norm in inference mode.
    outputs = model.flax_module.apply(
        {'params': params, 'batch_stats': batch_stats}, xs, train=False)
    self.assertEqual(outputs.shape,
                     (text_input_shape[0], text_input_shape[1], vocab_size))
Beispiel #19
0
    def test_text_models(self, model_str):
        """Test forward pass of the transformer model."""

        # TODO(gilmer): Find a clean way to handle small test hparams.
        vocab_size = 16

        small_hps = config_dict.ConfigDict({
            # Architecture Hparams.
            'batch_size': 16,
            'emb_dim': 32,
            'num_heads': 2,
            'num_layers': 3,
            'qkv_dim': 32,
            'label_smoothing': 0.1,
            'mlp_dim': 64,
            'max_target_length': 64,
            'max_eval_target_length': 64,
            'dropout_rate': 0.1,
            'attention_dropout_rate': 0.1,
            'momentum': 0.9,
            'normalizer': 'layer_norm',
            'lr_hparams': {
                'initial_value': 0.005,
                'schedule': 'constant'
            },
            'output_shape': (vocab_size, ),
            # Training HParams.
            'l2_decay_factor': 1e-4
        })

        text_input_shape = (32, 64)  # batch_size, max_target_length
        model_cls = models.get_model(model_str)
        rng = jax.random.PRNGKey(0)
        loss = 'cross_entropy'
        metrics = 'classification_metrics'
        model = model_cls(small_hps, {
            'max_len': 64,
            'shift_inputs': True,
            'causal': True
        }, loss, metrics)
        xs = jnp.array(
            np.random.randint(size=text_input_shape, low=1, high=vocab_size))
        rng, params_rng = jax.random.split(rng)
        rng, dropout_rng = jax.random.split(rng)

        with nn.stateful() as batch_stats:
            _, flax_module = model.flax_module_def.create_by_shape(
                params_rng, [(text_input_shape, jnp.float32)], train=False)

        # Check that the forward pass works with mutated batch_stats.
        # Due to a bug in flax, this jit is required, otherwise the model errors.
        @jax.jit
        def forward_pass(flax_module, xs, dropout_rng):
            with batch_stats.mutate() as new_batch_stats:
                with nn.stochastic(dropout_rng):
                    return flax_module(xs, train=True), new_batch_stats

        outputs, new_batch_stats = forward_pass(flax_module, xs, dropout_rng)
        self.assertEqual(
            outputs.shape,
            (text_input_shape[0], text_input_shape[1], vocab_size))

        # If it's a batch norm model check the batch stats changed.
        if batch_stats.as_dict():
            bflat, _ = ravel_pytree(batch_stats)
            new_bflat, _ = ravel_pytree(new_batch_stats)
            self.assertFalse(jnp.array_equal(bflat, new_bflat))

        # Test batch_norm in inference mode.
        with nn.stateful(batch_stats, mutable=False):
            outputs = flax_module(xs, train=False)
        self.assertEqual(
            outputs.shape,
            (text_input_shape[0], text_input_shape[1], vocab_size))
Beispiel #20
0
    def test_hessian_free_optimizer(self):
        """Tests the Hessian-free optimizer."""

        model_str = 'autoencoder'
        model_cls = models.get_model(model_str)
        model_hps = models.get_model_hparams(model_str)

        loss = 'sigmoid_binary_cross_entropy'
        metrics = 'binary_autoencoder_metrics'

        input_shape = (2, 2, 1)
        output_shape = (4, )

        hps = copy.copy(model_hps)
        hps.update({
            'optimizer': 'hessian_free',
            'opt_hparams': {
                'weight_decay': 0.0,
            },
            'hid_sizes': [2],
            'activation_function': ['id'],
            'input_shape': input_shape,
            'output_shape': output_shape
        })

        model = model_cls(hps, {}, loss, metrics)

        inputs = jnp.array([[[1, 0], [1, 1]], [[1, 0], [0, 1]]])
        targets = inputs.reshape(tuple([inputs.shape[0]] + list(output_shape)))
        batch = {'inputs': inputs, 'targets': targets}

        def forward_fn(variables, inputs):
            logits = model.flax_module.apply(variables, inputs, train=True)
            return logits

        def opt_cost(variables):
            return model.loss_fn(forward_fn(variables, inputs), targets)

        init_fn, update_fn = optimizers.get_optimizer(hps, model)

        params = {
            'Dense_0': {
                'kernel': jnp.array([[-1., 2.], [2., 0.], [-1., 3.], [-2.,
                                                                      2.]]),
                'bias': jnp.array([0., 0.])
            },
            'Dense_1': {
                'kernel': jnp.array([[4., 2., -2., 4.], [-3., 1., 2., -4.]]),
                'bias': jnp.array([0., 0., 0., 0.])
            }
        }
        variables = {'params': params}

        grad_fn = jax.grad(opt_cost)
        grads = grad_fn(variables)['params']

        outputs = forward_fn(variables, batch['inputs'])

        n = inputs.shape[0]
        m = outputs.shape[-1]
        d = ravel_pytree(params)[0].shape[0]

        v = np.ones(d)

        state = init_fn(params)

        partial_forward_fn = partial(forward_fn, inputs=batch['inputs'])
        partial_loss_fn = partial(model.loss_fn, targets=batch['targets'])

        matmul_fn = partial(gvp, variables, outputs, state.inner_state.damping,
                            partial_forward_fn, partial_loss_fn)

        jacobian = jax.jacfwd(partial_forward_fn)(variables)['params']
        jacobian_tensor = np.concatenate(
            (jacobian['Dense_0']['bias'].reshape(
                n, m, -1), jacobian['Dense_0']['kernel'].reshape(
                    n, m, -1), jacobian['Dense_1']['bias'].reshape(n, m, -1),
             jacobian['Dense_1']['kernel'].reshape(n, m, -1)),
            axis=2)

        ggn_matrix = 0
        for i in range(n):
            jacobian_matrix = jacobian_tensor[i]
            hessian = jax.hessian(partial_loss_fn)(outputs[i, None])[0, :,
                                                                     0, :]
            ggn_matrix += np.transpose(
                jacobian_matrix) @ hessian @ jacobian_matrix
        ggn_matrix /= n
        ggn_matrix += state.inner_state.damping * np.identity(d)

        expected = ggn_matrix @ v

        # Test the gvp function
        self.assertAlmostEqual(jnp.linalg.norm(matmul_fn(v) - expected),
                               0,
                               places=4)

        update_pmapped = jax.pmap(update_fn,
                                  axis_name='batch',
                                  in_axes=(None, None, None, 0, None))

        batch_shard = data_utils.shard(batch)

        state.hyperparams['learning_rate'] = 1.0

        p, state = update_pmapped(grads, state, params, batch_shard, None)

        # Test the damping parameter update
        self.assertEqual(state.inner_state.damping, 3 / 2)

        # Test the search direction
        self.assertAlmostEqual(jnp.linalg.norm(
            ravel_pytree(p)[0] +
            jnp.linalg.inv(ggn_matrix) @ ravel_pytree(grads)[0]),
                               0,
                               places=4)
    def test_accumulation(self):
        """Test simple gradient accumulation."""
        num_steps = 3
        per_step_batch_size = 16
        total_batch_size = 48
        virtual_batch_size = 8
        model_str = 'wide_resnet'  # Pick a model with batch norm.
        model_cls = models.get_model(model_str)
        model_hps = models.get_model_hparams(model_str)
        dataset_name = 'cifar10'
        dataset_builder = datasets.get_dataset(dataset_name)
        hps = copy.copy(model_hps)
        hps.update(datasets.get_dataset_hparams(dataset_name))

        # Compute updates using gradient accumulation.
        hps.update({
            'batch_size': per_step_batch_size,
            'virtual_batch_size': virtual_batch_size,
            'normalizer': 'virtual_batch_norm',
            'total_accumulated_batch_size': total_batch_size,
        })
        grad_acc_params, grad_acc_batch_stats, grad_acc_training_cost = _init_model(
            model_cls, hps)
        total_dataset = dataset_builder(shuffle_rng=jax.random.PRNGKey(1),
                                        batch_size=total_batch_size,
                                        eval_batch_size=10,
                                        hps=hps)
        # Ensure we see the same exact batches.
        train_iter = total_dataset.train_iterator_fn()
        train_iter = itertools.islice(train_iter, 0, num_steps)
        train_iter = itertools.cycle(train_iter)

        def grad_acc_train_iter():
            for _ in range(num_steps):
                total_batch = next(train_iter)
                # Split each total batch into sub batches.
                num_sub_batches = total_batch_size // per_step_batch_size
                start_index = 0
                end_index = int(total_batch_size / num_sub_batches)
                for bi in range(num_sub_batches):
                    yield jax.tree_map(lambda x: x[start_index:end_index],
                                       total_batch)  # pylint: disable=cell-var-from-loop
                    start_index = end_index
                    end_index = int(total_batch_size * (bi + 2) /
                                    num_sub_batches)

        lrs = jnp.array([1.0, 0.1, 1e-2])
        sgd_opt_init, sgd_opt_update = optax.sgd(
            learning_rate=lambda t: lrs.at[t].get())
        opt_init, opt_update = gradient_accumulator.accumulate_gradients(
            per_step_batch_size=per_step_batch_size,
            total_batch_size=total_batch_size,
            virtual_batch_size=virtual_batch_size,
            base_opt_init_fn=sgd_opt_init,
            base_opt_update_fn=sgd_opt_update)
        grad_acc_params, grad_acc_batch_stats = _optimize(
            # Run for 3x the number of steps to see the same number of examples.
            num_steps=3 * num_steps,
            params=grad_acc_params,
            batch_stats=grad_acc_batch_stats,
            training_cost=grad_acc_training_cost,
            train_iter=grad_acc_train_iter(),
            opt_init=opt_init,
            opt_update=opt_update)

        # Compute the same updates, but without gradient accumulation.
        hps.update({
            'batch_size': total_batch_size,
            'total_accumulated_batch_size': None,
        })
        params, batch_stats, training_cost = _init_model(model_cls, hps)
        params, batch_stats = _optimize(num_steps=num_steps,
                                        params=params,
                                        batch_stats=batch_stats,
                                        training_cost=training_cost,
                                        train_iter=train_iter,
                                        opt_init=sgd_opt_init,
                                        opt_update=sgd_opt_update)

        diffs_params = jax.tree_multimap(lambda a, b: jnp.mean(jnp.abs(a - b)),
                                         grad_acc_params, params)

        def batch_stats_reduce(a, b):
            if len(a.shape) > 0:  # pylint: disable=g-explicit-length-test
                return jnp.mean(
                    jnp.abs(jnp.mean(a, axis=0) - jnp.mean(b, axis=0)))
            # The gradient accumulator counters are scalars.
            return a - b

        diffs_batch_stats = jax.tree_multimap(batch_stats_reduce,
                                              grad_acc_batch_stats,
                                              batch_stats)
        # We sometimes get small floating point errors in the gradients, so we
        # cannot test for the values being exactly the same.
        acceptable_params_diff = 1e-4
        acceptable_batch_stats_diff = 5e-3

        def check_closeness(root_name, d, max_diff):
            not_close_dict = {}
            for name, dd in d.items():
                new_name = root_name + '/' + name if root_name else name
                if isinstance(dd, (dict, core.FrozenDict)):
                    not_close_dict.update(
                        check_closeness(new_name, dd, max_diff))
                else:
                    if dd > max_diff:
                        not_close_dict[new_name] = dd
            return not_close_dict

        not_close_params = check_closeness('', diffs_params,
                                           acceptable_params_diff)
        self.assertEmpty(not_close_params)
        not_close_batch_stats = check_closeness('', diffs_batch_stats,
                                                acceptable_batch_stats_diff)
        # Note that for the variance variables in the batch stats collection, they
        # sometimes can start to diverge slightly over time (with a higher number of
        # training steps), likely due to numerical issues.
        self.assertEmpty(not_close_batch_stats)
Beispiel #22
0
    def test_translate_model(self):
        """Test forward pass of the translate model."""
        vocab_size = 16
        small_hps = config_dict.ConfigDict({
            # Architecture Hparams.
            'batch_size': 16,
            'share_embeddings': False,
            'logits_via_embedding': False,
            'emb_dim': 32,
            'num_heads': 2,
            'enc_num_layers': 2,
            'dec_num_layers': 2,
            'qkv_dim': 32,
            'label_smoothing': 0.1,
            'mlp_dim': 64,
            'max_target_length': 64,
            'max_eval_target_length': 64,
            'normalizer': 'pre_layer_norm',
            'max_predict_length': 64,
            'dropout_rate': 0.1,
            'attention_dropout_rate': 0.1,
            'momentum': 0.9,
            'lr_hparams': {
                'initial_value': 0.005,
                'schedule': 'constant'
            },
            'output_shape': (vocab_size, ),
            # Training HParams.
            'l2_decay_factor': 1e-4
        })
        text_src_input_shape = (32, 64)  # batch_size, max_source_length
        text_tgt_input_shape = (32, 40)  # batch_size, max_target_length
        model_cls = models.get_model('xformer_translate')
        rng = jax.random.PRNGKey(0)
        loss = 'cross_entropy'
        metrics = 'classification_metrics'
        model = model_cls(small_hps, {
            'shift_outputs': True,
            'causal': True
        }, loss, metrics)
        xs = jnp.array(
            np.random.randint(size=text_src_input_shape,
                              low=1,
                              high=vocab_size))
        ys = jnp.array(
            np.random.randint(size=text_tgt_input_shape,
                              low=1,
                              high=vocab_size))
        rng, params_rng = jax.random.split(rng)
        rng, dropout_rng = jax.random.split(rng)
        with nn.stateful() as batch_stats:
            _, flax_module = model.flax_module_def.create_by_shape(
                params_rng, [(text_src_input_shape, jnp.float32),
                             (text_tgt_input_shape, jnp.float32)],
                train=False)

        # Test forward pass.
        @jax.jit
        def forward_pass(flax_module, xs, ys, dropout_rng):
            with batch_stats.mutate() as new_batch_stats:
                with nn.stochastic(dropout_rng):
                    return flax_module(xs, ys, train=True), new_batch_stats

        logits, _ = forward_pass(flax_module, xs, ys, dropout_rng)
        # Testing only train mode
        # TODO(ankugarg): Add tests for individual encoder/decoder (inference mode).
        self.assertEqual(
            logits.shape,
            (text_tgt_input_shape[0], text_tgt_input_shape[1], vocab_size))
Beispiel #23
0
    def test_early_stopping(self):
        """Test training early stopping on MNIST with a small model."""
        rng = jax.random.PRNGKey(0)

        # Set the numpy seed to make the fake data deterministc. mocking.mock_data
        # ultimately calls numpy.random.
        np.random.seed(0)

        model_name = 'fully_connected'
        loss_name = 'cross_entropy'
        metrics_name = 'classification_metrics'
        initializer_name = 'noop'
        dataset_name = 'mnist'
        model_cls = models.get_model(model_name)
        initializer = initializers.get_initializer(initializer_name)
        dataset_builder = datasets.get_dataset(dataset_name)
        hparam_overrides = {
            'lr_hparams': {
                'base_lr': 0.1,
                'schedule': 'cosine'
            },
            'batch_size': 8,
            'train_size': 160,
            'valid_size': 96,
            'test_size': 80,
        }
        input_pipeline_hps = config_dict.ConfigDict(
            dict(
                num_tf_data_prefetches=-1,
                num_device_prefetches=0,
                num_tf_data_map_parallel_calls=-1,
            ))
        hps = hyperparameters.build_hparams(
            model_name,
            initializer_name,
            dataset_name,
            hparam_file=None,
            hparam_overrides=hparam_overrides,
            input_pipeline_hps=input_pipeline_hps)

        eval_batch_size = 16
        num_examples = 256

        def as_dataset(self, *args, **kwargs):
            del args
            del kwargs

            # pylint: disable=g-long-lambda,g-complex-comprehension
            return tf.data.Dataset.from_generator(
                lambda: ({
                    'image': np.ones(shape=(28, 28, 1), dtype=np.uint8),
                    'label': 9,
                } for i in range(num_examples)),
                output_types=self.info.features.dtype,
                output_shapes=self.info.features.shape,
            )

        # This will override the tfds.load(mnist) call to return 100 fake samples.
        with tfds.testing.mock_data(as_dataset_fn=as_dataset,
                                    num_examples=num_examples):
            dataset = dataset_builder(shuffle_rng=jax.random.PRNGKey(0),
                                      batch_size=hps.batch_size,
                                      eval_batch_size=eval_batch_size,
                                      hps=hps)

        model = model_cls(hps, datasets.get_dataset_meta_data(dataset_name),
                          loss_name, metrics_name)

        num_train_steps = 40
        early_stopping_target_name = 'test/ce_loss'
        early_stopping_target_value = 0.005
        early_stopping_mode = 'less'
        eval_num_batches = 5
        eval_every = 10
        checkpoint_steps = [1, 3, 15]
        metrics_logger, init_logger = utils.set_up_loggers(self.test_dir)
        epoch_reports = list(
            trainer.train(
                train_dir=self.test_dir,
                model=model,
                dataset_builder=lambda *unused_args, **unused_kwargs: dataset,
                initializer=initializer,
                num_train_steps=num_train_steps,
                hps=hps,
                rng=rng,
                eval_batch_size=eval_batch_size,
                eval_num_batches=eval_num_batches,
                eval_train_num_batches=eval_num_batches,
                eval_frequency=eval_every,
                checkpoint_steps=checkpoint_steps,
                early_stopping_target_name=early_stopping_target_name,
                early_stopping_target_value=early_stopping_target_value,
                early_stopping_mode=early_stopping_mode,
                metrics_logger=metrics_logger,
                init_logger=init_logger))
        self.assertLen(epoch_reports, 3)
        self.assertGreater(epoch_reports[-2][early_stopping_target_name],
                           early_stopping_target_value)
        self.assertLess(epoch_reports[-1][early_stopping_target_name],
                        early_stopping_target_value)
    def test_text_model(self):
        """Test gradient accumulator training of a small transformer."""
        rng = jax.random.PRNGKey(42)

        # Set the numpy seed to make the fake data deterministc. mocking.mock_data
        # ultimately calls numpy.random.
        np.random.seed(0)

        model_cls = models.get_model('transformer')
        loss_name = 'cross_entropy'
        metrics_name = 'classification_metrics'
        batch_size = 16
        train_size = 20 * batch_size
        hps = config_dict.ConfigDict({
            # Architecture Hparams.
            'batch_size': batch_size,
            'emb_dim': 32,
            'num_heads': 2,
            'num_layers': 3,
            'qkv_dim': 32,
            'mlp_dim': 64,
            'max_target_length': 64,
            'max_eval_target_length': 64,
            'input_shape': (64, ),
            'output_shape': (4, ),
            'dropout_rate': 0.1,
            'attention_dropout_rate': 0.1,
            'layer_rescale_factors': {},
            'optimizer': 'momentum',
            'normalizer': 'layer_norm',
            'opt_hparams': {
                'momentum': 0.9,
            },
            'lr_hparams': {
                'base_lr': 0.005,
                'schedule': 'constant'
            },
            # Training HParams.
            'l2_decay_factor': 1e-4,
            'l2_decay_rank_threshold': 2,
            'train_size': train_size,
            'gradient_clipping': 0.0,
            'model_dtype': 'float32',
            'decode': False,
        })
        initializer = initializers.get_initializer('noop')
        eval_num_batches = 5
        dataset, dataset_meta_data = _get_fake_text_dataset(
            batch_size=hps.batch_size, eval_num_batches=eval_num_batches)
        eval_batch_size = hps.batch_size

        model = model_cls(hps, dataset_meta_data, loss_name, metrics_name)

        eval_every = 10
        checkpoint_steps = []
        num_train_steps = train_size // batch_size * 3

        metrics_logger, init_logger = trainer.set_up_loggers(self.test_dir)
        _ = list(
            trainer.train(
                train_dir=self.test_dir,
                model=model,
                dataset_builder=lambda *unused_args, **unused_kwargs: dataset,
                initializer=initializer,
                num_train_steps=num_train_steps,
                hps=hps,
                rng=rng,
                eval_batch_size=eval_batch_size,
                eval_num_batches=eval_num_batches,
                eval_train_num_batches=eval_num_batches,
                eval_frequency=eval_every,
                checkpoint_steps=checkpoint_steps,
                metrics_logger=metrics_logger,
                init_logger=init_logger))

        with tf.io.gfile.GFile(os.path.join(self.test_dir,
                                            'measurements.csv')) as f:
            df = pandas.read_csv(f)
            train_err = df['train/error_rate'].values[-1]
            # Note that upgrading to Linen made this fail at 0.6.
            self.assertLess(train_err, 0.7)
Beispiel #25
0
  def test_run_lanczos(self):
    """Test training for two epochs on MNIST with a small model."""
    rng = jax.random.PRNGKey(0)

    # Set the numpy seed to make the fake data deterministc. mocking.mock_data
    # ultimately calls numpy.random.
    np.random.seed(0)

    model_name = 'fully_connected'
    loss_name = 'cross_entropy'
    metrics_name = 'classification_metrics'
    initializer_name = 'noop'
    dataset_name = 'mnist'
    model_cls = models.get_model(model_name)
    initializer = initializers.get_initializer(initializer_name)
    dataset_builder = datasets.get_dataset(dataset_name)
    hparam_overrides = {
        'lr_hparams': {
            'base_lr': 0.1,
            'schedule': 'cosine'
        },
        'batch_size': 8,
        'train_size': 160,
        'valid_size': 96,
        'test_size': 80,
    }
    input_pipeline_hps = config_dict.ConfigDict(dict(
        num_tf_data_prefetches=-1,
        num_device_prefetches=0,
        num_tf_data_map_parallel_calls=-1,
    ))
    hps = hyperparameters.build_hparams(
        model_name,
        initializer_name,
        dataset_name,
        hparam_file=None,
        hparam_overrides=hparam_overrides,
        input_pipeline_hps=input_pipeline_hps)
    model = model_cls(hps, datasets.get_dataset_meta_data(dataset_name),
                      loss_name, metrics_name)

    eval_batch_size = 16
    num_examples = 256

    def as_dataset(self, *args, **kwargs):
      del args
      del kwargs

      # pylint: disable=g-long-lambda,g-complex-comprehension
      return tf.data.Dataset.from_generator(
          lambda: ({
              'image': np.ones(shape=(28, 28, 1), dtype=np.uint8),
              'label': 9,
          } for i in range(num_examples)),
          output_types=self.info.features.dtype,
          output_shapes=self.info.features.shape,
      )

    # This will override the tfds.load(mnist) call to return 100 fake samples.
    with tfds.testing.mock_data(
        as_dataset_fn=as_dataset, num_examples=num_examples):
      dataset = dataset_builder(
          shuffle_rng=jax.random.PRNGKey(0),
          batch_size=hps.batch_size,
          eval_batch_size=eval_batch_size,
          hps=hps)

    num_train_steps = 41
    eval_num_batches = 5
    eval_every = 10
    checkpoint_steps = [10, 30, 40]
    metrics_logger, init_logger = None, None
    _ = list(
        trainer.train(
            train_dir=self.test_dir,
            model=model,
            dataset_builder=lambda *unused_args, **unused_kwargs: dataset,
            initializer=initializer,
            num_train_steps=num_train_steps,
            hps=hps,
            rng=rng,
            eval_batch_size=eval_batch_size,
            eval_num_batches=eval_num_batches,
            eval_train_num_batches=eval_num_batches,
            eval_frequency=eval_every,
            checkpoint_steps=checkpoint_steps,
            metrics_logger=metrics_logger,
            init_logger=init_logger))

    checkpoint_dir = os.path.join(self.test_dir, 'checkpoints')
    rng = jax.random.PRNGKey(0)

    run_lanczos.eval_checkpoints(
        checkpoint_dir,
        hps,
        rng,
        eval_num_batches,
        model_cls=model_cls,
        dataset_builder=lambda *unused_args, **unused_kwargs: dataset,
        dataset_meta_data=datasets.get_dataset_meta_data(dataset_name),
        hessian_eval_config=hessian_eval.DEFAULT_EVAL_CONFIG,
    )

    # Load the saved file.
    hessian_dir = os.path.join(checkpoint_dir, 'hessian', 'training_metrics')
    pytree_list = checkpoint.load_pytree(hessian_dir)

    # Convert to a regular list (checkpointer will have converted the saved
    # list to a dict of keys '0', '1', ...
    pytree_list = [pytree_list[str(i)] for i in range(len(pytree_list))]
    # Test that the logged steps are correct.
    saved_steps = [row['step'] for row in pytree_list]
    self.assertEqual(saved_steps, checkpoint_steps)
Beispiel #26
0
    def test_dlrm_model_trainer(self):
        """Tests that dlrm model training decreases loss."""
        rng = jax.random.PRNGKey(1337)
        model_str = 'dlrm'
        dataset_str = 'criteo1tb'
        model_cls = models.get_model(model_str)
        model_hps = models.get_model_hparams(model_str)
        dataset_hps = datasets.get_dataset_hparams(dataset_str)
        dataset_hps.update({
            'batch_size': model_hps.batch_size,
            'num_dense_features': model_hps.num_dense_features,
            'vocab_sizes': model_hps.vocab_sizes,
        })
        eval_num_batches = 5
        eval_batch_size = dataset_hps.batch_size
        loss_name = 'sigmoid_binary_cross_entropy'
        metrics_name = 'binary_classification_metrics'
        dataset, dataset_meta_data = _get_fake_dlrm_dataset(
            dataset_hps.batch_size, eval_num_batches, dataset_hps)
        hps = copy.copy(model_hps)
        hps.update({
            'train_size':
            15,
            'valid_size':
            10,
            'test_size':
            10,
            'input_shape':
            (model_hps.num_dense_features + len(model_hps.vocab_sizes), ),
            'output_shape': (1, ),
            'l2_decay_factor':
            1e-4,
            'l2_decay_rank_threshold':
            2,
            'num_device_prefetches':
            0,
        })
        model = model_cls(hps, dataset_meta_data, loss_name, metrics_name)
        initializer = initializers.get_initializer('noop')

        metrics_logger, init_logger = utils.set_up_loggers(self.test_dir)
        _ = list(
            trainer.train(
                train_dir=self.test_dir,
                model=model,
                dataset_builder=lambda *unused_args, **unused_kwargs: dataset,
                initializer=initializer,
                num_train_steps=10,
                hps=hps,
                rng=rng,
                eval_batch_size=eval_batch_size,
                eval_num_batches=eval_num_batches,
                eval_train_num_batches=eval_num_batches,
                eval_frequency=2,
                checkpoint_steps=[],
                metrics_logger=metrics_logger,
                init_logger=init_logger))

        with tf.io.gfile.GFile(os.path.join(self.test_dir,
                                            'measurements.csv')) as f:
            df = pandas.read_csv(f)
            train_loss = df['train/ce_loss'].values
            self.assertLess(train_loss[-1], train_loss[0])
Beispiel #27
0
    def test_graph_model_trainer(self):
        """Tests that graph model training decreases loss."""
        rng = jax.random.PRNGKey(1337)
        model_str = 'gnn'
        model_cls = models.get_model(model_str)
        hps = models.get_model_hparams(model_str)
        hps.update({
            'batch_size': 2,
            'input_edge_shape': (7, ),
            'input_node_shape': (3, ),
            'input_shape': (7, 3),
            'output_shape': (5, ),
            'model_dtype': 'float32',
            'train_size': 15,
            'valid_size': 10,
            'test_size': 10,
            'num_message_passing_steps': 1,
            'normalizer': 'none',
            'dropout_rate': 0.0,
            'lr_hparams': {
                'base_lr': 0.001,
                'schedule': 'constant'
            },
            'num_device_prefetches': 0,
        })
        eval_num_batches = 5
        eval_batch_size = hps.batch_size
        loss_name = 'sigmoid_binary_cross_entropy'
        metrics_name = 'binary_classification_metrics_ogbg_map'
        dataset, dataset_meta_data = _get_fake_graph_dataset(
            batch_size=hps.batch_size,
            eval_num_batches=eval_num_batches,
            hps=hps)
        model = model_cls(hps, dataset_meta_data, loss_name, metrics_name)
        initializer = initializers.get_initializer('noop')

        metrics_logger, init_logger = utils.set_up_loggers(self.test_dir)
        _ = list(
            trainer.train(
                train_dir=self.test_dir,
                model=model,
                dataset_builder=lambda *unused_args, **unused_kwargs: dataset,
                initializer=initializer,
                num_train_steps=10,
                hps=hps,
                rng=rng,
                eval_batch_size=eval_batch_size,
                eval_num_batches=eval_num_batches,
                eval_train_num_batches=eval_num_batches,
                # Note that for some reason, moving from the deprecated to linen
                # Flax model API made training less stable so we need to eval more
                # frequently in order to get a `train_loss[0]` that is earlier in
                # training.
                eval_frequency=2,
                checkpoint_steps=[],
                metrics_logger=metrics_logger,
                init_logger=init_logger))

        with tf.io.gfile.GFile(os.path.join(self.test_dir,
                                            'measurements.csv')) as f:
            df = pandas.read_csv(f)
            train_loss = df['train/ce_loss'].values
            self.assertLess(train_loss[-1], train_loss[0])
Beispiel #28
0
def _run(train_fn, dataset_name, eval_batch_size, eval_num_batches,
         eval_train_num_batches, eval_frequency, checkpoint_steps,
         num_tf_data_prefetches, num_device_prefetches,
         num_tf_data_map_parallel_calls, early_stopping_target_name,
         early_stopping_target_value, early_stopping_mode, eval_steps,
         hparam_file, hparam_overrides, initializer_name, model_name,
         loss_name, metrics_name, num_train_steps, experiment_dir, worker_id,
         training_metrics_config, callback_configs, external_checkpoint_path):
    """Function that runs a Jax experiment. See flag definitions for args."""
    model_cls = models.get_model(model_name)
    initializer = initializers.get_initializer(initializer_name)
    dataset_builder = datasets.get_dataset(dataset_name)
    dataset_meta_data = datasets.get_dataset_meta_data(dataset_name)
    input_pipeline_hps = config_dict.ConfigDict(
        dict(
            num_tf_data_prefetches=num_tf_data_prefetches,
            num_device_prefetches=num_device_prefetches,
            num_tf_data_map_parallel_calls=num_tf_data_map_parallel_calls,
        ))

    merged_hps = hyperparameters.build_hparams(
        model_name=model_name,
        initializer_name=initializer_name,
        dataset_name=dataset_name,
        hparam_file=hparam_file,
        hparam_overrides=hparam_overrides,
        input_pipeline_hps=input_pipeline_hps)

    # Note that one should never tune an RNG seed!!! The seed is only included in
    # the hparams for convenience of running hparam trials with multiple seeds per
    # point.
    rng_seed = merged_hps.rng_seed
    if merged_hps.rng_seed < 0:
        rng_seed = _create_synchronized_rng_seed()
    xm_experiment = None
    xm_work_unit = None
    if jax.process_index() == 0:
        logging.info('Running with seed %d', rng_seed)
    rng = jax.random.PRNGKey(rng_seed)

    # Build the loss_fn, metrics_bundle, and flax_module.
    model = model_cls(merged_hps, dataset_meta_data, loss_name, metrics_name)
    trial_dir = os.path.join(experiment_dir, str(worker_id))
    meta_data_path = os.path.join(trial_dir, 'meta_data.json')
    meta_data = {'worker_id': worker_id, 'status': 'incomplete'}
    if jax.process_index() == 0:
        logging.info('rng: %s', rng)
        gfile.makedirs(trial_dir)
        # Set up the metric loggers for host 0.
        metrics_logger, init_logger = utils.set_up_loggers(
            trial_dir, xm_work_unit)
        hparams_fname = os.path.join(trial_dir, 'hparams.json')
        logging.info('saving hparams to %s', hparams_fname)
        with gfile.GFile(hparams_fname, 'w') as f:
            f.write(merged_hps.to_json())
        _write_trial_meta_data(meta_data_path, meta_data)
    else:
        metrics_logger = None
        init_logger = None
    try:
        epoch_reports = list(
            train_fn(trial_dir,
                     model,
                     dataset_builder,
                     initializer,
                     num_train_steps,
                     merged_hps,
                     rng,
                     eval_batch_size,
                     eval_num_batches,
                     eval_train_num_batches,
                     eval_frequency,
                     checkpoint_steps,
                     early_stopping_target_name,
                     early_stopping_target_value,
                     early_stopping_mode,
                     eval_steps,
                     metrics_logger,
                     init_logger,
                     training_metrics_config=training_metrics_config,
                     callback_configs=callback_configs,
                     external_checkpoint_path=external_checkpoint_path))
        logging.info(epoch_reports)
        meta_data['status'] = 'done'
    except utils.TrainingDivergedError as err:
        meta_data['status'] = 'diverged'
        raise err
    finally:
        if jax.process_index() == 0:
            _write_trial_meta_data(meta_data_path, meta_data)
Beispiel #29
0
    def test_trainer(self):
        """Test training for two epochs on MNIST with a small model."""
        rng = jax.random.PRNGKey(0)

        # Set the numpy seed to make the fake data deterministc. mocking.mock_data
        # ultimately calls numpy.random.
        np.random.seed(0)

        model_name = 'fully_connected'
        loss_name = 'cross_entropy'
        metrics_name = 'classification_metrics'
        initializer_name = 'noop'
        dataset_name = 'mnist'
        model_cls = models.get_model(model_name)
        initializer = initializers.get_initializer(initializer_name)
        dataset_builder = datasets.get_dataset(dataset_name)
        hparam_overrides = {
            'lr_hparams': {
                'base_lr': 0.1,
                'schedule': 'cosine'
            },
            'batch_size': 8,
            'train_size': 160,
            'valid_size': 96,
            'test_size': 80,
        }
        input_pipeline_hps = config_dict.ConfigDict(
            dict(
                num_tf_data_prefetches=-1,
                num_device_prefetches=0,
                num_tf_data_map_parallel_calls=-1,
            ))
        hps = hyperparameters.build_hparams(
            model_name,
            initializer_name,
            dataset_name,
            hparam_file=None,
            hparam_overrides=hparam_overrides,
            input_pipeline_hps=input_pipeline_hps)

        eval_batch_size = 16
        num_examples = 256

        def as_dataset(self, *args, **kwargs):
            del args
            del kwargs

            # pylint: disable=g-long-lambda,g-complex-comprehension
            return tf.data.Dataset.from_generator(
                lambda: ({
                    'image': np.ones(shape=(28, 28, 1), dtype=np.uint8),
                    'label': 9,
                } for i in range(num_examples)),
                output_types=self.info.features.dtype,
                output_shapes=self.info.features.shape,
            )

        # This will override the tfds.load(mnist) call to return 100 fake samples.
        with tfds.testing.mock_data(as_dataset_fn=as_dataset,
                                    num_examples=num_examples):
            dataset = dataset_builder(shuffle_rng=jax.random.PRNGKey(0),
                                      batch_size=hps.batch_size,
                                      eval_batch_size=eval_batch_size,
                                      hps=hps)

        model = model_cls(hps, datasets.get_dataset_meta_data(dataset_name),
                          loss_name, metrics_name)

        num_train_steps = 40
        eval_num_batches = 5
        eval_every = 10
        checkpoint_steps = [1, 3, 15]
        metrics_logger, init_logger = utils.set_up_loggers(self.test_dir)
        epoch_reports = list(
            trainer.train(
                train_dir=self.test_dir,
                model=model,
                dataset_builder=lambda *unused_args, **unused_kwargs: dataset,
                initializer=initializer,
                num_train_steps=num_train_steps,
                hps=hps,
                rng=rng,
                eval_batch_size=eval_batch_size,
                eval_num_batches=eval_num_batches,
                eval_train_num_batches=eval_num_batches,
                eval_frequency=eval_every,
                checkpoint_steps=checkpoint_steps,
                metrics_logger=metrics_logger,
                init_logger=init_logger))

        # check that the additional checkpoints are saved.
        checkpoint_dir = os.path.join(self.test_dir, 'checkpoints')
        saved_steps = []
        for f in tf.io.gfile.listdir(checkpoint_dir):
            if f[:5] == 'ckpt_':
                saved_steps.append(int(f[5:]))

        self.assertEqual(set(saved_steps), set(checkpoint_steps))

        self.assertLen(epoch_reports, num_train_steps / eval_every)
        with tf.io.gfile.GFile(os.path.join(self.test_dir,
                                            'measurements.csv')) as f:
            df = pandas.read_csv(f)
            train_err = df['train/error_rate'].values[-1]
            self.assertEqual(df['preemption_count'].values[-1], 0)
            self.assertLess(train_err, 0.9)

        self.assertEqual(set(df.columns.values), set(get_column_names()))

        model = model_cls(hps, {'apply_one_hot_in_loss': False}, loss_name,
                          metrics_name)

        # Test reload from the checkpoint by increasing num_train_steps.
        num_train_steps_reload = 100
        epoch_reports = list(
            trainer.train(
                train_dir=self.test_dir,
                model=model,
                dataset_builder=lambda *unused_args, **unused_kwargs: dataset,
                initializer=initializer,
                num_train_steps=num_train_steps_reload,
                hps=hps,
                rng=rng,
                eval_batch_size=eval_batch_size,
                eval_num_batches=eval_num_batches,
                eval_train_num_batches=eval_num_batches,
                eval_frequency=eval_every,
                checkpoint_steps=checkpoint_steps,
                metrics_logger=metrics_logger,
                init_logger=init_logger))
        self.assertLen(epoch_reports,
                       (num_train_steps_reload - num_train_steps) / eval_every)
        with tf.io.gfile.GFile(os.path.join(self.test_dir,
                                            'measurements.csv')) as f:
            df = pandas.read_csv(f)
            train_err = df['train/error_rate'].values[-1]
            train_loss = df['train/ce_loss'].values[-1]
            self.assertLess(train_err, 0.35)
            self.assertLess(train_loss, 0.1)

            self.assertEqual(df['valid/num_examples'].values[-1],
                             eval_num_batches * eval_batch_size)
            self.assertEqual(df['preemption_count'].values[-1], 1)
            # Check that the correct learning rate was saved in the measurements file.
            final_learning_rate = df['learning_rate'].values[-1]
            final_step = df['global_step'].values[-1]
            self.assertEqual(num_train_steps_reload, final_step)

            # final_step will be one larger than the last step used to calculate the
            # lr_decay, hense we plug in (final_step - 1) to the decay formula.
            # Note that there is a small numerical different here with np vs jnp.
            decay_factor = (1 + np.cos(
                (final_step - 1) / num_train_steps_reload * np.pi)) * 0.5
            self.assertEqual(float(final_learning_rate),
                             hps.lr_hparams['base_lr'] * decay_factor)

        self.assertEqual(set(df.columns.values), set(get_column_names()))