def testEquals(self):
        """Tests that __eq__() respects hidden mutability."""
        fcd = _test_frozenconfigdict()

        # First, ensure __eq__() returns False when comparing to other types
        self.assertNotEqual(fcd, (1, 2))
        self.assertNotEqual(fcd, fcd.as_configdict())

        list_to_tuple = _test_dict_deepcopy()
        list_to_tuple['list'] = tuple(list_to_tuple['list'])
        fcd_list_to_tuple = ml_collections.FrozenConfigDict(list_to_tuple)

        set_to_frozenset = _test_dict_deepcopy()
        set_to_frozenset['set'] = frozenset(set_to_frozenset['set'])
        fcd_set_to_frozenset = ml_collections.FrozenConfigDict(
            set_to_frozenset)

        self.assertNotEqual(fcd, fcd_list_to_tuple)

        # Because set == frozenset in Python:
        self.assertEqual(fcd, fcd_set_to_frozenset)

        # Items are not affected by hidden mutability
        self.assertCountEqual(fcd.items(), fcd_list_to_tuple.items())
        self.assertCountEqual(fcd.items(), fcd_set_to_frozenset.items())
    def testEqualsAsConfigDict(self):
        """Tests that eq_as_configdict respects hidden mutability but not type."""
        fcd = _test_frozenconfigdict()

        # First, ensure eq_as_configdict() returns True with an equal ConfigDict but
        # False for other types.
        self.assertFalse(fcd.eq_as_configdict([1, 2]))
        self.assertTrue(fcd.eq_as_configdict(fcd.as_configdict()))
        empty_fcd = ml_collections.FrozenConfigDict()
        self.assertTrue(empty_fcd.eq_as_configdict(
            ml_collections.ConfigDict()))

        # Now, ensure it has the same immutability detection as __eq__().
        list_to_tuple = _test_dict_deepcopy()
        list_to_tuple['list'] = tuple(list_to_tuple['list'])
        fcd_list_to_tuple = ml_collections.FrozenConfigDict(list_to_tuple)

        set_to_frozenset = _test_dict_deepcopy()
        set_to_frozenset['set'] = frozenset(set_to_frozenset['set'])
        fcd_set_to_frozenset = ml_collections.FrozenConfigDict(
            set_to_frozenset)

        self.assertFalse(fcd.eq_as_configdict(fcd_list_to_tuple))
        # Because set == frozenset in Python:
        self.assertTrue(fcd.eq_as_configdict(fcd_set_to_frozenset))
Exemple #3
0
def get_config():
    """Returns a training configuration."""
    config = ml_collections.ConfigDict()
    config.rng_seed = 0
    config.num_trajectories = 1
    config.single_step_predictions = True
    config.num_samples = 1000
    config.split_on = 'times'
    config.train_split_proportion = 80 / 1000
    config.time_delta = 1.
    config.train_time_jump_range = (1, 10)
    config.test_time_jumps = (1, 2, 5, 10, 20, 50)
    config.num_train_steps = 5000
    config.latent_size = 100
    config.activation = 'relu'
    config.model = 'euler-update-network'
    config.encoder_decoder_type = 'mlp'
    config.scaler = 'identity'
    config.learning_rate = 1e-3
    config.batch_size = 100
    config.eval_cadence = 50
    config.simulation = 'shm'
    config.regularizations = ml_collections.FrozenConfigDict()
    config.simulation_parameter_ranges = ml_collections.FrozenConfigDict({
        'phi': (0, 0),
        'A': (1, 10),
        'm': (1, 5),
        'w': (0.05, 0.1),
    })
    return config
    def testAsConfigDict(self):
        """Tests that converting FrozenConfigDict to ConfigDict works correctly.

    In particular, ensures that FrozenConfigDict does the inverse of ConfigDict
    regarding type_safe, lock, and attribute mutability.
    """
        # First ensure conversion to ConfigDict works on empty FrozenConfigDict
        self.assertEqual(
            ml_collections.ConfigDict(ml_collections.FrozenConfigDict()),
            ml_collections.ConfigDict())

        cd = _test_configdict()
        cd_fcd_cd = ml_collections.ConfigDict(
            ml_collections.FrozenConfigDict(cd))
        self.assertEqual(cd, cd_fcd_cd)

        # Make sure locking is respected
        cd.lock()
        self.assertEqual(
            cd, ml_collections.ConfigDict(ml_collections.FrozenConfigDict(cd)))

        # Make sure type_safe is respected
        cd = ml_collections.ConfigDict(_TEST_DICT, type_safe=False)
        self.assertEqual(
            cd, ml_collections.ConfigDict(ml_collections.FrozenConfigDict(cd)))
 def testBasicEquality(self):
     """Tests basic equality with different types of initialization."""
     fcd = _test_frozenconfigdict()
     fcd_cd = ml_collections.FrozenConfigDict(_test_configdict())
     fcd_fcd = ml_collections.FrozenConfigDict(fcd)
     self.assertEqual(fcd, fcd_cd)
     self.assertEqual(fcd, fcd_fcd)
    def testInitInvalidAttributeName(self):
        """Ensure initialization fails on attributes with invalid names."""
        dot_name = {'dot.name': None}
        immutable_name = {'__hash__': None}

        with self.assertRaises(ValueError):
            ml_collections.FrozenConfigDict(dot_name)

        with self.assertRaises(AttributeError):
            ml_collections.FrozenConfigDict(immutable_name)
