예제 #1
0
 def test_cross_entropy_loss_fn(self):
   for data in CROSS_ENTROPY_TEST_DATA:
     sigmoid_binary_ce_fn = losses.get_loss_fn('sigmoid_binary_cross_entropy')
     ce_fn = losses.get_loss_fn('cross_entropy')
     self.assertAlmostEqual(
         sigmoid_binary_ce_fn(
             np.array([logits[0] - logits[1] for logits in data['logits']]),
             np.array([targets[0] for targets in data['targets']]),
             data['weights']),
         ce_fn(data['logits'], data['targets'], data['weights']),
         places=5)
예제 #2
0
 def __init__(self, hps, dataset_meta_data, loss_name, metrics_name):
     self.hps = hps
     self.dataset_meta_data = dataset_meta_data
     self.loss_fn = losses.get_loss_fn(loss_name)
     self.output_activation_fn = losses.get_output_activation_fn(loss_name)
     self.metrics_bundle = metrics.get_metrics(metrics_name)
     self.flax_module_def = self.build_flax_module()
예제 #3
0
    def test_meta_loss(self, model_name):
        """Test that meta_init does not update the bias scalars."""

        rng = jax.random.PRNGKey(0)
        flax_module, params, input_shape, _ = _load_model(model_name)
        norms = jax.tree_map(lambda node: jnp.linalg.norm(node.reshape(-1)),
                             params)
        normalized_params = jax.tree_map(meta_init.normalize, params)
        loss_name = 'cross_entropy'
        loss_fn = losses.get_loss_fn(loss_name)
        learned_norms, _ = meta_init.meta_optimize_scales(
            loss_fn=loss_fn,
            fprop=flax_module.apply,
            normalized_params=normalized_params,
            norms=norms,
            hps=meta_init.DEFAULT_HPARAMS,
            input_shape=input_shape[1:],
            output_shape=OUTPUT_SHAPE,
            rng_key=rng)

        # Check that all learned bias scales are 0, the meta loss should be
        # independent of these terms.
        learned_norms_flat = model_utils.flatten_dict(learned_norms)
        for layer_key in learned_norms_flat:
            if 'bias' in layer_key:
                self.assertEqual(learned_norms_flat[layer_key], 0.0)
예제 #4
0
    def test_sparse_init(self):
        """Test that sparse_init produces sparse params."""

        rng = jax.random.PRNGKey(0)
        flax_module, params, input_shape, model_hps = _load_model(
            'fully_connected')
        non_zero_connection_weights = 3
        init_hps = sparse_init.DEFAULT_HPARAMS
        init_hps['non_zero_connection_weights'] = non_zero_connection_weights
        init_hps.update(model_hps)
        loss_name = 'cross_entropy'
        loss_fn = losses.get_loss_fn(loss_name)
        new_params = sparse_init.sparse_init(loss_fn=loss_fn,
                                             flax_module=flax_module,
                                             params=params,
                                             hps=init_hps,
                                             input_shape=input_shape[1:],
                                             output_shape=OUTPUT_SHAPE,
                                             rng_key=rng)

        # Check new params are sparse
        for key in new_params:
            num_units = new_params[key]['kernel'].shape[0]
            self.assertEqual(jnp.count_nonzero(new_params[key]['kernel']),
                             num_units * non_zero_connection_weights)
            self.assertEqual(jnp.count_nonzero(new_params[key]['bias']), 0)
예제 #5
0
    def test_initializers(self, init):
        """Test that each initializer runs, and the output is a valid pytree."""

        rng = jax.random.PRNGKey(0)
        flax_module, params, input_shape, model_hps = _load_model(
            'fully_connected')
        _, init_rng = jax.random.split(rng)
        initializer = initializers.get_initializer(init)
        init_hps = initializers.get_initializer_hparams(init)
        init_hps.update(model_hps)
        loss_name = 'cross_entropy'
        loss_fn = losses.get_loss_fn(loss_name)
        new_params = initializer(loss_fn=loss_fn,
                                 flax_module=flax_module,
                                 params=params,
                                 hps=init_hps,
                                 input_shape=input_shape[1:],
                                 output_shape=OUTPUT_SHAPE,
                                 rng_key=init_rng)

        # Check new params are still valid params
        outputs = flax_module.apply({'params': new_params},
                                    jnp.ones(input_shape),
                                    train=True)
        utils.log_pytree_shape_and_statistics(new_params)
        self.assertEqual(outputs.shape, (input_shape[0], OUTPUT_SHAPE[-1]))
예제 #6
0
  def test_ctc_loss(self, logits, labels, result):
    """Tests the CTC loss computation."""
    ctc_loss = losses.get_loss_fn('ctc')
    loss_value = ctc_loss(logits, np.zeros(logits.shape[:2]), labels,
                          np.zeros(labels.shape))

    self.assertAlmostEqual(loss_value, jax.numpy.array([result]), places=6)
예제 #7
0
 def test_regression_losses(self, loss_name):
   loss_fn = losses.get_loss_fn(loss_name)
   for data in RECONSTRUCTION_TEST_DATA:
     self.assertAlmostEqual(
         loss_fn(data['logits'], data['targets'], data['weights']),
         data[loss_name],
         places=6)
예제 #8
0
 def test_classification_losses(self, loss_name):
   loss_fn = losses.get_loss_fn(loss_name)
   for data in CLASSIFICATION_TEST_DATA:
     self.assertAlmostEqual(
         loss_fn(data['logits'], data['one_hot_targets'], data['weights']),
         data[loss_name],
         places=5)
예제 #9
0
  def test_sigmoid_cross_entropy_per_label_weights(self):
    """Tests whether per label weights mask the correct entries."""
    sigmoid_binary_ce_fn = losses.get_loss_fn('sigmoid_binary_cross_entropy')
    logits = np.arange(15).reshape(3, 5)
    targets = np.arange(15, 30).reshape(3, 5)

    per_label_weights = np.array([
        [1, 1, 1, 1, 0],
        [1, 1, 1, 1, 0],
        [0, 0, 0, 0, 0],
    ])
    per_example_weights = np.array([1, 1, 0])

    # Both calls normalize by the sum of weights, which is higher in the
    # per-label case.
    self.assertAlmostEqual(
        sigmoid_binary_ce_fn(logits, targets, per_label_weights),
        sigmoid_binary_ce_fn(logits[:, :4], targets[:, :4],
                             per_example_weights) / 4)
예제 #10
0
 def test_loss_fn_registry(self):
   for loss_name in losses._ALL_LOSS_FUNCTIONS:  # pylint: disable=protected-access
     loss_fn = losses.get_loss_fn(loss_name)
     self.assertIsInstance(loss_fn, types.FunctionType)
   with self.assertRaises(ValueError):
     losses.get_loss_fn('__test__loss__name__')
예제 #11
0
  def test_weighted_mean_absolute_error(self, logits, targets, result):
    """Tests computing MAE."""
    mae = losses.get_loss_fn('mean_absolute_error')
    loss_value = mae(logits, targets)

    self.assertAlmostEqual(loss_value, jax.numpy.array([result]))