def test_cross_entropy_loss_fn(self): for data in CROSS_ENTROPY_TEST_DATA: sigmoid_binary_ce_fn = losses.get_loss_fn('sigmoid_binary_cross_entropy') ce_fn = losses.get_loss_fn('cross_entropy') self.assertAlmostEqual( sigmoid_binary_ce_fn( np.array([logits[0] - logits[1] for logits in data['logits']]), np.array([targets[0] for targets in data['targets']]), data['weights']), ce_fn(data['logits'], data['targets'], data['weights']), places=5)
def __init__(self, hps, dataset_meta_data, loss_name, metrics_name): self.hps = hps self.dataset_meta_data = dataset_meta_data self.loss_fn = losses.get_loss_fn(loss_name) self.output_activation_fn = losses.get_output_activation_fn(loss_name) self.metrics_bundle = metrics.get_metrics(metrics_name) self.flax_module_def = self.build_flax_module()
def test_meta_loss(self, model_name): """Test that meta_init does not update the bias scalars.""" rng = jax.random.PRNGKey(0) flax_module, params, input_shape, _ = _load_model(model_name) norms = jax.tree_map(lambda node: jnp.linalg.norm(node.reshape(-1)), params) normalized_params = jax.tree_map(meta_init.normalize, params) loss_name = 'cross_entropy' loss_fn = losses.get_loss_fn(loss_name) learned_norms, _ = meta_init.meta_optimize_scales( loss_fn=loss_fn, fprop=flax_module.apply, normalized_params=normalized_params, norms=norms, hps=meta_init.DEFAULT_HPARAMS, input_shape=input_shape[1:], output_shape=OUTPUT_SHAPE, rng_key=rng) # Check that all learned bias scales are 0, the meta loss should be # independent of these terms. learned_norms_flat = model_utils.flatten_dict(learned_norms) for layer_key in learned_norms_flat: if 'bias' in layer_key: self.assertEqual(learned_norms_flat[layer_key], 0.0)
def test_sparse_init(self): """Test that sparse_init produces sparse params.""" rng = jax.random.PRNGKey(0) flax_module, params, input_shape, model_hps = _load_model( 'fully_connected') non_zero_connection_weights = 3 init_hps = sparse_init.DEFAULT_HPARAMS init_hps['non_zero_connection_weights'] = non_zero_connection_weights init_hps.update(model_hps) loss_name = 'cross_entropy' loss_fn = losses.get_loss_fn(loss_name) new_params = sparse_init.sparse_init(loss_fn=loss_fn, flax_module=flax_module, params=params, hps=init_hps, input_shape=input_shape[1:], output_shape=OUTPUT_SHAPE, rng_key=rng) # Check new params are sparse for key in new_params: num_units = new_params[key]['kernel'].shape[0] self.assertEqual(jnp.count_nonzero(new_params[key]['kernel']), num_units * non_zero_connection_weights) self.assertEqual(jnp.count_nonzero(new_params[key]['bias']), 0)
def test_initializers(self, init): """Test that each initializer runs, and the output is a valid pytree.""" rng = jax.random.PRNGKey(0) flax_module, params, input_shape, model_hps = _load_model( 'fully_connected') _, init_rng = jax.random.split(rng) initializer = initializers.get_initializer(init) init_hps = initializers.get_initializer_hparams(init) init_hps.update(model_hps) loss_name = 'cross_entropy' loss_fn = losses.get_loss_fn(loss_name) new_params = initializer(loss_fn=loss_fn, flax_module=flax_module, params=params, hps=init_hps, input_shape=input_shape[1:], output_shape=OUTPUT_SHAPE, rng_key=init_rng) # Check new params are still valid params outputs = flax_module.apply({'params': new_params}, jnp.ones(input_shape), train=True) utils.log_pytree_shape_and_statistics(new_params) self.assertEqual(outputs.shape, (input_shape[0], OUTPUT_SHAPE[-1]))
def test_ctc_loss(self, logits, labels, result): """Tests the CTC loss computation.""" ctc_loss = losses.get_loss_fn('ctc') loss_value = ctc_loss(logits, np.zeros(logits.shape[:2]), labels, np.zeros(labels.shape)) self.assertAlmostEqual(loss_value, jax.numpy.array([result]), places=6)
def test_regression_losses(self, loss_name): loss_fn = losses.get_loss_fn(loss_name) for data in RECONSTRUCTION_TEST_DATA: self.assertAlmostEqual( loss_fn(data['logits'], data['targets'], data['weights']), data[loss_name], places=6)
def test_classification_losses(self, loss_name): loss_fn = losses.get_loss_fn(loss_name) for data in CLASSIFICATION_TEST_DATA: self.assertAlmostEqual( loss_fn(data['logits'], data['one_hot_targets'], data['weights']), data[loss_name], places=5)
def test_sigmoid_cross_entropy_per_label_weights(self): """Tests whether per label weights mask the correct entries.""" sigmoid_binary_ce_fn = losses.get_loss_fn('sigmoid_binary_cross_entropy') logits = np.arange(15).reshape(3, 5) targets = np.arange(15, 30).reshape(3, 5) per_label_weights = np.array([ [1, 1, 1, 1, 0], [1, 1, 1, 1, 0], [0, 0, 0, 0, 0], ]) per_example_weights = np.array([1, 1, 0]) # Both calls normalize by the sum of weights, which is higher in the # per-label case. self.assertAlmostEqual( sigmoid_binary_ce_fn(logits, targets, per_label_weights), sigmoid_binary_ce_fn(logits[:, :4], targets[:, :4], per_example_weights) / 4)
def test_loss_fn_registry(self): for loss_name in losses._ALL_LOSS_FUNCTIONS: # pylint: disable=protected-access loss_fn = losses.get_loss_fn(loss_name) self.assertIsInstance(loss_fn, types.FunctionType) with self.assertRaises(ValueError): losses.get_loss_fn('__test__loss__name__')
def test_weighted_mean_absolute_error(self, logits, targets, result): """Tests computing MAE.""" mae = losses.get_loss_fn('mean_absolute_error') loss_value = mae(logits, targets) self.assertAlmostEqual(loss_value, jax.numpy.array([result]))