Exemple #7
0
  def test_loss_fn(self, text_length, n_mentions, n_linked_mentions,
                   no_entity_attention):
    """Test loss function runs and produces expected values."""

    model_config = copy.deepcopy(self.model_config)
    model_config['encoder_config']['no_entity_attention'] = no_entity_attention
    model_config = ml_collections.FrozenConfigDict(model_config)
    config = ml_collections.FrozenConfigDict(self.config)

    max_length = model_config.encoder_config.max_length
    preprocess_fn = eae_task.EaETask.make_preprocess_fn(config)
    collater_fn = eae_task.EaETask.make_collater_fn(config)
    postprocess_fn = eae_task.EaETask.make_output_postprocess_fn(config)

    model = eae_task.EaETask.build_model(model_config)
    dummy_input = eae_task.EaETask.dummy_input(config)
    init_rng = jax.random.PRNGKey(0)
    init_parameters = model.init(init_rng, dummy_input, True)

    raw_example = test_utils.gen_mention_pretraining_sample(
        text_length, n_mentions, n_linked_mentions, max_length=max_length)
    processed_example = preprocess_fn(raw_example)
    batch = {
        key: np.tile(value, (config.per_device_batch_size, 1))
        for key, value in processed_example.items()
    }
    batch = collater_fn(batch)
    batch = jax.tree_map(np.asarray, batch)

    loss_fn = eae_task.EaETask.make_loss_fn(config)
    _, metrics, auxiliary_output = loss_fn(
        model_config=model_config,
        model_params=init_parameters['params'],
        model_vars={},
        batch=batch,
        deterministic=True,
    )

    self.assertEqual(metrics['mlm']['denominator'],
                     batch['mlm_target_weights'].sum())
    self.assertEqual(metrics['el_intermediate']['denominator'],
                     batch['mention_target_weights'].sum())
    if batch['mention_target_weights'].sum() > 0:
      self.assertFalse(np.isnan(metrics['el_intermediate']['loss']))
    self.assertEqual(metrics['el_final']['denominator'],
                     batch['mention_target_weights'].sum())
    if batch['mention_target_weights'].sum() > 0:
      self.assertFalse(np.isnan(metrics['el_final']['loss']))

    features = postprocess_fn(batch, auxiliary_output)
    # Check features are JSON-serializable
    json.dumps(features)
    # Check features match the original batch
    for key in batch.keys():
      self.assertArrayEqual(np.array(features[key]), batch[key])
    def testHash(self):
        """Ensures __hash__() respects hidden mutability."""
        list_to_tuple = _test_dict_deepcopy()
        list_to_tuple['list'] = tuple(list_to_tuple['list'])

        self.assertEqual(
            hash(_test_frozenconfigdict()),
            hash(ml_collections.FrozenConfigDict(_test_dict_deepcopy())))
        self.assertNotEqual(
            hash(_test_frozenconfigdict()),
            hash(ml_collections.FrozenConfigDict(list_to_tuple)))

        # Ensure Python realizes FrozenConfigDict is hashable
        self.assertIsInstance(_test_frozenconfigdict(),
                              collections_abc.Hashable)
    def test_parameterized_mixing(self, model_arch, mixing_layer_name):
        config = dummy_config(model_arch=model_arch)
        with config.unlocked():
            config.use_fft = True
        frozen_config = ml_collections.FrozenConfigDict(config)

        encoder = models.EncoderModel(config=frozen_config)

        rng = jax.random.PRNGKey(0)
        init_batch = init_encoder_batch(config=frozen_config)
        params = init_model_params(rng, encoder, init_batch)
        expected_keys = {
            "embedder", "encoder_0", "encoder_1", "feed_forward_0",
            "feed_forward_1", f"{mixing_layer_name}_0",
            f"{mixing_layer_name}_1", "pooler"
        }
        self.assertEqual(params.keys(), expected_keys)

        inputs = dummy_inputs(rng, config=frozen_config)
        encoder_output = encoder.apply({"params": params},
                                       rngs={"dropout": rng},
                                       **inputs)
        expected_sequence_output_shape = (config.train_batch_size,
                                          config.max_seq_length,
                                          config.d_model)
        self.assertEqual(encoder_output.sequence_output.shape,
                         expected_sequence_output_shape)
        expected_pooled_output_shape = (config.train_batch_size,
                                        config.d_model)
        self.assertEqual(encoder_output.pooled_output.shape,
                         expected_pooled_output_shape)
