def test_prepare_searchlight_mvpa_data(): processed_data, labels = io.prepare_searchlight_mvpa_data(dir, extension, epoch_file) expected_searchlight_processed_data = np.load( os.path.join(os.path.dirname(__file__), "data/expected_searchlight_processed_data.npy") ) for idx in range(len(processed_data)): assert np.allclose( processed_data[idx], expected_searchlight_processed_data[idx] ), "raw data do not match in test_prepare_searchlight_mvpa_data" assert np.array_equal(labels, expected_labels), "the labels do not match in test_prepare_searchlight_mvpa_data"
data_dir = sys.argv[1] extension = sys.argv[2] mask_file = sys.argv[3] epoch_file = sys.argv[4] # all MPI processes read the mask; the mask file is small mask_img = nib.load(mask_file) mask = mask_img.get_data().astype(np.bool) data = None labels = None if MPI.COMM_WORLD.Get_rank()==0: logger.info( 'mask size: %d' % np.sum(mask) ) data, labels = io.prepare_searchlight_mvpa_data(data_dir, extension, epoch_file) # the following line is an example to leaving a subject out #epoch_info = [x for x in epoch_info if x[1] != 0] num_subjs = int(sys.argv[5]) # create a Searchlight object sl = Searchlight(sl_rad=1) mvs = MVPAVoxelSelector(data, mask, labels, num_subjs, sl) clf = svm.SVC(kernel='linear', shrinking=False, C=1) # only rank 0 has meaningful return values score_volume, results = mvs.run(clf) # this output is just for result checking if MPI.COMM_WORLD.Get_rank()==0: score_volume = np.nan_to_num(score_volume.astype(np.float)) io.write_nifti_file(score_volume, mask_img.affine, 'result_score.nii.gz') seq_volume = np.zeros(mask.shape, dtype=np.int) seq = np.zeros(len(results), dtype=np.int)