def testEquals(self): """Tests that __eq__() respects hidden mutability.""" fcd = _test_frozenconfigdict() # First, ensure __eq__() returns False when comparing to other types self.assertNotEqual(fcd, (1, 2)) self.assertNotEqual(fcd, fcd.as_configdict()) list_to_tuple = _test_dict_deepcopy() list_to_tuple['list'] = tuple(list_to_tuple['list']) fcd_list_to_tuple = ml_collections.FrozenConfigDict(list_to_tuple) set_to_frozenset = _test_dict_deepcopy() set_to_frozenset['set'] = frozenset(set_to_frozenset['set']) fcd_set_to_frozenset = ml_collections.FrozenConfigDict( set_to_frozenset) self.assertNotEqual(fcd, fcd_list_to_tuple) # Because set == frozenset in Python: self.assertEqual(fcd, fcd_set_to_frozenset) # Items are not affected by hidden mutability self.assertCountEqual(fcd.items(), fcd_list_to_tuple.items()) self.assertCountEqual(fcd.items(), fcd_set_to_frozenset.items())
def testEqualsAsConfigDict(self): """Tests that eq_as_configdict respects hidden mutability but not type.""" fcd = _test_frozenconfigdict() # First, ensure eq_as_configdict() returns True with an equal ConfigDict but # False for other types. self.assertFalse(fcd.eq_as_configdict([1, 2])) self.assertTrue(fcd.eq_as_configdict(fcd.as_configdict())) empty_fcd = ml_collections.FrozenConfigDict() self.assertTrue(empty_fcd.eq_as_configdict( ml_collections.ConfigDict())) # Now, ensure it has the same immutability detection as __eq__(). list_to_tuple = _test_dict_deepcopy() list_to_tuple['list'] = tuple(list_to_tuple['list']) fcd_list_to_tuple = ml_collections.FrozenConfigDict(list_to_tuple) set_to_frozenset = _test_dict_deepcopy() set_to_frozenset['set'] = frozenset(set_to_frozenset['set']) fcd_set_to_frozenset = ml_collections.FrozenConfigDict( set_to_frozenset) self.assertFalse(fcd.eq_as_configdict(fcd_list_to_tuple)) # Because set == frozenset in Python: self.assertTrue(fcd.eq_as_configdict(fcd_set_to_frozenset))
def get_config(): """Returns a training configuration.""" config = ml_collections.ConfigDict() config.rng_seed = 0 config.num_trajectories = 1 config.single_step_predictions = True config.num_samples = 1000 config.split_on = 'times' config.train_split_proportion = 80 / 1000 config.time_delta = 1. config.train_time_jump_range = (1, 10) config.test_time_jumps = (1, 2, 5, 10, 20, 50) config.num_train_steps = 5000 config.latent_size = 100 config.activation = 'relu' config.model = 'euler-update-network' config.encoder_decoder_type = 'mlp' config.scaler = 'identity' config.learning_rate = 1e-3 config.batch_size = 100 config.eval_cadence = 50 config.simulation = 'shm' config.regularizations = ml_collections.FrozenConfigDict() config.simulation_parameter_ranges = ml_collections.FrozenConfigDict({ 'phi': (0, 0), 'A': (1, 10), 'm': (1, 5), 'w': (0.05, 0.1), }) return config
def testAsConfigDict(self): """Tests that converting FrozenConfigDict to ConfigDict works correctly. In particular, ensures that FrozenConfigDict does the inverse of ConfigDict regarding type_safe, lock, and attribute mutability. """ # First ensure conversion to ConfigDict works on empty FrozenConfigDict self.assertEqual( ml_collections.ConfigDict(ml_collections.FrozenConfigDict()), ml_collections.ConfigDict()) cd = _test_configdict() cd_fcd_cd = ml_collections.ConfigDict( ml_collections.FrozenConfigDict(cd)) self.assertEqual(cd, cd_fcd_cd) # Make sure locking is respected cd.lock() self.assertEqual( cd, ml_collections.ConfigDict(ml_collections.FrozenConfigDict(cd))) # Make sure type_safe is respected cd = ml_collections.ConfigDict(_TEST_DICT, type_safe=False) self.assertEqual( cd, ml_collections.ConfigDict(ml_collections.FrozenConfigDict(cd)))
def testBasicEquality(self): """Tests basic equality with different types of initialization.""" fcd = _test_frozenconfigdict() fcd_cd = ml_collections.FrozenConfigDict(_test_configdict()) fcd_fcd = ml_collections.FrozenConfigDict(fcd) self.assertEqual(fcd, fcd_cd) self.assertEqual(fcd, fcd_fcd)
def testInitInvalidAttributeName(self): """Ensure initialization fails on attributes with invalid names.""" dot_name = {'dot.name': None} immutable_name = {'__hash__': None} with self.assertRaises(ValueError): ml_collections.FrozenConfigDict(dot_name) with self.assertRaises(AttributeError): ml_collections.FrozenConfigDict(immutable_name)
def test_loss_fn(self, text_length, n_mentions, n_linked_mentions, no_entity_attention): """Test loss function runs and produces expected values.""" model_config = copy.deepcopy(self.model_config) model_config['encoder_config']['no_entity_attention'] = no_entity_attention model_config = ml_collections.FrozenConfigDict(model_config) config = ml_collections.FrozenConfigDict(self.config) max_length = model_config.encoder_config.max_length preprocess_fn = eae_task.EaETask.make_preprocess_fn(config) collater_fn = eae_task.EaETask.make_collater_fn(config) postprocess_fn = eae_task.EaETask.make_output_postprocess_fn(config) model = eae_task.EaETask.build_model(model_config) dummy_input = eae_task.EaETask.dummy_input(config) init_rng = jax.random.PRNGKey(0) init_parameters = model.init(init_rng, dummy_input, True) raw_example = test_utils.gen_mention_pretraining_sample( text_length, n_mentions, n_linked_mentions, max_length=max_length) processed_example = preprocess_fn(raw_example) batch = { key: np.tile(value, (config.per_device_batch_size, 1)) for key, value in processed_example.items() } batch = collater_fn(batch) batch = jax.tree_map(np.asarray, batch) loss_fn = eae_task.EaETask.make_loss_fn(config) _, metrics, auxiliary_output = loss_fn( model_config=model_config, model_params=init_parameters['params'], model_vars={}, batch=batch, deterministic=True, ) self.assertEqual(metrics['mlm']['denominator'], batch['mlm_target_weights'].sum()) self.assertEqual(metrics['el_intermediate']['denominator'], batch['mention_target_weights'].sum()) if batch['mention_target_weights'].sum() > 0: self.assertFalse(np.isnan(metrics['el_intermediate']['loss'])) self.assertEqual(metrics['el_final']['denominator'], batch['mention_target_weights'].sum()) if batch['mention_target_weights'].sum() > 0: self.assertFalse(np.isnan(metrics['el_final']['loss'])) features = postprocess_fn(batch, auxiliary_output) # Check features are JSON-serializable json.dumps(features) # Check features match the original batch for key in batch.keys(): self.assertArrayEqual(np.array(features[key]), batch[key])
def testHash(self): """Ensures __hash__() respects hidden mutability.""" list_to_tuple = _test_dict_deepcopy() list_to_tuple['list'] = tuple(list_to_tuple['list']) self.assertEqual( hash(_test_frozenconfigdict()), hash(ml_collections.FrozenConfigDict(_test_dict_deepcopy()))) self.assertNotEqual( hash(_test_frozenconfigdict()), hash(ml_collections.FrozenConfigDict(list_to_tuple))) # Ensure Python realizes FrozenConfigDict is hashable self.assertIsInstance(_test_frozenconfigdict(), collections_abc.Hashable)
def test_parameterized_mixing(self, model_arch, mixing_layer_name): config = dummy_config(model_arch=model_arch) with config.unlocked(): config.use_fft = True frozen_config = ml_collections.FrozenConfigDict(config) encoder = models.EncoderModel(config=frozen_config) rng = jax.random.PRNGKey(0) init_batch = init_encoder_batch(config=frozen_config) params = init_model_params(rng, encoder, init_batch) expected_keys = { "embedder", "encoder_0", "encoder_1", "feed_forward_0", "feed_forward_1", f"{mixing_layer_name}_0", f"{mixing_layer_name}_1", "pooler" } self.assertEqual(params.keys(), expected_keys) inputs = dummy_inputs(rng, config=frozen_config) encoder_output = encoder.apply({"params": params}, rngs={"dropout": rng}, **inputs) expected_sequence_output_shape = (config.train_batch_size, config.max_seq_length, config.d_model) self.assertEqual(encoder_output.sequence_output.shape, expected_sequence_output_shape) expected_pooled_output_shape = (config.train_batch_size, config.d_model) self.assertEqual(encoder_output.pooled_output.shape, expected_pooled_output_shape)
def render_and_save(): """Renders the image according to the configuration and saves it to disk.""" rendering_config = configuration.get_config() rendering_config = ml_collections.FrozenConfigDict(rendering_config) aspect_ratio = rendering_config.aspect_ratio height = rendering_config.height width = int(aspect_ratio * height) scene_camera = build_camera(rendering_config, aspect_ratio) world = build_world(rendering_config) # Render. logging.info("Tracing rays...") render_image_fn = jax.jit(render.generate_image, static_argnames=["height", "width", "config"]) image = render_image_fn(height, width, scene_camera, world, rendering_config) image = render.correct_gamma(image, gamma=rendering_config.gamma_correction) logging.info("Saving to file...") output.export_as_ppm(image, rendering_config.output_file) return image
def test_unparametrized_mixing_encoder(self, model_arch): config = dummy_config(model_arch=model_arch) frozen_config = ml_collections.FrozenConfigDict(config) encoder = models.EncoderModel(config=frozen_config) rng = jax.random.PRNGKey(0) init_batch = init_encoder_batch(config) params = init_model_params(rng, encoder, init_batch) # Unparameterized mixing encoders do not have any parameters in their mixing # layers, so their mixing layer names do not show up in params. expected_keys = { "embedder", "encoder_0", "encoder_1", "feed_forward_0", "feed_forward_1", "pooler" } self.assertEqual(params.keys(), expected_keys) inputs = dummy_inputs(rng, config) hidden_states, pooled_output = encoder.apply({"params": params}, rngs={"dropout": rng}, **inputs) expected_hidden_states_shape = (config.train_batch_size, config.max_seq_length, config.d_model) self.assertEqual(hidden_states.shape, expected_hidden_states_shape) expected_pooled_output_shape = (config.train_batch_size, config.d_model) self.assertEqual(pooled_output.shape, expected_pooled_output_shape)
def test_hybrid_encoder(self, attention_layout, num_attention_layers, expected_attention_layers): config = dummy_config(model_arch=ModelArchitecture.F_NET) with config.unlocked(): config.num_layers = 4 config.attention_layout = attention_layout config.num_attention_layers = num_attention_layers frozen_config = ml_collections.FrozenConfigDict(config) encoder = models.EncoderModel(config=frozen_config) rng = jax.random.PRNGKey(0) init_batch = init_encoder_batch(config) params = init_model_params(rng, encoder, init_batch) expected_keys = { "embedder", "encoder_0", "encoder_1", "encoder_2", "encoder_3", "feed_forward_0", "feed_forward_1", "feed_forward_2", "feed_forward_3", "pooler" } for expected_attention_layer in expected_attention_layers: expected_keys.add(f"self_attention_{expected_attention_layer}") self.assertEqual(params.keys(), expected_keys) inputs = dummy_inputs(rng, config) hidden_states, pooled_output = encoder.apply({"params": params}, rngs={"dropout": rng}, **inputs) expected_hidden_states_shape = (config.train_batch_size, config.max_seq_length, config.d_model) self.assertEqual(hidden_states.shape, expected_hidden_states_shape) expected_pooled_output_shape = (config.train_batch_size, config.d_model) self.assertEqual(pooled_output.shape, expected_pooled_output_shape)
def test_classification_model(self): n_classes = 2 config = dummy_config(model_arch=ModelArchitecture.BERT) with config.unlocked(): config.dataset_name = "dummy/classification_dataset" frozen_config = ml_collections.FrozenConfigDict(config) model = models.SequenceClassificationModel(config=frozen_config, n_classes=n_classes) rng = jax.random.PRNGKey(0) init_batch = init_encoder_batch(config) params = init_model_params(rng, model, init_batch) self.assertEqual(params.keys(), {"encoder", "classification"}) # Logits for eval/prediction (no labels supplied). eval_inputs = dummy_inputs(rng, config) eval_inputs["deterministic"] = True logits = model.apply({"params": params}, **eval_inputs) expected_logits_shape = (config.train_batch_size, n_classes) self.assertEqual(jnp.shape(logits), expected_logits_shape) # Metrics for training (labels supplied). train_inputs = dummy_inputs(rng, config) train_inputs["labels"] = jnp.ones(config.train_batch_size, jnp.int32) metrics = model.apply({"params": params}, rngs={"dropout": rng}, **train_inputs) self.assertEqual(metrics.keys(), {"loss", "correct_predictions", "num_labels"})
def test_regression_model(self): n_classes = 1 # Only one label for regression config = dummy_config(model_arch=ModelArchitecture.F_NET) with config.unlocked(): config.dataset_name = "glue/stsb" # regression task dataset frozen_config = ml_collections.FrozenConfigDict(config) model = models.SequenceClassificationModel(config=frozen_config, n_classes=n_classes) rng = jax.random.PRNGKey(0) init_batch = init_encoder_batch(config) params = init_model_params(rng, model, init_batch) self.assertEqual(params.keys(), {"encoder", "classification"}) # Logits for eval/prediction (no labels supplied). eval_inputs = dummy_inputs(rng, config) eval_inputs["deterministic"] = True logits = model.apply({"params": params}, **eval_inputs) expected_logits_shape = (config.train_batch_size, n_classes) self.assertEqual(jnp.shape(logits), expected_logits_shape) # Metrics for training (labels supplied). train_inputs = dummy_inputs(rng, config) _, label_key = jax.random.split(rng) train_inputs["labels"] = jax.random.uniform( label_key, (config.train_batch_size, ), minval=0., maxval=1.) metrics = model.apply({"params": params}, rngs={"dropout": rng}, **train_inputs) self.assertEqual(metrics.keys(), {"loss", "num_labels"})
def test_regression_model(self): n_classes = 1 # Only one label for regression config = dummy_config(model_arch=ModelArchitecture.F_NET) with config.unlocked(): config.dataset_name = "glue/stsb" # regression task dataset config.num_moe_layers = 1 # Add moe layer to verify expert metrics num_tokens = config.train_batch_size * config.max_seq_length config.max_group_size = num_tokens frozen_config = ml_collections.FrozenConfigDict(config) model = models.SequenceClassificationModel(config=frozen_config, n_classes=n_classes) rng = jax.random.PRNGKey(0) init_batch = init_encoder_batch(config=frozen_config) params = init_model_params(rng, model, init_batch) self.assertEqual(params.keys(), {"encoder", "classification"}) # Logits for eval/prediction (no labels supplied). eval_inputs = dummy_inputs(rng, config=frozen_config) eval_inputs["deterministic"] = True logits = model.apply({"params": params}, **eval_inputs) expected_logits_shape = (config.train_batch_size, n_classes) self.assertEqual(jnp.shape(logits), expected_logits_shape) # Metrics for training (labels supplied). train_inputs = dummy_inputs(rng, config=frozen_config) label_key, dropout_key, jitter_key = jax.random.split(rng, num=3) train_inputs["labels"] = jax.random.uniform( label_key, (config.train_batch_size, ), minval=0., maxval=1.) metrics, state = model.apply({"params": params}, rngs={ "dropout": dropout_key, "jitter": jitter_key }, mutable=["intermediates"], **train_inputs) self.assertAlmostEqual(metrics.batch_loss, 1.0100806, places=6) self.assertEqual(metrics.num_labels, 3) self.assertIn("intermediates", state) jax.tree_util.tree_map( functools.partial(np.testing.assert_allclose, rtol=1e-6), state["intermediates"], FrozenDict({ "encoder": { "moe_1": { "diversity_metrics": models.DiversityMetrics( auxiliary_loss=1.0073242, router_z_loss=0.45751953, fraction_tokens_left_behind=0.25, expert_usage=0.75, router_confidence=0.51171875) } } }))
def testToDict(self): """Ensure to_dict() does not care about hidden mutability.""" list_to_tuple = _test_dict_deepcopy() list_to_tuple['list'] = tuple(list_to_tuple['list']) self.assertEqual( _test_frozenconfigdict().to_dict(), ml_collections.FrozenConfigDict(list_to_tuple).to_dict())
def get_config(): """Returns a training configuration.""" config = ml_collections.ConfigDict() config.rng_seed = 0 config.num_trajectories = 1 config.single_step_predictions = True config.num_samples = 1000 config.split_on = 'times' config.train_split_proportion = 80 / 1000 config.time_delta = 1. config.train_time_jump_range = (1, 10) config.test_time_jumps = (1, 2, 5, 10, 20, 50) config.num_train_steps = 5000 config.latent_size = 100 config.activation = 'sigmoid' config.model = 'action-angle-network' config.encoder_decoder_type = 'flow' config.flow_type = 'shear' config.num_flow_layers = 10 config.num_coordinates = 2 if config.flow_type == 'masked_coupling': config.flow_spline_range_min = -3 config.flow_spline_range_max = 3 config.flow_spline_bins = 100 config.polar_action_angles = True config.scaler = 'identity' config.learning_rate = 1e-3 config.batch_size = 100 config.eval_cadence = 50 config.simulation = 'shm' config.regularizations = ml_collections.FrozenConfigDict({ 'actions': 1., 'angular_velocities': 0., 'encoded_decoded_differences': 0., }) config.simulation_parameter_ranges = ml_collections.FrozenConfigDict({ 'phi': (0, 0), 'A': (1, 10), 'm': (1, 5), 'w': (0.05, 0.1), }) return config
def testFieldReferenceResolved(self): """Tests that FieldReferences are resolved.""" cfg = ml_collections.ConfigDict( {'fr': ml_collections.FieldReference(1)}) frozen_cfg = ml_collections.FrozenConfigDict(cfg) self.assertNotIsInstance(frozen_cfg._fields['fr'], ml_collections.FieldReference) hash( frozen_cfg) # with FieldReference resolved, frozen_cfg is hashable
def testFieldReferenceCycle(self): """Tests that FieldReferences may not contain reference cycles.""" frozenset_fr = {'frozenset': frozenset({1, 2})} frozenset_fr['fr'] = ml_collections.FieldReference( frozenset_fr['frozenset']) list_fr = {'list': [1, 2]} list_fr['fr'] = ml_collections.FieldReference(list_fr['list']) cyclic_fr = {'a': 1} cyclic_fr['fr'] = ml_collections.FieldReference(cyclic_fr) cyclic_fr_parent = {'dict': {}} cyclic_fr_parent['dict']['fr'] = ml_collections.FieldReference( cyclic_fr_parent) # FieldReference is allowed to point to non-cyclic objects: _ = ml_collections.FrozenConfigDict(frozenset_fr) _ = ml_collections.FrozenConfigDict(list_fr) # But not cycles: self.assertFrozenRaisesValueError([cyclic_fr, cyclic_fr_parent])
def testPickle(self): """Make sure FrozenConfigDict can be dumped and loaded with pickle.""" fcd = _test_frozenconfigdict() locked_fcd = ml_collections.FrozenConfigDict(_test_configdict().lock()) unpickled_fcd = pickle.loads(pickle.dumps(fcd)) unpickled_locked_fcd = pickle.loads(pickle.dumps(locked_fcd)) self.assertEqual(fcd, unpickled_fcd) self.assertEqual(locked_fcd, unpickled_locked_fcd)
def testInitConfigDict(self): """Tests that ConfigDict initialization handles FrozenConfigDict. Initializing a ConfigDict on a dictionary with FrozenConfigDict values should unfreeze these values. """ dict_without_fcd_node = _test_dict_deepcopy() dict_without_fcd_node.pop('ref') dict_with_fcd_node = copy.deepcopy(dict_without_fcd_node) dict_with_fcd_node['dict'] = ml_collections.FrozenConfigDict( dict_with_fcd_node['dict']) cd_without_fcd_node = ml_collections.ConfigDict(dict_without_fcd_node) cd_with_fcd_node = ml_collections.ConfigDict(dict_with_fcd_node) fcd_without_fcd_node = ml_collections.FrozenConfigDict( dict_without_fcd_node) fcd_with_fcd_node = ml_collections.FrozenConfigDict(dict_with_fcd_node) self.assertEqual(cd_without_fcd_node, cd_with_fcd_node) self.assertEqual(fcd_without_fcd_node, fcd_with_fcd_node)
def dummy_frozen_config(): """Creates a dummy model config that can be used by all tests.""" config = default_config.get_config() config.model_arch = default_config.ModelArchitecture.FF_ONLY config.d_emb = 4 config.d_model = 4 config.d_ff = 4 config.max_seq_length = 8 config.num_layers = 1 config.vocab_size = 1000 config.train_batch_size = 2 return ml_collections.FrozenConfigDict(config)
def main(_): print_section('Attribute Types.') cfg = ml_collections.ConfigDict() cfg.int = 1 cfg.list = [1, 2, 3] cfg.tuple = (1, 2, 3) cfg.set = {1, 2, 3} cfg.frozenset = frozenset({1, 2, 3}) cfg.dict = { 'nested_int': 4, 'nested_list': [4, 5, 6], 'nested_tuple': ([4], 5, 6), } print('Types of cfg fields:') print('list: ', type(cfg.list)) # List print('set: ', type(cfg.set)) # Set print('nested_list: ', type(cfg.dict.nested_list)) # List print('nested_tuple[0]: ', type(cfg.dict.nested_tuple[0])) # List frozen_cfg = ml_collections.FrozenConfigDict(cfg) print('\nTypes of FrozenConfigDict(cfg) fields:') print('list: ', type(frozen_cfg.list)) # Tuple print('set: ', type(frozen_cfg.set)) # Frozenset print('nested_list: ', type(frozen_cfg.dict.nested_list)) # Tuple print('nested_tuple[0]: ', type(frozen_cfg.dict.nested_tuple[0])) # Tuple cfg_from_frozen = ml_collections.ConfigDict(frozen_cfg) print('\nTypes of ConfigDict(FrozenConfigDict(cfg)) fields:') print('list: ', type(cfg_from_frozen.list)) # List print('set: ', type(cfg_from_frozen.set)) # Set print('nested_list: ', type(cfg_from_frozen.dict.nested_list)) # List print('nested_tuple[0]: ', type(cfg_from_frozen.dict.nested_tuple[0])) # List print( '\nCan use FrozenConfigDict.as_configdict() to convert to ConfigDict:') print(cfg_from_frozen == frozen_cfg.as_configdict()) # True print_section('Immutability.') try: frozen_cfg.new_field = 1 # Raises AttributeError because of immutability. except AttributeError as e: print(e) print_section('"==" and eq_as_configdict().') # FrozenConfigDict.__eq__() is not type-invariant with respect to ConfigDict print(frozen_cfg == cfg) # False # FrozenConfigDict.eq_as_configdict() is type-invariant with respect to # ConfigDict print(frozen_cfg.eq_as_configdict(cfg)) # True # .eq_as_congfigdict() is also a method of ConfigDict print(cfg.eq_as_configdict(frozen_cfg)) # True
def test_moe(self, moe_layout, num_moe_layers, expected_ff_layer_keys, expected_moe_layer_keys, expected_moe_keys): config = dummy_config(model_arch=ModelArchitecture.T_NET) with config.unlocked(): config.num_layers = 4 # More layers so we can test different layouts config.moe_layout = moe_layout config.num_moe_layers = num_moe_layers frozen_config = ml_collections.FrozenConfigDict(config) encoder = models.EncoderModel(config=frozen_config) rng = jax.random.PRNGKey(0) init_batch = init_encoder_batch(config=frozen_config) params = init_model_params(rng, encoder, init_batch) expected_keys = { "embedder", "encoder_0", "encoder_1", "encoder_2", "encoder_3", "pooler", "toeplitz_transform_0", "toeplitz_transform_1", "toeplitz_transform_2", "toeplitz_transform_3" } expected_keys.update(expected_ff_layer_keys) expected_keys.update(expected_moe_keys) self.assertEqual(params.keys(), expected_keys) rng, dropout_key, jitter_key = jax.random.split(rng, num=3) inputs = dummy_inputs(rng, config=frozen_config) encoder_output, state = encoder.apply({"params": params}, rngs={ "dropout": dropout_key, "jitter": jitter_key }, mutable=["intermediates"], **inputs) expected_sequence_output_shape = (config.train_batch_size, config.max_seq_length, config.d_model) self.assertEqual(encoder_output.sequence_output.shape, expected_sequence_output_shape) expected_pooled_output_shape = (config.train_batch_size, config.d_model) self.assertEqual(encoder_output.pooled_output.shape, expected_pooled_output_shape) if num_moe_layers > 0: self.assertIn("intermediates", state) for layer in expected_moe_layer_keys: self.assertIn(layer, state["intermediates"]) self.assertIn("diversity_metrics", state["intermediates"][layer]) else: self.assertNotIn("intermediates", state)
def test_embedding_layer(self): config = ml_collections.ConfigDict({ "batch_size": 3, "vocab_size": 1000, "d_emb": 32, "max_seq_length": 64, "type_vocab_size": 2, "d_model": 4, "dropout_rate": 0.1, "dtype": jnp.float32 }) frozen_config = ml_collections.FrozenConfigDict(config) rng = jax.random.PRNGKey(100) embedding_layer = layers.EmbeddingLayer(config=frozen_config) init_batch = { "input_ids": jnp.ones((1, frozen_config.max_seq_length), jnp.int32), "type_ids": jnp.ones((1, frozen_config.max_seq_length), jnp.int32) } params = init_layer_variables(rng, embedding_layer, init_batch)["params"] expected_keys = { "word", "position", "type", "layer_norm", "hidden_mapping_in" } self.assertEqual(params.keys(), expected_keys) rng, init_rng = jax.random.split(rng) inputs = { "input_ids": jax.random.randint( init_rng, (frozen_config.batch_size, frozen_config.max_seq_length), minval=0, maxval=13), "type_ids": jax.random.randint( init_rng, (frozen_config.batch_size, frozen_config.max_seq_length), minval=0, maxval=2) } outputs = embedding_layer.apply({"params": params}, rngs={"dropout": rng}, **inputs) self.assertEqual(outputs.shape, (frozen_config.batch_size, frozen_config.max_seq_length, frozen_config.d_model))
def test_f_net_encoder_bad_long_seq(self): config = dummy_config(model_arch=ModelArchitecture.F_NET) with config.unlocked(): config.max_seq_length = 8194 frozen_config = ml_collections.FrozenConfigDict(config) encoder = models.EncoderModel(config=frozen_config) rng = jax.random.PRNGKey(0) init_batch = init_encoder_batch(config) with self.assertRaisesRegex( ValueError, "must be a power of 2 to take advantage of FFT optimizations"): _ = init_model_params(rng, encoder, init_batch)
def test_model_shape( self, text_length, n_mentions, n_linked_mentions, no_entity_attention, ): """Test model forward runs and produces expected shape.""" config = copy.deepcopy(self.config) config['model_config']['encoder_config'][ 'no_entity_attention'] = no_entity_attention config = ml_collections.FrozenConfigDict(self.config) model_config = config.model_config encoder_config = model_config.encoder_config max_length = model_config.encoder_config.max_length preprocess_fn = eae_task.EaETask.make_preprocess_fn(config) collater_fn = eae_task.EaETask.make_collater_fn(config) raw_example = test_utils.gen_mention_pretraining_sample( text_length, n_mentions, n_linked_mentions, max_length=max_length) processed_example = preprocess_fn(raw_example) batch = { key: np.tile(value, (config.per_device_batch_size, 1)) for key, value in processed_example.items() } batch = collater_fn(batch) batch = jax.tree_map(np.asarray, batch) model = eae_encoder.EaEEncoder(**model_config.encoder_config) init_rng = jax.random.PRNGKey(0) (encoded_output, loss_helpers, _), _ = model.init_with_output( init_rng, batch, deterministic=True, method=model.forward, ) self.assertEqual( encoded_output.shape, (config.per_device_batch_size, encoder_config.max_length, encoder_config.hidden_size)) self.assertEqual( loss_helpers['target_mention_encodings'].shape, (config.max_mention_targets * config.per_device_batch_size, encoder_config.entity_dim))
def test_dataset_standard_batching(self): dataset_name = 'control_flow_programs/decimal-L10' data_dir = tempfile.mkdtemp() config = config_lib.get_config() with config.unlocked(): config.dataset.name = dataset_name config.dataset.in_memory = True config.dataset.batch_size = 5 config.dataset.representation = 'trace' config = ml_collections.FrozenConfigDict(config) dataset_info = dataset_utils.get_dataset(data_dir, config) item = next(iter(dataset_info.dataset)) self.assertEqual(item['cfg']['data'].shape[0], 5)
def test_output_expected(self): model_config_dict = { 'dtype': 'float32', 'features': [32, 32], 'use_bias': False, } config_dict = { 'model_config': model_config_dict, 'per_device_batch_size': 32, 'seed': 0, } config = ml_collections.ConfigDict(config_dict) model_config = ml_collections.FrozenConfigDict(model_config_dict) batch = { 'x': jnp.zeros( (config.per_device_batch_size, config.model_config.features[0]), dtype=config.model_config.dtype, ) } init_rng = jax.random.PRNGKey(config.seed) # Create dummy input dummy_input = example_task.ExampleTask.dummy_input(config) model = example_task.ExampleTask.build_model(model_config) initial_variables = jax.jit(model.init)(init_rng, dummy_input, False) loss_fn = example_task.ExampleTask.make_loss_fn(config) loss, _, _ = loss_fn( model_config=model_config, model_params=initial_variables['params'], model_vars={}, batch=batch, deterministic=False, ) self.assertEqual(loss, 0.0)
def frozen_config(sharded_params=False): """Creates a dummy model config that can be used by all tests.""" config = default_config.get_config() config.model_arch = default_config.ModelArchitecture.LINEAR.name config.num_attention_layers = 0 config.d_emb = 4 config.d_model = 4 config.d_ff = 4 config.max_seq_length = 8 config.num_layers = 1 config.vocab_size = 16 config.train_batch_size = 4 config.dtype = jnp.float32 config.pad_id = 3 # MoE layers contain sharded parameters. config.num_moe_layers = 1 if sharded_params else 0 config.num_experts = 1 if sharded_params else 0 config.auxiliary_loss_factor = 0.01 config.router_z_loss_factor = 0.01 return ml_collections.FrozenConfigDict(config)