Exemple #10
0
def render_and_save():
    """Renders the image according to the configuration and saves it to disk."""

    rendering_config = configuration.get_config()
    rendering_config = ml_collections.FrozenConfigDict(rendering_config)
    aspect_ratio = rendering_config.aspect_ratio
    height = rendering_config.height
    width = int(aspect_ratio * height)

    scene_camera = build_camera(rendering_config, aspect_ratio)
    world = build_world(rendering_config)

    # Render.
    logging.info("Tracing rays...")
    render_image_fn = jax.jit(render.generate_image,
                              static_argnames=["height", "width", "config"])
    image = render_image_fn(height, width, scene_camera, world,
                            rendering_config)

    image = render.correct_gamma(image,
                                 gamma=rendering_config.gamma_correction)

    logging.info("Saving to file...")
    output.export_as_ppm(image, rendering_config.output_file)

    return image
Exemple #11
0
    def test_unparametrized_mixing_encoder(self, model_arch):
        config = dummy_config(model_arch=model_arch)
        frozen_config = ml_collections.FrozenConfigDict(config)

        encoder = models.EncoderModel(config=frozen_config)

        rng = jax.random.PRNGKey(0)
        init_batch = init_encoder_batch(config)
        params = init_model_params(rng, encoder, init_batch)
        # Unparameterized mixing encoders do not have any parameters in their mixing
        # layers, so their mixing layer names do not show up in params.
        expected_keys = {
            "embedder", "encoder_0", "encoder_1", "feed_forward_0",
            "feed_forward_1", "pooler"
        }
        self.assertEqual(params.keys(), expected_keys)

        inputs = dummy_inputs(rng, config)
        hidden_states, pooled_output = encoder.apply({"params": params},
                                                     rngs={"dropout": rng},
                                                     **inputs)
        expected_hidden_states_shape = (config.train_batch_size,
                                        config.max_seq_length, config.d_model)
        self.assertEqual(hidden_states.shape, expected_hidden_states_shape)
        expected_pooled_output_shape = (config.train_batch_size,
                                        config.d_model)
        self.assertEqual(pooled_output.shape, expected_pooled_output_shape)
Exemple #12
0
    def test_hybrid_encoder(self, attention_layout, num_attention_layers,
                            expected_attention_layers):
        config = dummy_config(model_arch=ModelArchitecture.F_NET)
        with config.unlocked():
            config.num_layers = 4
            config.attention_layout = attention_layout
            config.num_attention_layers = num_attention_layers
        frozen_config = ml_collections.FrozenConfigDict(config)

        encoder = models.EncoderModel(config=frozen_config)

        rng = jax.random.PRNGKey(0)
        init_batch = init_encoder_batch(config)
        params = init_model_params(rng, encoder, init_batch)

        expected_keys = {
            "embedder", "encoder_0", "encoder_1", "encoder_2", "encoder_3",
            "feed_forward_0", "feed_forward_1", "feed_forward_2",
            "feed_forward_3", "pooler"
        }
        for expected_attention_layer in expected_attention_layers:
            expected_keys.add(f"self_attention_{expected_attention_layer}")

        self.assertEqual(params.keys(), expected_keys)

        inputs = dummy_inputs(rng, config)
        hidden_states, pooled_output = encoder.apply({"params": params},
                                                     rngs={"dropout": rng},
                                                     **inputs)
        expected_hidden_states_shape = (config.train_batch_size,
                                        config.max_seq_length, config.d_model)
        self.assertEqual(hidden_states.shape, expected_hidden_states_shape)
        expected_pooled_output_shape = (config.train_batch_size,
                                        config.d_model)
        self.assertEqual(pooled_output.shape, expected_pooled_output_shape)
