def test_model_number(self): """Test that the function simply splits flat between its first and remaining elements when the model is a number""" unflattened = unflatten_tf(self.flat, 0) assert tf.equal(unflattened[0], 0) assert all( tf.equal(unflattened[1], tf.constant([i for i in range(1, 12)])))
def test_model_iterable(self): """Test that the function correctly unflattens when the model is a list of numbers, which should result in unflatten_tf returning a list of tensors""" model = [1] * 12 unflattened = unflatten_tf(self.flat, model) assert all([i.numpy().shape == () for i in unflattened[0]]) assert unflattened[1].numpy().size == 0
def test_model_nested_tensor(self): """Test that the function correctly unflattens when the model is a nested tensor, which should result in unflatten_tf returning a list of tensors of the same shape""" model = [tf.ones(3), tf.ones((2, 2)), tf.ones((3, 1)), tf.ones((1, 2))] unflattened = unflatten_tf(self.flat, model) assert all([ u.numpy().shape == model[i].numpy().shape for i, u in enumerate(unflattened[0]) ]) assert unflattened[1].numpy().size == 0
def test_model_tensor(self): """Test that function correctly takes the first elements of flat and reshapes it into the model tensor, while leaving the remaining elements as a flat tensor""" model = tf.ones((3, 3)) unflattened = unflatten_tf(self.flat, model) target = tf.reshape(self.flat[:9], (3, 3)) remaining = self.flat[-3:] assert np.allclose(unflattened[0].numpy(), target.numpy()) assert np.allclose(unflattened[1].numpy(), remaining.numpy())