def test_single_patch(self): """Test computing DSMs using a single searchlight patch.""" data = np.array([[[1, 2, 3], [2, 3, 4]], [[2, 3, 4], [3, 4, 5]], [[3, 4, 5], [4, 5, 6]]]) dsms = dsm_array(data, dist_metric='correlation') assert len(dsms) == 1 assert dsms.shape == (3, ) assert_allclose(list(dsms), [[0, 0, 0]], atol=1E-15)
def test_temporal(self): """Test computing DSMs using a temporal searchlight.""" data = np.array([[1, 2, 3, 4], [1, 2, 3, 4]]) patches = searchlight(data.shape, temporal_radius=1) dsms = dsm_array(data, patches, dist_metric='euclidean') assert len(dsms) == len(patches) assert dsms.shape == (2, 1) assert_equal(list(dsms), [0, 0])
def test_spatial(self): """Test computing DSMs using a spatial searchlight.""" dist = np.array([[0, 1, 2, 3], [1, 0, 1, 2], [2, 1, 0, 1], [3, 2, 1, 0]]) data = np.array([[1, 2, 3, 4], [1, 2, 3, 4]]) patches = searchlight(data.shape, dist, spatial_radius=1) dsms = dsm_array(data, patches, dist_metric='euclidean') assert len(dsms) == len(patches) assert dsms.shape == (4, 1) assert_equal(list(dsms), [0, 0, 0, 0])
def test_spatio_temporal(self): """Test computing DSMs using a spatio-temporal searchlight.""" data = np.array([[[1, 2, 3], [2, 3, 4]], [[2, 3, 4], [3, 4, 5]], [[3, 4, 5], [4, 5, 6]]]) dist = np.array([[0, 1, 2], [1, 0, 1], [2, 1, 0]]) patches = searchlight(data.shape, dist, spatial_radius=1, temporal_radius=1) dsms = dsm_array(data, patches, dist_metric='correlation') assert len(dsms) == len(patches) assert dsms.shape == (2, 1, 3) assert_allclose(list(dsms), [[0, 0, 0], [0, 0, 0]], atol=1E-15)
def test_crossvalidation(self): """Test computing DSMs using a searchlight and cross-validation.""" data = np.array([[[1, 2, 3], [2, 3, 4]], [[2, 3, 4], [3, 4, 5]], [[3, 4, 5], [4, 5, 6]], [[1, 2, 3], [2, 3, 4]], [[2, 3, 4], [3, 4, 5]], [[3, 4, 5], [4, 5, 6]]]) dist = np.array([[0, 1, 2], [1, 0, 1], [2, 1, 0]]) patches = searchlight(data.shape, dist, spatial_radius=1, temporal_radius=1) dsms = dsm_array(data, patches, y=[1, 2, 3, 1, 2, 3], n_folds=2) assert len(dsms) == len(patches) assert dsms.shape == (2, 1, 3) assert_allclose(list(dsms), [[0, 0, 0], [0, 0, 0]], atol=1E-15)