Exemple #13
0
    def test_classification_model(self):
        n_classes = 2

        config = dummy_config(model_arch=ModelArchitecture.BERT)
        with config.unlocked():
            config.dataset_name = "dummy/classification_dataset"
        frozen_config = ml_collections.FrozenConfigDict(config)

        model = models.SequenceClassificationModel(config=frozen_config,
                                                   n_classes=n_classes)

        rng = jax.random.PRNGKey(0)
        init_batch = init_encoder_batch(config)
        params = init_model_params(rng, model, init_batch)
        self.assertEqual(params.keys(), {"encoder", "classification"})

        # Logits for eval/prediction (no labels supplied).
        eval_inputs = dummy_inputs(rng, config)
        eval_inputs["deterministic"] = True
        logits = model.apply({"params": params}, **eval_inputs)
        expected_logits_shape = (config.train_batch_size, n_classes)
        self.assertEqual(jnp.shape(logits), expected_logits_shape)

        # Metrics for training (labels supplied).
        train_inputs = dummy_inputs(rng, config)
        train_inputs["labels"] = jnp.ones(config.train_batch_size, jnp.int32)
        metrics = model.apply({"params": params},
                              rngs={"dropout": rng},
                              **train_inputs)
        self.assertEqual(metrics.keys(),
                         {"loss", "correct_predictions", "num_labels"})
Exemple #14
0
    def test_regression_model(self):
        n_classes = 1  # Only one label for regression

        config = dummy_config(model_arch=ModelArchitecture.F_NET)
        with config.unlocked():
            config.dataset_name = "glue/stsb"  # regression task dataset
        frozen_config = ml_collections.FrozenConfigDict(config)

        model = models.SequenceClassificationModel(config=frozen_config,
                                                   n_classes=n_classes)

        rng = jax.random.PRNGKey(0)
        init_batch = init_encoder_batch(config)
        params = init_model_params(rng, model, init_batch)
        self.assertEqual(params.keys(), {"encoder", "classification"})

        # Logits for eval/prediction (no labels supplied).
        eval_inputs = dummy_inputs(rng, config)
        eval_inputs["deterministic"] = True
        logits = model.apply({"params": params}, **eval_inputs)
        expected_logits_shape = (config.train_batch_size, n_classes)
        self.assertEqual(jnp.shape(logits), expected_logits_shape)

        # Metrics for training (labels supplied).
        train_inputs = dummy_inputs(rng, config)
        _, label_key = jax.random.split(rng)
        train_inputs["labels"] = jax.random.uniform(
            label_key, (config.train_batch_size, ), minval=0., maxval=1.)

        metrics = model.apply({"params": params},
                              rngs={"dropout": rng},
                              **train_inputs)
        self.assertEqual(metrics.keys(), {"loss", "num_labels"})
    def test_regression_model(self):
        n_classes = 1  # Only one label for regression

        config = dummy_config(model_arch=ModelArchitecture.F_NET)
        with config.unlocked():
            config.dataset_name = "glue/stsb"  # regression task dataset
            config.num_moe_layers = 1  # Add moe layer to verify expert metrics
            num_tokens = config.train_batch_size * config.max_seq_length
            config.max_group_size = num_tokens
        frozen_config = ml_collections.FrozenConfigDict(config)

        model = models.SequenceClassificationModel(config=frozen_config,
                                                   n_classes=n_classes)

        rng = jax.random.PRNGKey(0)
        init_batch = init_encoder_batch(config=frozen_config)
        params = init_model_params(rng, model, init_batch)
        self.assertEqual(params.keys(), {"encoder", "classification"})

        # Logits for eval/prediction (no labels supplied).
        eval_inputs = dummy_inputs(rng, config=frozen_config)
        eval_inputs["deterministic"] = True
        logits = model.apply({"params": params}, **eval_inputs)
        expected_logits_shape = (config.train_batch_size, n_classes)
        self.assertEqual(jnp.shape(logits), expected_logits_shape)

        # Metrics for training (labels supplied).
        train_inputs = dummy_inputs(rng, config=frozen_config)
        label_key, dropout_key, jitter_key = jax.random.split(rng, num=3)
        train_inputs["labels"] = jax.random.uniform(
            label_key, (config.train_batch_size, ), minval=0., maxval=1.)

        metrics, state = model.apply({"params": params},
                                     rngs={
                                         "dropout": dropout_key,
                                         "jitter": jitter_key
                                     },
                                     mutable=["intermediates"],
                                     **train_inputs)

        self.assertAlmostEqual(metrics.batch_loss, 1.0100806, places=6)
        self.assertEqual(metrics.num_labels, 3)

        self.assertIn("intermediates", state)
        jax.tree_util.tree_map(
            functools.partial(np.testing.assert_allclose, rtol=1e-6),
            state["intermediates"],
            FrozenDict({
                "encoder": {
                    "moe_1": {
                        "diversity_metrics":
                        models.DiversityMetrics(
                            auxiliary_loss=1.0073242,
                            router_z_loss=0.45751953,
                            fraction_tokens_left_behind=0.25,
                            expert_usage=0.75,
                            router_confidence=0.51171875)
                    }
                }
            }))
    def testToDict(self):
        """Ensure to_dict() does not care about hidden mutability."""
        list_to_tuple = _test_dict_deepcopy()
        list_to_tuple['list'] = tuple(list_to_tuple['list'])

        self.assertEqual(
            _test_frozenconfigdict().to_dict(),
            ml_collections.FrozenConfigDict(list_to_tuple).to_dict())
