def test_full_transform_scalar(self):
        """Test extracting the full transform when only a scalar is provided.
        """
        # Construct the expected matrix
        expected = np.zeros((4, 8))
        expected[np.array([1, 2, 3]), np.array([5, 6, 7])] = 1.5

        # Create the transform
        t = Transform(size_in=8, slice_in=slice(5, None),
                      transform=1.5,
                      size_out=4, slice_out=[1, 2, 3])

        # Check the full transform is correct
        assert np.array_equal(
            t.full_transform(slice_in=True, slice_out=False),
            expected[:, 5:]
        )

        assert np.array_equal(
            t.full_transform(slice_in=False, slice_out=True),
            expected[1:]
        )

        assert np.array_equal(
            t.full_transform(slice_in=False, slice_out=False),
            expected
        )
    def test_full_transform_matrix(self):
        """Test extracting the full transform when a matrix is provided.
        """
        # Construct the expected matrix
        expected = np.zeros((4, 8))
        matrix = np.arange(9)
        matrix.shape = (3, 3)
        expected[:3, 3:6] = matrix

        # Create the transform
        t = Transform(size_in=8, slice_in=[3, 4, 5],
                      transform=matrix,
                      size_out=4, slice_out=slice(3))

        # Check the full transform is correct
        assert np.array_equal(
            t.full_transform(slice_in=True, slice_out=False),
            expected[:, 3:6]
        )

        assert np.array_equal(
            t.full_transform(slice_in=False, slice_out=True),
            expected[:3]
        )

        assert np.array_equal(
            t.full_transform(slice_in=False, slice_out=False),
            expected
        )
    def test_full_transform_matrix(self):
        """Test extracting the full transform when a matrix is provided.
        """
        # Construct the expected matrix
        expected = np.zeros((4, 8))
        matrix = np.arange(9)
        matrix.shape = (3, 3)
        expected[:3, 3:6] = matrix

        # Create the transform
        t = Transform(size_in=8,
                      slice_in=[3, 4, 5],
                      transform=matrix,
                      size_out=4,
                      slice_out=slice(3))

        # Check the full transform is correct
        assert np.array_equal(t.full_transform(slice_in=True, slice_out=False),
                              expected[:, 3:6])

        assert np.array_equal(t.full_transform(slice_in=False, slice_out=True),
                              expected[:3])

        assert np.array_equal(
            t.full_transform(slice_in=False, slice_out=False), expected)
    def test_full_transform_vector(self):
        """Test extracting the full transform when only a vector is provided.
        """
        # Construct the expected matrix
        expected = np.zeros((4, 8))
        diag = np.array([1.0, 2.0, 3.0])
        expected[np.array([0, 1, 2]), np.array([3, 4, 5])] = diag

        # Create the transform
        t = Transform(size_in=8, slice_in=[3, 4, 5],
                      transform=diag,
                      size_out=4, slice_out=slice(3))

        # Check the full transform is correct
        assert np.array_equal(
            t.full_transform(slice_in=True, slice_out=False),
            expected[:, 3:6]
        )

        assert np.array_equal(
            t.full_transform(slice_in=False, slice_out=True),
            expected[:3]
        )

        assert np.array_equal(
            t.full_transform(slice_in=False, slice_out=False),
            expected
        )
    def test_concat_empty_transform_zero_transform_trivial(self):
        # Build the transforms
        A = Transform(size_in=2, transform=0, size_out=2)
        B = Transform(size_in=2, transform=1, size_out=2)

        # Check that the test is correct
        expected = np.dot(B.full_transform(False, False),
                          A.full_transform(False, False))
        assert not np.any(expected), "Test broken"

        # Combine the transforms, this should return None to indicate that the
        # transform is empty.
        assert A.concat(B) is None
    def test_concat_empty_transform_mismatched_slicing(self):
        # Build the transforms
        A = Transform(size_in=1, transform=1, size_out=2, slice_out=[0])
        B = Transform(size_in=2, slice_in=[1], transform=1, size_out=1)

        # Check that the test is correct
        expected = np.dot(B.full_transform(False, False),
                          A.full_transform(False, False))
        assert not np.any(expected), "Test broken"

        # Combine the transforms, this should return None to indicate that the
        # transform is empty.
        assert A.concat(B) is None
    def test_concat_empty_transform_zero_transform_trivial(self):
        # Build the transforms
        A = Transform(size_in=2, transform=0, size_out=2)
        B = Transform(size_in=2, transform=1, size_out=2)

        # Check that the test is correct
        expected = np.dot(B.full_transform(False, False),
                          A.full_transform(False, False))
        assert not np.any(expected), "Test broken"

        # Combine the transforms, this should return None to indicate that the
        # transform is empty.
        assert A.concat(B) is None
    def test_concat_empty_transform_mismatched_slicing(self):
        # Build the transforms
        A = Transform(size_in=1, transform=1, size_out=2, slice_out=[0])
        B = Transform(size_in=2, slice_in=[1], transform=1, size_out=1)

        # Check that the test is correct
        expected = np.dot(B.full_transform(False, False),
                          A.full_transform(False, False))
        assert not np.any(expected), "Test broken"

        # Combine the transforms, this should return None to indicate that the
        # transform is empty.
        assert A.concat(B) is None
    def test_concat(self, a_params, b_params):
        # Build the transforms
        A = Transform(size_in=8, size_out=8, **a_params)
        B = Transform(size_in=8, size_out=4, **b_params)

        # Compute the expected combined transform
        expected = np.dot(B.full_transform(False, False),
                          A.full_transform(False, False))

        # Combine the transforms
        C = A.concat(B)

        # Test
        assert np.array_equal(expected, C.full_transform(False, False))
    def test_concat(self, a_params, b_params):
        # Build the transforms
        A = Transform(size_in=8, size_out=8, **a_params)
        B = Transform(size_in=8, size_out=4, **b_params)

        # Compute the expected combined transform
        expected = np.dot(B.full_transform(False, False),
                          A.full_transform(False, False))

        # Combine the transforms
        C = A.concat(B)

        # Test
        assert np.array_equal(expected, C.full_transform(False, False))
    def test_full_transform_scalar(self):
        """Test extracting the full transform when only a scalar is provided.
        """
        # Construct the expected matrix
        expected = np.zeros((4, 8))
        expected[np.array([1, 2, 3]), np.array([5, 6, 7])] = 1.5

        # Create the transform
        t = Transform(size_in=8,
                      slice_in=slice(5, None),
                      transform=1.5,
                      size_out=4,
                      slice_out=[1, 2, 3])

        # Check the full transform is correct
        assert np.array_equal(t.full_transform(slice_in=True, slice_out=False),
                              expected[:, 5:])

        assert np.array_equal(t.full_transform(slice_in=False, slice_out=True),
                              expected[1:])

        assert np.array_equal(
            t.full_transform(slice_in=False, slice_out=False), expected)
    def test_full_transform_vector(self):
        """Test extracting the full transform when only a vector is provided.
        """
        # Construct the expected matrix
        expected = np.zeros((4, 8))
        diag = np.array([1.0, 2.0, 3.0])
        expected[np.array([0, 1, 2]), np.array([3, 4, 5])] = diag

        # Create the transform
        t = Transform(size_in=8,
                      slice_in=[3, 4, 5],
                      transform=diag,
                      size_out=4,
                      slice_out=slice(3))

        # Check the full transform is correct
        assert np.array_equal(t.full_transform(slice_in=True, slice_out=False),
                              expected[:, 3:6])

        assert np.array_equal(t.full_transform(slice_in=False, slice_out=True),
                              expected[:3])

        assert np.array_equal(
            t.full_transform(slice_in=False, slice_out=False), expected)