def test_ROIEventSet():
    event_set = dc_types.ROIEventSet()
    rng = np.random.RandomState(888812)
    true_trace_s = []
    true_event_s = []
    true_trace_c = []
    true_event_c = []
    for ii in range(3):
        t = rng.random_sample(13)
        e = rng.randint(0, 111, size=13)
        signal = dc_types.ROIEvents()
        signal['trace'] = t
        signal['events'] = e
        true_trace_s.append(t)
        true_event_s.append(e)

        t = rng.random_sample(13)
        e = rng.randint(0, 111, size=13)
        crosstalk = dc_types.ROIEvents()
        crosstalk['trace'] = t
        crosstalk['events'] = e
        true_trace_c.append(t)
        true_event_c.append(e)

        channels = dc_types.ROIEventChannels()
        channels['signal'] = signal
        channels['crosstalk'] = crosstalk
        event_set[ii] = channels

    for ii in range(3):
        np_almost(event_set[ii]['signal']['trace'],
                  true_trace_s[ii],
                  decimal=10)
        np_equal(event_set[ii]['signal']['events'], true_event_s[ii])
        np_almost(event_set[ii]['crosstalk']['trace'],
                  true_trace_c[ii],
                  decimal=10)
        np_equal(event_set[ii]['crosstalk']['events'], true_event_c[ii])
    assert 0 in event_set
    assert 1 in event_set
    assert 2 in event_set
    assert 3 not in event_set
    keys = event_set.keys()
    keys.sort()
    assert keys == [0, 1, 2]
    channels = event_set.pop(1)
    np_almost(channels['signal']['trace'], true_trace_s[1], decimal=10)
    np_equal(channels['signal']['events'], true_event_s[1])
    np_almost(channels['crosstalk']['trace'], true_trace_c[1], decimal=10)
    np_equal(channels['crosstalk']['events'], true_event_c[1])
    assert 0 in event_set
    assert 2 in event_set
    assert 1 not in event_set
    keys = event_set.keys()
    keys.sort()
    assert keys == [0, 2]
def get_trace_events(
    trace_dict: dc_types.ROIDict,
    trace_threshold_params: dict = {
        'len_ne': 20,
        'th_ag': 14
    }
) -> dc_types.ROIEventSet:
    """
    trace_dict -- a decrosstalk_types.ROIDict containing the trace data
                  for many ROIs to be analyzed

    trace_threshold_params -- a dict of kwargs that need to be
                              passed to active_traces.get_trace_events
                              default: {'len_ne': 20,
                                        'th_ag': 14}

    Returns
    -------
    A decrosstalk_type.ROIEventSet containing the active trace data
    for the ROIs
    """
    roi_id_list = list(trace_dict.keys())

    data_arr = np.array(
        [trace_dict[roi_id]['signal'] for roi_id in roi_id_list])
    sig_dict = active_traces.get_trace_events(data_arr, trace_threshold_params)

    data_arr = np.array(
        [trace_dict[roi_id]['crosstalk'] for roi_id in roi_id_list])
    ct_dict = active_traces.get_trace_events(data_arr, trace_threshold_params)

    output = dc_types.ROIEventSet()
    for i_roi, roi_id in enumerate(roi_id_list):
        local_channels = dc_types.ROIEventChannels()

        signal = dc_types.ROIEvents()
        signal['trace'] = sig_dict['trace'][i_roi]
        signal['events'] = sig_dict['events'][i_roi]
        local_channels['signal'] = signal

        crosstalk = dc_types.ROIEvents()
        crosstalk['trace'] = ct_dict['trace'][i_roi]
        crosstalk['events'] = ct_dict['events'][i_roi]
        local_channels['crosstalk'] = crosstalk

        output[roi_id] = local_channels

    return output
def test_ROIEvents_exceptions():

    trace = np.linspace(0, 2.0, 5)
    events = np.arange(5, dtype=int)
    roi_events = dc_types.ROIEvents()

    with pytest.raises(KeyError):
        roi_events['boom'] = trace
    with pytest.raises(ValueError):
        roi_events['trace'] = events
    with pytest.raises(ValueError):
        roi_events['events'] = trace
    with pytest.raises(ValueError):
        roi_events['events'] = 5
    with pytest.raises(ValueError):
        roi_events['trace'] = 6.7
    with pytest.raises(ValueError):
        roi_events['trace'] = np.array([[1.1, 2.2], [3.4, 5.5]])

    roi_events['trace'] = trace
    roi_events['events'] = events

    _ = roi_events['trace']
    _ = roi_events['events']
    with pytest.raises(KeyError):
        _ = roi_events['boom']
def test_ROIEventSet_exceptions():

    channel = dc_types.ROIEventChannels()
    e = dc_types.ROIEvents()
    e['trace'] = np.linspace(1, 3, 10)
    e['events'] = np.arange(10, dtype=int)
    channel['signal'] = e
    e = dc_types.ROIEvents()
    e['trace'] = np.linspace(1, 7, 10)
    e['events'] = np.arange(10, 20, dtype=int)
    channel['crosstalk'] = e

    event_set = dc_types.ROIEventSet()
    with pytest.raises(KeyError):
        event_set['signal'] = e
    with pytest.raises(ValueError):
        event_set[9] = np.linspace(2, 7, 20)
    event_set[9] = channel