def get_config():
    """Returns a training configuration."""
    config = ml_collections.ConfigDict()
    config.rng_seed = 0
    config.num_trajectories = 1
    config.single_step_predictions = True
    config.num_samples = 1000
    config.split_on = 'times'
    config.train_split_proportion = 80 / 1000
    config.time_delta = 1.
    config.train_time_jump_range = (1, 10)
    config.test_time_jumps = (1, 2, 5, 10, 20, 50)
    config.num_train_steps = 5000
    config.latent_size = 100
    config.activation = 'sigmoid'
    config.model = 'action-angle-network'
    config.encoder_decoder_type = 'flow'
    config.flow_type = 'shear'
    config.num_flow_layers = 10
    config.num_coordinates = 2
    if config.flow_type == 'masked_coupling':
        config.flow_spline_range_min = -3
        config.flow_spline_range_max = 3
        config.flow_spline_bins = 100
    config.polar_action_angles = True
    config.scaler = 'identity'
    config.learning_rate = 1e-3
    config.batch_size = 100
    config.eval_cadence = 50
    config.simulation = 'shm'
    config.regularizations = ml_collections.FrozenConfigDict({
        'actions':
        1.,
        'angular_velocities':
        0.,
        'encoded_decoded_differences':
        0.,
    })
    config.simulation_parameter_ranges = ml_collections.FrozenConfigDict({
        'phi': (0, 0),
        'A': (1, 10),
        'm': (1, 5),
        'w': (0.05, 0.1),
    })
    return config
 def testFieldReferenceResolved(self):
     """Tests that FieldReferences are resolved."""
     cfg = ml_collections.ConfigDict(
         {'fr': ml_collections.FieldReference(1)})
     frozen_cfg = ml_collections.FrozenConfigDict(cfg)
     self.assertNotIsInstance(frozen_cfg._fields['fr'],
                              ml_collections.FieldReference)
     hash(
         frozen_cfg)  # with FieldReference resolved, frozen_cfg is hashable
    def testFieldReferenceCycle(self):
        """Tests that FieldReferences may not contain reference cycles."""
        frozenset_fr = {'frozenset': frozenset({1, 2})}
        frozenset_fr['fr'] = ml_collections.FieldReference(
            frozenset_fr['frozenset'])
        list_fr = {'list': [1, 2]}
        list_fr['fr'] = ml_collections.FieldReference(list_fr['list'])

        cyclic_fr = {'a': 1}
        cyclic_fr['fr'] = ml_collections.FieldReference(cyclic_fr)
        cyclic_fr_parent = {'dict': {}}
        cyclic_fr_parent['dict']['fr'] = ml_collections.FieldReference(
            cyclic_fr_parent)

        # FieldReference is allowed to point to non-cyclic objects:
        _ = ml_collections.FrozenConfigDict(frozenset_fr)
        _ = ml_collections.FrozenConfigDict(list_fr)
        # But not cycles:
        self.assertFrozenRaisesValueError([cyclic_fr, cyclic_fr_parent])
    def testPickle(self):
        """Make sure FrozenConfigDict can be dumped and loaded with pickle."""
        fcd = _test_frozenconfigdict()
        locked_fcd = ml_collections.FrozenConfigDict(_test_configdict().lock())

        unpickled_fcd = pickle.loads(pickle.dumps(fcd))
        unpickled_locked_fcd = pickle.loads(pickle.dumps(locked_fcd))

        self.assertEqual(fcd, unpickled_fcd)
        self.assertEqual(locked_fcd, unpickled_locked_fcd)
    def testInitConfigDict(self):
        """Tests that ConfigDict initialization handles FrozenConfigDict.

    Initializing a ConfigDict on a dictionary with FrozenConfigDict values
    should unfreeze these values.
    """
        dict_without_fcd_node = _test_dict_deepcopy()
        dict_without_fcd_node.pop('ref')
        dict_with_fcd_node = copy.deepcopy(dict_without_fcd_node)
        dict_with_fcd_node['dict'] = ml_collections.FrozenConfigDict(
            dict_with_fcd_node['dict'])
        cd_without_fcd_node = ml_collections.ConfigDict(dict_without_fcd_node)
        cd_with_fcd_node = ml_collections.ConfigDict(dict_with_fcd_node)
        fcd_without_fcd_node = ml_collections.FrozenConfigDict(
            dict_without_fcd_node)
        fcd_with_fcd_node = ml_collections.FrozenConfigDict(dict_with_fcd_node)

        self.assertEqual(cd_without_fcd_node, cd_with_fcd_node)
        self.assertEqual(fcd_without_fcd_node, fcd_with_fcd_node)
