def test_addsubs_wo_fit(tempdir, identifiability):

    with tempfile.TemporaryDirectory() as datadir:
        X, W, S = generate_data(
            n_voxels,
            [24, 25],
            n_subjects,
            n_components,
            datadir,
            0,
            "list_of_list",
        )

        if tempdir:
            temp_dir = datadir
        else:
            temp_dir = None

        srm = IdentifiableFastSRM(
            identifiability=identifiability,
            n_subjects_ica=n_subjects,
            n_components=n_components,
            n_iter=10,
            temp_dir=temp_dir,
        )

        srm.add_subjects(X, S)

        for i in range(len(W)):
            assert_array_almost_equal(safe_load(srm.basis_list[i]), W[i])
def test_fastsrm_class_correctness(
    input_format,
    low_ram,
    tempdir,
    atlas,
    n_jobs,
    n_timeframes,
    aggregate,
    identifiability,
):
    with tempfile.TemporaryDirectory() as datadir:
        np.random.seed(0)
        X, W, S = generate_ica_friendly_data(n_voxels, n_timeframes,
                                             n_subjects, datadir, 0,
                                             input_format)

        XX, n_sessions = apply_input_format(X, input_format)

        if tempdir:
            temp_dir = datadir
        else:
            temp_dir = None

        srm = IdentifiableFastSRM(
            identifiability=identifiability,
            n_subjects_ica=n_subjects,
            atlas=atlas,
            n_components=n_components,
            n_iter=1000,
            tol=1e-7,
            temp_dir=temp_dir,
            low_ram=low_ram,
            n_jobs=n_jobs,
            aggregate=aggregate,
        )

        # Check that there is no difference between fit_transform
        # and fit then transform
        shared_response_fittransform = apply_aggregate(srm.fit_transform(X),
                                                       aggregate, input_format)
        prev_basis = srm.basis_list
        # we need to align both basis though...
        srm.basis_list = align_basis(srm.basis_list, prev_basis)

        basis = [safe_load(b) for b in srm.basis_list]
        shared_response_raw = srm.transform(X)
        shared_response = apply_aggregate(shared_response_raw, aggregate,
                                          input_format)

        for j in range(n_sessions):
            assert_array_almost_equal(shared_response_fittransform[j],
                                      shared_response[j], 1)

        # Check that the decomposition works
        for i in range(n_subjects):
            for j in range(n_sessions):
                assert_array_almost_equal(shared_response[j].T.dot(basis[i]),
                                          XX[i][j].T)

        # Check that if we use all subjects but one if gives almost the
        # same shared response
        shared_response_partial_raw = srm.transform(X[1:5],
                                                    subjects_indexes=list(
                                                        range(1, 5)))

        shared_response_partial = apply_aggregate(shared_response_partial_raw,
                                                  aggregate, input_format)
        for j in range(n_sessions):
            assert_array_almost_equal(shared_response_partial[j],
                                      shared_response[j])

        # Check that if we perform add 2 times the same subject we
        # obtain the same decomposition
        srm.add_subjects(X[:1], shared_response_raw)
        assert_array_almost_equal(safe_load(srm.basis_list[0]),
                                  safe_load(srm.basis_list[-1]))