def test_ROIEvents():

    trace = np.linspace(0, 2.0, 5)
    events = np.arange(5, dtype=int)
    roi_events = dc_types.ROIEvents()
    roi_events['trace'] = trace
    roi_events['events'] = events

    np_almost(roi_events['trace'], np.linspace(0, 2.0, 5), decimal=10)
    np_equal(roi_events['events'], np.arange(5, dtype=int))
def test_ROIEventChannels():
    rng = np.random.RandomState(1245)
    traces = list([rng.random_sample(10) for ii in range(2)])
    events = list([rng.randint(0, 20, size=10) for ii in range(2)])

    event_set = dc_types.ROIEventChannels()
    for ii, k in enumerate(('signal', 'crosstalk')):
        ee = dc_types.ROIEvents()
        ee['trace'] = traces[ii]
        ee['events'] = events[ii]
        event_set[k] = ee

    np_almost(event_set['signal']['trace'], traces[0], decimal=10)
    np_equal(event_set['signal']['events'], events[0])
    np_almost(event_set['crosstalk']['trace'], traces[1], decimal=10)
    np_equal(event_set['crosstalk']['events'], events[1])
def test_ROIEventChannels_exceptions():

    event_set = dc_types.ROIEventChannels()
    ee = dc_types.ROIEvents()
    ee['trace'] = np.linspace(0, 1, 7)
    ee['events'] = np.arange(7, dtype=int)

    with pytest.raises(KeyError):
        event_set['a'] = ee

    with pytest.raises(ValueError):
        event_set['signal'] = np.linspace(1, 9, 12)

    event_set['signal'] = ee

    with pytest.raises(KeyError):
        _ = event_set['b']
def find_independent_events(signal_events: dc_types.ROIEvents,
                            crosstalk_events: dc_types.ROIEvents,
                            window: int = 2) -> dc_types.ROIEvents:
    """
    Calculate independent events between signal_events and crosstalk_events.

    The algorithm uses window to extend the range of event matches, such that
    if an event happens at time t in the signal and time t+window in the
    crosstalk, they are *not* considered independent events. If window=0,
    then any events that are not exact matches (i.e. occurring at the same
    time point) will be considered independent events.

    Parameters
    ----------
    signal_events -- a decrosstalk_types.ROIEvents containing the active
                     traces from the signal channel

    crosstalk_events -- a decrosstalk_types.ROIEvents containing the active
                        traces from the crosstalk channel

    window -- an int specifying the amount of blurring to use (default=2)

    Returns
    -------
    independent_events -- a decrosstalking_types.ROIEvents containing
                          the active traces and events that were in
                          signal_events, but not crosstalk_events +/- window
    """
    blurred_crosstalk = np.unique(
        np.concatenate([
            crosstalk_events['events'] + ii
            for ii in np.arange(-window, window + 1)
        ]))

    valid_signal_events = np.where(
        np.logical_not(np.isin(signal_events['events'], blurred_crosstalk)))

    output = dc_types.ROIEvents()
    output['trace'] = signal_events['trace'][valid_signal_events]
    output['events'] = signal_events['events'][valid_signal_events]

    return output