Exemple #22
0
def dummy_frozen_config():
    """Creates a dummy model config that can be used by all tests."""
    config = default_config.get_config()
    config.model_arch = default_config.ModelArchitecture.FF_ONLY
    config.d_emb = 4
    config.d_model = 4
    config.d_ff = 4
    config.max_seq_length = 8
    config.num_layers = 1
    config.vocab_size = 1000
    config.train_batch_size = 2
    return ml_collections.FrozenConfigDict(config)
Exemple #23
0
def main(_):
    print_section('Attribute Types.')
    cfg = ml_collections.ConfigDict()
    cfg.int = 1
    cfg.list = [1, 2, 3]
    cfg.tuple = (1, 2, 3)
    cfg.set = {1, 2, 3}
    cfg.frozenset = frozenset({1, 2, 3})
    cfg.dict = {
        'nested_int': 4,
        'nested_list': [4, 5, 6],
        'nested_tuple': ([4], 5, 6),
    }

    print('Types of cfg fields:')
    print('list: ', type(cfg.list))  # List
    print('set: ', type(cfg.set))  # Set
    print('nested_list: ', type(cfg.dict.nested_list))  # List
    print('nested_tuple[0]: ', type(cfg.dict.nested_tuple[0]))  # List

    frozen_cfg = ml_collections.FrozenConfigDict(cfg)
    print('\nTypes of FrozenConfigDict(cfg) fields:')
    print('list: ', type(frozen_cfg.list))  # Tuple
    print('set: ', type(frozen_cfg.set))  # Frozenset
    print('nested_list: ', type(frozen_cfg.dict.nested_list))  # Tuple
    print('nested_tuple[0]: ', type(frozen_cfg.dict.nested_tuple[0]))  # Tuple

    cfg_from_frozen = ml_collections.ConfigDict(frozen_cfg)
    print('\nTypes of ConfigDict(FrozenConfigDict(cfg)) fields:')
    print('list: ', type(cfg_from_frozen.list))  # List
    print('set: ', type(cfg_from_frozen.set))  # Set
    print('nested_list: ', type(cfg_from_frozen.dict.nested_list))  # List
    print('nested_tuple[0]: ',
          type(cfg_from_frozen.dict.nested_tuple[0]))  # List

    print(
        '\nCan use FrozenConfigDict.as_configdict() to convert to ConfigDict:')
    print(cfg_from_frozen == frozen_cfg.as_configdict())  # True

    print_section('Immutability.')
    try:
        frozen_cfg.new_field = 1  # Raises AttributeError because of immutability.
    except AttributeError as e:
        print(e)

    print_section('"==" and eq_as_configdict().')
    # FrozenConfigDict.__eq__() is not type-invariant with respect to ConfigDict
    print(frozen_cfg == cfg)  # False
    # FrozenConfigDict.eq_as_configdict() is type-invariant with respect to
    # ConfigDict
    print(frozen_cfg.eq_as_configdict(cfg))  # True
    # .eq_as_congfigdict() is also a method of ConfigDict
    print(cfg.eq_as_configdict(frozen_cfg))  # True
    def test_moe(self, moe_layout, num_moe_layers, expected_ff_layer_keys,
                 expected_moe_layer_keys, expected_moe_keys):
        config = dummy_config(model_arch=ModelArchitecture.T_NET)
        with config.unlocked():
            config.num_layers = 4  # More layers so we can test different layouts
            config.moe_layout = moe_layout
            config.num_moe_layers = num_moe_layers
        frozen_config = ml_collections.FrozenConfigDict(config)

        encoder = models.EncoderModel(config=frozen_config)

        rng = jax.random.PRNGKey(0)
        init_batch = init_encoder_batch(config=frozen_config)
        params = init_model_params(rng, encoder, init_batch)

        expected_keys = {
            "embedder", "encoder_0", "encoder_1", "encoder_2", "encoder_3",
            "pooler", "toeplitz_transform_0", "toeplitz_transform_1",
            "toeplitz_transform_2", "toeplitz_transform_3"
        }
        expected_keys.update(expected_ff_layer_keys)
        expected_keys.update(expected_moe_keys)

        self.assertEqual(params.keys(), expected_keys)

        rng, dropout_key, jitter_key = jax.random.split(rng, num=3)
        inputs = dummy_inputs(rng, config=frozen_config)
        encoder_output, state = encoder.apply({"params": params},
                                              rngs={
                                                  "dropout": dropout_key,
                                                  "jitter": jitter_key
                                              },
                                              mutable=["intermediates"],
                                              **inputs)

        expected_sequence_output_shape = (config.train_batch_size,
                                          config.max_seq_length,
                                          config.d_model)
        self.assertEqual(encoder_output.sequence_output.shape,
                         expected_sequence_output_shape)
        expected_pooled_output_shape = (config.train_batch_size,
                                        config.d_model)
        self.assertEqual(encoder_output.pooled_output.shape,
                         expected_pooled_output_shape)

        if num_moe_layers > 0:
            self.assertIn("intermediates", state)
            for layer in expected_moe_layer_keys:
                self.assertIn(layer, state["intermediates"])
                self.assertIn("diversity_metrics",
                              state["intermediates"][layer])
        else:
            self.assertNotIn("intermediates", state)
