Example #1
0
 def test_model_number(self):
     """Test that the function simply splits flat between its first and remaining elements
     when the model is a number"""
     unflattened = unflatten_tf(self.flat, 0)
     assert tf.equal(unflattened[0], 0)
     assert all(
         tf.equal(unflattened[1], tf.constant([i for i in range(1, 12)])))
Example #2
0
    def test_model_iterable(self):
        """Test that the function correctly unflattens when the model is a list of numbers,
        which should result in unflatten_tf returning a list of tensors"""
        model = [1] * 12
        unflattened = unflatten_tf(self.flat, model)

        assert all([i.numpy().shape == () for i in unflattened[0]])
        assert unflattened[1].numpy().size == 0
Example #3
0
    def test_model_nested_tensor(self):
        """Test that the function correctly unflattens when the model is a nested tensor,
        which should result in unflatten_tf returning a list of tensors of the same shape"""
        model = [tf.ones(3), tf.ones((2, 2)), tf.ones((3, 1)), tf.ones((1, 2))]
        unflattened = unflatten_tf(self.flat, model)

        assert all([
            u.numpy().shape == model[i].numpy().shape
            for i, u in enumerate(unflattened[0])
        ])
        assert unflattened[1].numpy().size == 0
Example #4
0
    def test_model_tensor(self):
        """Test that function correctly takes the first elements of flat and reshapes it into the
        model tensor, while leaving the remaining elements as a flat tensor"""
        model = tf.ones((3, 3))
        unflattened = unflatten_tf(self.flat, model)

        target = tf.reshape(self.flat[:9], (3, 3))
        remaining = self.flat[-3:]

        assert np.allclose(unflattened[0].numpy(), target.numpy())
        assert np.allclose(unflattened[1].numpy(), remaining.numpy())