def create_dataset():
    """
    Create a test dataset exercising all possible permutations of
    validity flags.

    Returns
    -------
    dict
        'roi_flags': a dict mimicing the roi_flags returned by the pipeline

        'raw_traces': ROISetDict of valid traces
        'invalid_raw_traces': same as above for invalid raw traces

        'unmixed_traces': ROISetDict of valid unmixed traces
        'invalid_unmixed_traces': same as above fore invalid unmixed traces

        'raw_events': ROIEventSet of activity in valid raw traces
        'invalid_raw_events': same as above for invalid raw traces

        'unmixed_events': ROIEventSet of activity in valid unmixed traces
        'invalid_unmixed_events': same as above for invalid unmixed traces

        'true_flags': ground truth values of validity flags for all ROIs
    """
    rng = np.random.RandomState(172)
    n_t = 10

    raw_traces = dc_types.ROISetDict()
    invalid_raw_traces = dc_types.ROISetDict()
    unmixed_traces = dc_types.ROISetDict()
    invalid_unmixed_traces = dc_types.ROISetDict()
    raw_events = dc_types.ROIEventSet()
    invalid_raw_events = dc_types.ROIEventSet()
    unmixed_events = dc_types.ROIEventSet()
    invalid_unmixed_events = dc_types.ROIEventSet()
    roi_flags = {}
    roi_flags['decrosstalk_ghost'] = []

    true_flags = []

    iterator = itertools.product([True, False], [True, False], [True, False],
                                 [True, False], [True, False], [True, False])

    roi_id = -1
    for _f in iterator:
        roi_id += 1
        flags = {
            'valid_raw_trace': _f[0],
            'valid_raw_active_trace': _f[1],
            'valid_unmixed_trace': _f[2],
            'valid_unmixed_active_trace': _f[3],
            'converged': _f[4],
            'ghost': _f[5]
        }

        if not flags['valid_raw_trace']:
            flags['valid_raw_active_trace'] = False
            flags['valid_unmixed_trace'] = False
            flags['valid_unmixed_active_trace'] = False
        if not flags['valid_raw_active_trace']:
            flags['valid_unmixed_trace'] = False
            flags['valid_unmixed_active_trace'] = False
        if not flags['valid_unmixed_trace']:
            flags['valid_unmixed_active_trace'] = False

        true_flags.append(flags)

        # raw traces
        raw_roi = dc_types.ROIChannels()
        raw_roi['signal'] = rng.random_sample(n_t)
        raw_roi['crosstalk'] = rng.random_sample(n_t)

        raw_np = dc_types.ROIChannels()
        raw_np['signal'] = rng.random_sample(n_t)
        raw_np['crosstalk'] = rng.random_sample(n_t)

        if flags['valid_raw_trace'] and flags['valid_raw_active_trace']:
            raw_traces['roi'][roi_id] = raw_roi
            raw_traces['neuropil'][roi_id] = raw_np
        else:
            invalid_raw_traces['roi'][roi_id] = raw_roi
            invalid_raw_traces['neuropil'][roi_id] = raw_np
            if not flags['valid_raw_trace']:
                continue

        # raw trace events
        ee = dc_types.ROIEventChannels()
        e = dc_types.ROIEvents()
        e['events'] = rng.choice(np.arange(n_t, dtype=int), 3)
        e['trace'] = rng.random_sample(3)
        ee['signal'] = e
        e = dc_types.ROIEvents()
        e['events'] = rng.choice(np.arange(n_t, dtype=int), 3)
        e['trace'] = rng.random_sample(3)
        ee['crosstalk'] = e

        if flags['valid_raw_active_trace']:
            raw_events[roi_id] = ee
        else:
            invalid_raw_events[roi_id] = ee
            continue

        # unmixed traces
        unmixed_roi = dc_types.ROIChannels()
        unmixed_roi['signal'] = rng.random_sample(n_t)
        unmixed_roi['crosstalk'] = rng.random_sample(n_t)
        unmixed_roi['mixing_matrix'] = rng.random_sample((2, 2))
        unmixed_roi['use_avg_mixing_matrix'] = not flags['converged']
        if not flags['converged']:
            unmixed_roi['poorly_converged_signal'] = rng.random_sample(n_t)
            unmixed_roi['poorly_converged_crosstalk'] = rng.random_sample(n_t)
            mm = rng.random_sample((2, 2))
            unmixed_roi['poorly_converged_mixing_matrix'] = mm

        unmixed_np = dc_types.ROIChannels()
        unmixed_np['signal'] = rng.random_sample(n_t)
        unmixed_np['crosstalk'] = rng.random_sample(n_t)
        unmixed_np['mixing_matrix'] = rng.random_sample((2, 2))
        unmixed_np['use_avg_mixing_matrix'] = not flags['converged']
        if not flags['converged']:
            unmixed_np['poorly_converged_signal'] = rng.random_sample(n_t)
            unmixed_np['poorly_converged_crosstalk'] = rng.random_sample(n_t)
            mm = rng.random_sample((2, 2))
            unmixed_np['poorly_converged_mixing_matrix'] = mm

        is_valid = (flags['valid_unmixed_trace']
                    and flags['valid_unmixed_active_trace'])

        if is_valid:
            unmixed_traces['roi'][roi_id] = unmixed_roi
            unmixed_traces['neuropil'][roi_id] = unmixed_np
        else:
            invalid_unmixed_traces['roi'][roi_id] = unmixed_roi
            invalid_unmixed_traces['neuropil'][roi_id] = unmixed_np
            if not flags['valid_unmixed_trace']:
                continue

        # unmixedtrace events
        ee = dc_types.ROIEventChannels()
        e = dc_types.ROIEvents()
        e['events'] = rng.choice(np.arange(n_t, dtype=int), 3)
        e['trace'] = rng.random_sample(3)
        ee['signal'] = e
        e = dc_types.ROIEvents()
        e['events'] = rng.choice(np.arange(n_t, dtype=int), 3)
        e['trace'] = rng.random_sample(3)
        ee['crosstalk'] = e

        if flags['valid_unmixed_active_trace']:
            unmixed_events[roi_id] = ee
        else:
            invalid_unmixed_events[roi_id] = ee
            continue

        if flags['ghost']:
            roi_flags['decrosstalk_ghost'].append(roi_id)

    output = {}
    output['roi_flags'] = roi_flags
    output['raw_traces'] = raw_traces
    output['invalid_raw_traces'] = invalid_raw_traces
    output['unmixed_traces'] = unmixed_traces
    output['invalid_unmixed_traces'] = invalid_unmixed_traces
    output['raw_events'] = raw_events
    output['invalid_raw_events'] = invalid_raw_events
    output['unmixed_events'] = unmixed_events
    output['invalid_unmixed_events'] = invalid_unmixed_events
    output['true_flags'] = true_flags

    return output