Exemple #25
0
    def test_embedding_layer(self):
        config = ml_collections.ConfigDict({
            "batch_size": 3,
            "vocab_size": 1000,
            "d_emb": 32,
            "max_seq_length": 64,
            "type_vocab_size": 2,
            "d_model": 4,
            "dropout_rate": 0.1,
            "dtype": jnp.float32
        })
        frozen_config = ml_collections.FrozenConfigDict(config)
        rng = jax.random.PRNGKey(100)

        embedding_layer = layers.EmbeddingLayer(config=frozen_config)
        init_batch = {
            "input_ids": jnp.ones((1, frozen_config.max_seq_length),
                                  jnp.int32),
            "type_ids": jnp.ones((1, frozen_config.max_seq_length), jnp.int32)
        }
        params = init_layer_variables(rng, embedding_layer,
                                      init_batch)["params"]

        expected_keys = {
            "word", "position", "type", "layer_norm", "hidden_mapping_in"
        }
        self.assertEqual(params.keys(), expected_keys)

        rng, init_rng = jax.random.split(rng)
        inputs = {
            "input_ids":
            jax.random.randint(
                init_rng,
                (frozen_config.batch_size, frozen_config.max_seq_length),
                minval=0,
                maxval=13),
            "type_ids":
            jax.random.randint(
                init_rng,
                (frozen_config.batch_size, frozen_config.max_seq_length),
                minval=0,
                maxval=2)
        }
        outputs = embedding_layer.apply({"params": params},
                                        rngs={"dropout": rng},
                                        **inputs)
        self.assertEqual(outputs.shape,
                         (frozen_config.batch_size,
                          frozen_config.max_seq_length, frozen_config.d_model))
Exemple #26
0
    def test_f_net_encoder_bad_long_seq(self):
        config = dummy_config(model_arch=ModelArchitecture.F_NET)
        with config.unlocked():
            config.max_seq_length = 8194
        frozen_config = ml_collections.FrozenConfigDict(config)

        encoder = models.EncoderModel(config=frozen_config)

        rng = jax.random.PRNGKey(0)
        init_batch = init_encoder_batch(config)

        with self.assertRaisesRegex(
                ValueError,
                "must be a power of 2 to take advantage of FFT optimizations"):
            _ = init_model_params(rng, encoder, init_batch)
