Ejemplo n.º 1
0
def test_run_iteration():
    """Test functions needed to run each iteration.
    """
    model_file = join(get_test_data_path(), 'gclda_model.pkl')
    model = Model.load(model_file)
    initial_iter = model.iter
    model.run_complete_iteration()
    assert model.iter == initial_iter + 1
Ejemplo n.º 2
0
def test_encode_from_list():
    """Acceptance test of test-to-image encoding with list input.
    """
    model_file = join(get_test_data_path(), 'gclda_model.pkl')
    text = ['anterior', 'insula', 'was', 'analyzed']
    model = Model.load(model_file)
    encoded_img, _ = encode(model, text)
    assert encoded_img.shape == model.dataset.mask_img.shape
Ejemplo n.º 3
0
def test_decode_continuous_from_file():
    """Acceptance test of continuous image-based decoding with str input.
    """
    model_file = join(get_test_data_path(), 'gclda_model.pkl')
    continuous_file = join(get_test_data_path(), 'continuous.nii.gz')
    model = Model.load(model_file)
    decoded_df, _ = decode_continuous(model, continuous_file)
    assert decoded_df.shape[0] == model.n_word_labels
Ejemplo n.º 4
0
def test_decode_roi_from_file():
    """Acceptance test of ROI-based decoding with str input.
    """
    model_file = join(get_test_data_path(), 'gclda_model.pkl')
    roi_file = join(get_test_data_path(), 'roi.nii.gz')
    model = Model.load(model_file)
    decoded_df, _ = decode_roi(model, roi_file)
    assert decoded_df.shape[0] == model.n_word_labels
Ejemplo n.º 5
0
def test_decode_roi_with_priors():
    """Acceptance test of ROI-based decoding with topic priors.
    """
    model_file = join(get_test_data_path(), 'gclda_model.pkl')
    roi_file = join(get_test_data_path(), 'roi.nii.gz')
    model = Model.load(model_file)
    _, priors = decode_roi(model, roi_file)
    decoded_df, _ = decode_roi(model, roi_file, topic_priors=priors)
    assert decoded_df.shape[0] == model.n_word_labels
Ejemplo n.º 6
0
def test_encode_with_priors():
    """Acceptance test of test-to-image encoding.
    """
    model_file = join(get_test_data_path(), 'gclda_model.pkl')
    text = 'anterior insula was analyzed'
    model = Model.load(model_file)
    _, priors = encode(model, text)
    encoded_img, _ = encode(model, text, topic_priors=priors)
    assert encoded_img.shape == model.dataset.mask_img.shape
Ejemplo n.º 7
0
def test_symmetric():
    """Test running a model with symmetric ROIs.
    """
    dataset_file = join(get_test_data_path(), 'gclda_dataset.pkl')
    dset = Dataset.load(dataset_file)
    model = Model(dset,
                  n_topics=50,
                  n_regions=2,
                  symmetric=True,
                  alpha=.1,
                  beta=.01,
                  gamma=.01,
                  delta=1.,
                  dobs=25,
                  roi_size=10.,
                  seed_init=1)
    initial_iter = model.iter
    model.run_complete_iteration()
    assert model.iter == initial_iter + 1
Ejemplo n.º 8
0
def test_decode_continuous_with_priors():
    """Acceptance test of continuous image-based decoding with topic priors.
    """
    model_file = join(get_test_data_path(), 'gclda_model.pkl')
    continuous_file = join(get_test_data_path(), 'continuous.nii.gz')
    model = Model.load(model_file)
    _, priors = decode_continuous(model, continuous_file)
    decoded_df, _ = decode_continuous(model,
                                      continuous_file,
                                      topic_priors=priors)
    assert decoded_df.shape[0] == model.n_word_labels
Ejemplo n.º 9
0
def test_display_model_summary():
    """Prints model information to the console.
    """
    model_file = join(get_test_data_path(), 'gclda_model.pkl')
    model = Model.load(model_file)

    captured_output = StringIO()  # Create StringIO object
    sys.stdout = captured_output  #  and redirect stdout.
    model.display_model_summary()  # Call unchanged function.
    sys.stdout = sys.__stdout__  # Reset redirect.

    assert len(captured_output.getvalue()) > 0
