def test_MarginInstanceEmbeddingLoss_training(): '''Verifies that the InstanceMeanIoUEmbeddingLoss can be used to learn a simple thresholding operation.''' def compute_instance_dist(model, raw, yt): labels = yt.astype(int).squeeze(axis=-1) pred = model(raw, training=False).numpy() c1 = pred[labels == 1].mean(axis=0) c2 = pred[labels == 2].mean(axis=0) return np.linalg.norm(c1 - c2) set_seeds(25) raw = np.random.normal(size=(1, 10, 10, 1)).astype(np.float32) yt = (raw > 0.0).astype(np.int32) + 1 dataset = tf.data.Dataset.from_tensors((raw, yt)).repeat(100) model = tf.keras.models.Sequential([ tf.keras.layers.Conv2D(10, kernel_size=1, padding='same', activation='relu'), tf.keras.layers.Conv2D(10, kernel_size=1, padding='same', activation='relu'), tf.keras.layers.Conv2D(2, kernel_size=1, padding='same', activation=None), ]) model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.1), loss=MarginInstanceEmbeddingLoss(2, 6)) mean_dist_before = compute_instance_dist(model, raw, yt) loss_before = model.evaluate(dataset) model.fit(dataset, epochs=1) mean_dist_after = compute_instance_dist(model, raw, yt) loss_after = model.evaluate(dataset) assert loss_before * 0.95 >= loss_after assert mean_dist_before < mean_dist_after assert loss_after < 0.005
def test__unbatched_label_to_hot(): _unbatched_label_to_hot = DummySpatialInstanceEmbeddingLoss( )._unbatched_label_to_hot set_seeds(25) labels = np.random.choice(range(5), size=(10, 10, 1)).astype(np.int32) hot_labels = _unbatched_label_to_hot(labels) # #channels == #unique labels - bg assert hot_labels.shape == (10, 10, 4) for idx, l in enumerate([1, 2, 3, 4]): hot_slice = hot_labels[..., idx].numpy().astype(bool) l_mask = labels.squeeze() == l np.testing.assert_array_equal(hot_slice, l_mask)
def test_relabel_sequential(): set_seeds(25) labels = np.random.choice([-1, 0, 2, 3, 4, 5], size=(10, 10, 1)).astype(np.int32) # already sequential labels sk_sequential_labels = sk_relabel_sequential(labels + 1)[0] - 1 tf_sequential_labels = relabel_sequential(labels) assert set(np.unique(sk_sequential_labels)) == set( np.unique(tf_sequential_labels)) # non sequential labels labels[labels == 2] = 0 labels[labels == 4] = -1 sk_sequential_labels = sk_relabel_sequential(labels + 1)[0] - 1 tf_sequential_labels = relabel_sequential(labels) assert set(np.unique(sk_sequential_labels)) == set( np.unique(tf_sequential_labels))
def test_MarginInstanceEmbeddingLoss(intra_margin, inter_margin): margin_loss = MarginInstanceEmbeddingLoss(intra_margin, inter_margin) # random labels, 5 classes, batch size = 4 set_seeds(11) yt = np.random.choice(range(5), size=(4, 10, 10, 1)).astype(np.int32) # perfect embedding of size 10, more than inter_margin appart from each other yp_prefect = np.tile(yt, (1, 1, 1, 10)) * 1.1 * inter_margin yp_prefect = yp_prefect.astype(np.float32) loss_perfect = margin_loss(yt, yp_prefect) np.testing.assert_almost_equal(loss_perfect, 0.) # batch 1, 1d sample with 2 elements, single instance and embeddign of size 1 yt = np.ones((1, 2, 1), dtype=np.int32) yp = np.array([[[1], [1]]], dtype=np.float32) np.testing.assert_almost_equal(margin_loss(yt, yp), 0.) yp = np.array([[[1], [1 + intra_margin]]], dtype=np.float32) np.testing.assert_almost_equal(margin_loss(yt, yp), 0.) yp = np.array([[[1], [1 + 2 * intra_margin]]], dtype=np.float32) np.testing.assert_almost_equal(margin_loss(yt, yp), 0.) yp = np.array([[[1], [1 + 2.1 * intra_margin]]], dtype=np.float32) assert margin_loss(yt, yp) > 0 yp = np.array([[[1], [1 + 10 * intra_margin]]], dtype=np.float32) assert margin_loss(yt, yp) > 0 # batch 1, 1d sample with 2 elements, 2 instances and embeddign of size 1 yt = np.array([[[1], [2]]], dtype=np.int32) yp = np.array([[[1], [1]]], dtype=np.float32) assert margin_loss(yt, yp) > 0. yp = np.array([[[1], [1 + 0.5 * inter_margin]]], dtype=np.float32) assert margin_loss(yt, yp) > 0 yp = np.array([[[1], [1 + 1. * inter_margin]]], dtype=np.float32) np.testing.assert_almost_equal(margin_loss(yt, yp), 0.) yp = np.array([[[1], [1 + 2. * inter_margin]]], dtype=np.float32) np.testing.assert_almost_equal(margin_loss(yt, yp), 0.)
def test__unbatched_embedding_center(): _unbatched_label_to_hot = DummySpatialInstanceEmbeddingLoss( )._unbatched_label_to_hot _unbatched_embedding_center = DummySpatialInstanceEmbeddingLoss( )._unbatched_embedding_center set_seeds(25) labels = np.random.choice(range(5), size=(10, 10, 1)).astype(np.int32) hot_labels = _unbatched_label_to_hot(labels) yp = np.random.rand(10, 10, 3).astype(np.float32) centers = _unbatched_embedding_center(hot_labels, yp) assert centers.shape == (1, 1, 4, 3) expected_centers = np.stack([ label_mean(p, labels.squeeze(), [1, 2, 3, 4]) for p in np.moveaxis(yp, -1, 0) ], axis=-1) np.testing.assert_array_almost_equal(centers.numpy().squeeze(), expected_centers)
def test_BinaryJaccardLoss_training(): '''Verifies that the BinaryJaccardLoss can be used to learn a simple thresholding operation.''' set_seeds(25) raw = np.random.normal(size=(1, 10, 10, 1)).astype(np.float32) yt = (raw > 0.0).astype(np.float32) dataset = tf.data.Dataset.from_tensors((raw, yt)) model = tf.keras.models.Sequential([ tf.keras.layers.Conv2D(1, kernel_size=1, padding='same', activation='sigmoid'), ]) model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=10.), loss=BinaryJaccardLoss()) loss_before = model.evaluate(dataset) model.fit(dataset, epochs=100) loss_after = model.evaluate(dataset) assert loss_before * 0.95 >= loss_after assert loss_after < 0.001
def test_InstanceMeanIoUEmbeddingLoss(): set_seeds(25) n_classes = 5 # random labels, 5 classes, batch size = 4 yt = np.random.choice(range(n_classes), size=(4, 10, 10, 1)).astype(np.int32) yp_prefect = np.broadcast_to(yt.astype(np.float32), (4, 10, 10, 1)) loss_perfect = InstanceMeanIoUEmbeddingLoss(margin=0.001)( yt, yp_prefect).numpy() loss_clipped = InstanceMeanIoUEmbeddingLoss(margin=0.001, clip_probs=(0.01, 0.99))( yt, yp_prefect).numpy() loss_marginA = InstanceMeanIoUEmbeddingLoss(margin=0.5)( yt, yp_prefect).numpy() loss_marginB = InstanceMeanIoUEmbeddingLoss(margin=0.7)( yt, yp_prefect).numpy() np.testing.assert_almost_equal(loss_perfect, 0.) assert loss_perfect < loss_clipped assert loss_perfect < loss_marginA assert loss_marginA < loss_marginB