Exemple #27
0
    def test_model_shape(
        self,
        text_length,
        n_mentions,
        n_linked_mentions,
        no_entity_attention,
    ):
        """Test model forward runs and produces expected shape."""

        config = copy.deepcopy(self.config)
        config['model_config']['encoder_config'][
            'no_entity_attention'] = no_entity_attention
        config = ml_collections.FrozenConfigDict(self.config)
        model_config = config.model_config
        encoder_config = model_config.encoder_config

        max_length = model_config.encoder_config.max_length
        preprocess_fn = eae_task.EaETask.make_preprocess_fn(config)
        collater_fn = eae_task.EaETask.make_collater_fn(config)

        raw_example = test_utils.gen_mention_pretraining_sample(
            text_length, n_mentions, n_linked_mentions, max_length=max_length)
        processed_example = preprocess_fn(raw_example)
        batch = {
            key: np.tile(value, (config.per_device_batch_size, 1))
            for key, value in processed_example.items()
        }
        batch = collater_fn(batch)
        batch = jax.tree_map(np.asarray, batch)

        model = eae_encoder.EaEEncoder(**model_config.encoder_config)
        init_rng = jax.random.PRNGKey(0)
        (encoded_output, loss_helpers, _), _ = model.init_with_output(
            init_rng,
            batch,
            deterministic=True,
            method=model.forward,
        )

        self.assertEqual(
            encoded_output.shape,
            (config.per_device_batch_size, encoder_config.max_length,
             encoder_config.hidden_size))

        self.assertEqual(
            loss_helpers['target_mention_encodings'].shape,
            (config.max_mention_targets * config.per_device_batch_size,
             encoder_config.entity_dim))
Exemple #28
0
    def test_dataset_standard_batching(self):
        dataset_name = 'control_flow_programs/decimal-L10'

        data_dir = tempfile.mkdtemp()
        config = config_lib.get_config()

        with config.unlocked():
            config.dataset.name = dataset_name
            config.dataset.in_memory = True
            config.dataset.batch_size = 5
            config.dataset.representation = 'trace'
        config = ml_collections.FrozenConfigDict(config)

        dataset_info = dataset_utils.get_dataset(data_dir, config)
        item = next(iter(dataset_info.dataset))

        self.assertEqual(item['cfg']['data'].shape[0], 5)
    def test_output_expected(self):

        model_config_dict = {
            'dtype': 'float32',
            'features': [32, 32],
            'use_bias': False,
        }

        config_dict = {
            'model_config': model_config_dict,
            'per_device_batch_size': 32,
            'seed': 0,
        }
        config = ml_collections.ConfigDict(config_dict)
        model_config = ml_collections.FrozenConfigDict(model_config_dict)

        batch = {
            'x':
            jnp.zeros(
                (config.per_device_batch_size,
                 config.model_config.features[0]),
                dtype=config.model_config.dtype,
            )
        }

        init_rng = jax.random.PRNGKey(config.seed)

        # Create dummy input
        dummy_input = example_task.ExampleTask.dummy_input(config)

        model = example_task.ExampleTask.build_model(model_config)
        initial_variables = jax.jit(model.init)(init_rng, dummy_input, False)
        loss_fn = example_task.ExampleTask.make_loss_fn(config)

        loss, _, _ = loss_fn(
            model_config=model_config,
            model_params=initial_variables['params'],
            model_vars={},
            batch=batch,
            deterministic=False,
        )

        self.assertEqual(loss, 0.0)
def frozen_config(sharded_params=False):
    """Creates a dummy model config that can be used by all tests."""
    config = default_config.get_config()
    config.model_arch = default_config.ModelArchitecture.LINEAR.name
    config.num_attention_layers = 0
    config.d_emb = 4
    config.d_model = 4
    config.d_ff = 4
    config.max_seq_length = 8
    config.num_layers = 1
    config.vocab_size = 16
    config.train_batch_size = 4
    config.dtype = jnp.float32
    config.pad_id = 3
    # MoE layers contain sharded parameters.
    config.num_moe_layers = 1 if sharded_params else 0
    config.num_experts = 1 if sharded_params else 0
    config.auxiliary_loss_factor = 0.01
    config.router_z_loss_factor = 0.01
    return ml_collections.FrozenConfigDict(config)