Ejemplo n.º 10
0
def test_save_model2():
    """Test gclda.model.Model.save with gzipped file.
    """
    model_file = join(get_test_data_path(), 'gclda_model.pklz')
    temp_file = join(get_test_data_path(), 'temp.pklz')
    model = Model.load(model_file)
    model.save(temp_file)
    file_found = isfile(temp_file)
    assert file_found

    # Perform cleanup
    remove(temp_file)
Ejemplo n.º 11
0
def test_save_topic_figures():
    """Writes out images for topics.
    """
    model_file = join(get_test_data_path(), 'gclda_model.pkl')
    temp_dir = join(get_test_data_path(), 'temp')

    model = Model.load(model_file)
    model.save_topic_figures(temp_dir, n_top_words=5)
    figures = glob(join(temp_dir, '*.png'))
    assert len(figures) == model.n_topics

    # Perform cleanup
    rmtree(temp_dir)
Ejemplo n.º 12
0
def test_init():
    """Smoke test for Model class.
    """
    dataset_file = join(get_test_data_path(), 'gclda_dataset.pkl')
    dset = Dataset.load(dataset_file)
    model = Model(dset,
                  n_topics=50,
                  n_regions=1,
                  symmetric=False,
                  alpha=.1,
                  beta=.01,
                  gamma=.01,
                  delta=1.,
                  dobs=25,
                  roi_size=10.,
                  seed_init=1)
    assert isinstance(model, Model)
Ejemplo n.º 13
0
def test_save_model_params():
    """Ensure appropriate files are created.
    """
    model_file = join(get_test_data_path(), 'gclda_model.pkl')
    temp_dir = join(get_test_data_path(), 'temp')

    model = Model.load(model_file)
    model.save_model_params(temp_dir, n_top_words=2)
    files_found = [
        isfile(join(temp_dir, 'Topic_X_Word_Probs.csv')),
        isfile(join(temp_dir, 'Topic_X_Word_CountMatrix.csv')),
        isfile(join(temp_dir, 'ActivationAssignments.csv'))
    ]
    assert all(files_found)

    # Perform cleanup
    rmtree(temp_dir)
Ejemplo n.º 14
0
from os.path import join
import matplotlib.pyplot as plt

from nilearn import plotting
from nltools.mask import create_sphere

from gclda.model import Model
from gclda.decode import decode_roi
from gclda.utils import get_resource_path

###############################################################################
# Load model and initialize decoder
# ----------------------------------
model_file = join(get_resource_path(), 'models/Neurosynth2015Filtered2',
                  'model_200topics_2015Filtered2_10000iters.pklz')
model = Model.load(model_file)

###############################################################################
# Create region of interest (ROI) image
# --------------------------------------
coords = [[-40, -52, -20]]
radii = [6] * len(coords)

roi_img = create_sphere(coords, radius=radii, mask=model.dataset.mask_img)
fig = plotting.plot_roi(roi_img,
                        display_mode='ortho',
                        cut_coords=[-40, -52, -20],
                        draw_cross=False)

###############################################################################
# Decode ROI
Ejemplo n.º 15
0
def test_load_model():
    """Test gclda.model.Model.load.
    """
    model_file = join(get_test_data_path(), 'gclda_model.pkl')
    model = Model.load(model_file)
    assert isinstance(model, Model)
Ejemplo n.º 16
0
# Initialize dataset
# ----------------------------------
dataset_label = 'Neurosynth2015Filtered2_1000docs'
dataset_dir = join(get_resource_path(), 'datasets')
dataset = Dataset(dataset_label, dataset_dir)
dataset.display_dataset_summary()

###############################################################################
# Initialize model
# ----------------------
model = Model(dataset,
              n_topics=200,
              n_regions=2,
              alpha=.1,
              beta=.01,
              gamma=.01,
              delta=1.0,
              dobs=25,
              roi_size=50,
              symmetric=True,
              seed_init=1)
model.display_model_summary()

###############################################################################
# Run model (10 iterations)
# -------------------------
n_iterations = 10
for i in range(model.iter, n_iterations):
    model.run_complete_iteration(loglikely_freq=10, verbose=1)
model.display_model_summary()