def test_ROIEventSet():
    event_set = dc_types.ROIEventSet()
    rng = np.random.RandomState(888812)
    true_trace_s = []
    true_event_s = []
    true_trace_c = []
    true_event_c = []
    for ii in range(3):
        t = rng.random_sample(13)
        e = rng.randint(0, 111, size=13)
        signal = dc_types.ROIEvents()
        signal['trace'] = t
        signal['events'] = e
        true_trace_s.append(t)
        true_event_s.append(e)

        t = rng.random_sample(13)
        e = rng.randint(0, 111, size=13)
        crosstalk = dc_types.ROIEvents()
        crosstalk['trace'] = t
        crosstalk['events'] = e
        true_trace_c.append(t)
        true_event_c.append(e)

        channels = dc_types.ROIEventChannels()
        channels['signal'] = signal
        channels['crosstalk'] = crosstalk
        event_set[ii] = channels

    for ii in range(3):
        np_almost(event_set[ii]['signal']['trace'],
                  true_trace_s[ii],
                  decimal=10)
        np_equal(event_set[ii]['signal']['events'], true_event_s[ii])
        np_almost(event_set[ii]['crosstalk']['trace'],
                  true_trace_c[ii],
                  decimal=10)
        np_equal(event_set[ii]['crosstalk']['events'], true_event_c[ii])
    assert 0 in event_set
    assert 1 in event_set
    assert 2 in event_set
    assert 3 not in event_set
    keys = event_set.keys()
    keys.sort()
    assert keys == [0, 1, 2]
    channels = event_set.pop(1)
    np_almost(channels['signal']['trace'], true_trace_s[1], decimal=10)
    np_equal(channels['signal']['events'], true_event_s[1])
    np_almost(channels['crosstalk']['trace'], true_trace_c[1], decimal=10)
    np_equal(channels['crosstalk']['events'], true_event_c[1])
    assert 0 in event_set
    assert 2 in event_set
    assert 1 not in event_set
    keys = event_set.keys()
    keys.sort()
    assert keys == [0, 2]
def get_trace_events(
    trace_dict: dc_types.ROIDict,
    trace_threshold_params: dict = {
        'len_ne': 20,
        'th_ag': 14
    }
) -> dc_types.ROIEventSet:
    """
    trace_dict -- a decrosstalk_types.ROIDict containing the trace data
                  for many ROIs to be analyzed

    trace_threshold_params -- a dict of kwargs that need to be
                              passed to active_traces.get_trace_events
                              default: {'len_ne': 20,
                                        'th_ag': 14}

    Returns
    -------
    A decrosstalk_type.ROIEventSet containing the active trace data
    for the ROIs
    """
    roi_id_list = list(trace_dict.keys())

    data_arr = np.array(
        [trace_dict[roi_id]['signal'] for roi_id in roi_id_list])
    sig_dict = active_traces.get_trace_events(data_arr, trace_threshold_params)

    data_arr = np.array(
        [trace_dict[roi_id]['crosstalk'] for roi_id in roi_id_list])
    ct_dict = active_traces.get_trace_events(data_arr, trace_threshold_params)

    output = dc_types.ROIEventSet()
    for i_roi, roi_id in enumerate(roi_id_list):
        local_channels = dc_types.ROIEventChannels()

        signal = dc_types.ROIEvents()
        signal['trace'] = sig_dict['trace'][i_roi]
        signal['events'] = sig_dict['events'][i_roi]
        local_channels['signal'] = signal

        crosstalk = dc_types.ROIEvents()
        crosstalk['trace'] = ct_dict['trace'][i_roi]
        crosstalk['events'] = ct_dict['events'][i_roi]
        local_channels['crosstalk'] = crosstalk

        output[roi_id] = local_channels

    return output
def test_ROIEventSet_exceptions():

    channel = dc_types.ROIEventChannels()
    e = dc_types.ROIEvents()
    e['trace'] = np.linspace(1, 3, 10)
    e['events'] = np.arange(10, dtype=int)
    channel['signal'] = e
    e = dc_types.ROIEvents()
    e['trace'] = np.linspace(1, 7, 10)
    e['events'] = np.arange(10, 20, dtype=int)
    channel['crosstalk'] = e

    event_set = dc_types.ROIEventSet()
    with pytest.raises(KeyError):
        event_set['signal'] = e
    with pytest.raises(ValueError):
        event_set[9] = np.linspace(2, 7, 20)
    event_set[9] = channel
def create_dataset():
    """
    Create a test dataset exercising all possible permutations of
    validity flags.

    Returns
    -------
    dict
        'roi_flags': a dict mimicing the roi_flags returned by the pipeline

        'raw_traces': ROISetDict of valid traces
        'invalid_raw_traces': same as above for invalid raw traces

        'unmixed_traces': ROISetDict of valid unmixed traces
        'invalid_unmixed_traces': same as above fore invalid unmixed traces

        'raw_events': ROIEventSet of activity in valid raw traces
        'invalid_raw_events': same as above for invalid raw traces

        'unmixed_events': ROIEventSet of activity in valid unmixed traces
        'invalid_unmixed_events': same as above for invalid unmixed traces

        'true_flags': ground truth values of validity flags for all ROIs
    """
    rng = np.random.RandomState(172)
    n_t = 10

    raw_traces = dc_types.ROISetDict()
    invalid_raw_traces = dc_types.ROISetDict()
    unmixed_traces = dc_types.ROISetDict()
    invalid_unmixed_traces = dc_types.ROISetDict()
    raw_events = dc_types.ROIEventSet()
    invalid_raw_events = dc_types.ROIEventSet()
    unmixed_events = dc_types.ROIEventSet()
    invalid_unmixed_events = dc_types.ROIEventSet()
    roi_flags = {}
    roi_flags['decrosstalk_ghost'] = []

    true_flags = []

    iterator = itertools.product([True, False], [True, False], [True, False],
                                 [True, False], [True, False], [True, False])

    roi_id = -1
    for _f in iterator:
        roi_id += 1
        flags = {
            'valid_raw_trace': _f[0],
            'valid_raw_active_trace': _f[1],
            'valid_unmixed_trace': _f[2],
            'valid_unmixed_active_trace': _f[3],
            'converged': _f[4],
            'ghost': _f[5]
        }

        if not flags['valid_raw_trace']:
            flags['valid_raw_active_trace'] = False
            flags['valid_unmixed_trace'] = False
            flags['valid_unmixed_active_trace'] = False
        if not flags['valid_raw_active_trace']:
            flags['valid_unmixed_trace'] = False
            flags['valid_unmixed_active_trace'] = False
        if not flags['valid_unmixed_trace']:
            flags['valid_unmixed_active_trace'] = False

        true_flags.append(flags)

        # raw traces
        raw_roi = dc_types.ROIChannels()
        raw_roi['signal'] = rng.random_sample(n_t)
        raw_roi['crosstalk'] = rng.random_sample(n_t)

        raw_np = dc_types.ROIChannels()
        raw_np['signal'] = rng.random_sample(n_t)
        raw_np['crosstalk'] = rng.random_sample(n_t)

        if flags['valid_raw_trace'] and flags['valid_raw_active_trace']:
            raw_traces['roi'][roi_id] = raw_roi
            raw_traces['neuropil'][roi_id] = raw_np
        else:
            invalid_raw_traces['roi'][roi_id] = raw_roi
            invalid_raw_traces['neuropil'][roi_id] = raw_np
            if not flags['valid_raw_trace']:
                continue

        # raw trace events
        ee = dc_types.ROIEventChannels()
        e = dc_types.ROIEvents()
        e['events'] = rng.choice(np.arange(n_t, dtype=int), 3)
        e['trace'] = rng.random_sample(3)
        ee['signal'] = e
        e = dc_types.ROIEvents()
        e['events'] = rng.choice(np.arange(n_t, dtype=int), 3)
        e['trace'] = rng.random_sample(3)
        ee['crosstalk'] = e

        if flags['valid_raw_active_trace']:
            raw_events[roi_id] = ee
        else:
            invalid_raw_events[roi_id] = ee
            continue

        # unmixed traces
        unmixed_roi = dc_types.ROIChannels()
        unmixed_roi['signal'] = rng.random_sample(n_t)
        unmixed_roi['crosstalk'] = rng.random_sample(n_t)
        unmixed_roi['mixing_matrix'] = rng.random_sample((2, 2))
        unmixed_roi['use_avg_mixing_matrix'] = not flags['converged']
        if not flags['converged']:
            unmixed_roi['poorly_converged_signal'] = rng.random_sample(n_t)
            unmixed_roi['poorly_converged_crosstalk'] = rng.random_sample(n_t)
            mm = rng.random_sample((2, 2))
            unmixed_roi['poorly_converged_mixing_matrix'] = mm

        unmixed_np = dc_types.ROIChannels()
        unmixed_np['signal'] = rng.random_sample(n_t)
        unmixed_np['crosstalk'] = rng.random_sample(n_t)
        unmixed_np['mixing_matrix'] = rng.random_sample((2, 2))
        unmixed_np['use_avg_mixing_matrix'] = not flags['converged']
        if not flags['converged']:
            unmixed_np['poorly_converged_signal'] = rng.random_sample(n_t)
            unmixed_np['poorly_converged_crosstalk'] = rng.random_sample(n_t)
            mm = rng.random_sample((2, 2))
            unmixed_np['poorly_converged_mixing_matrix'] = mm

        is_valid = (flags['valid_unmixed_trace']
                    and flags['valid_unmixed_active_trace'])

        if is_valid:
            unmixed_traces['roi'][roi_id] = unmixed_roi
            unmixed_traces['neuropil'][roi_id] = unmixed_np
        else:
            invalid_unmixed_traces['roi'][roi_id] = unmixed_roi
            invalid_unmixed_traces['neuropil'][roi_id] = unmixed_np
            if not flags['valid_unmixed_trace']:
                continue

        # unmixedtrace events
        ee = dc_types.ROIEventChannels()
        e = dc_types.ROIEvents()
        e['events'] = rng.choice(np.arange(n_t, dtype=int), 3)
        e['trace'] = rng.random_sample(3)
        ee['signal'] = e
        e = dc_types.ROIEvents()
        e['events'] = rng.choice(np.arange(n_t, dtype=int), 3)
        e['trace'] = rng.random_sample(3)
        ee['crosstalk'] = e

        if flags['valid_unmixed_active_trace']:
            unmixed_events[roi_id] = ee
        else:
            invalid_unmixed_events[roi_id] = ee
            continue

        if flags['ghost']:
            roi_flags['decrosstalk_ghost'].append(roi_id)

    output = {}
    output['roi_flags'] = roi_flags
    output['raw_traces'] = raw_traces
    output['invalid_raw_traces'] = invalid_raw_traces
    output['unmixed_traces'] = unmixed_traces
    output['invalid_unmixed_traces'] = invalid_unmixed_traces
    output['raw_events'] = raw_events
    output['invalid_raw_events'] = invalid_raw_events
    output['unmixed_events'] = unmixed_events
    output['invalid_unmixed_events'] = invalid_unmixed_events
    output['true_flags'] = true_flags

    return output
def run_decrosstalk(
    signal_plane: DecrosstalkingOphysPlane,
    ct_plane: DecrosstalkingOphysPlane,
    cache_dir: str = None,
    clobber: bool = False,
    new_style_output: bool = False
) -> Tuple[dict, Tuple[dc_types.ROISetDict, dc_types.ROISetDict], Tuple[
        dc_types.ROISetDict, dc_types.ROISetDict], Tuple[
            dc_types.ROIEventSet, dc_types.ROIEventSet], Tuple[
                dc_types.ROIEventSet, dc_types.ROIEventSet]]:
    """
    Actually run the decrosstalking pipeline, comparing two
    DecrosstalkingOphysPlanes

    Parameters
    ----------
    signal_plane -- the DecrosstalkingOphysPlane characterizing the
                    signal plane

    ct_plane -- the DecrosstalkingOphysPlane characterizing the crosstalk plane

    cache_dir -- the directory in which to write the QC output
    (if None, the output does not get written)

    clobber -- a boolean indicating whether or not to overwrite
    pre-existing output files (default: False)

    new_style_output -- a boolean (default: False)

    Returns
    -------
    roi_flags -- a dict listing the ROI IDs of ROIs that were
    ruled invalid for different reasons, namely:

        'decrosstalk_ghost' -- ROIs that are ghosts

        'decrosstalk_invalid_raw' -- ROIs with invalid
                                     raw traces

        'decrosstalk_invalid_raw_active' -- ROIs with invalid
                                            raw active traces

        'decrosstalk_invalid_unmixed' -- ROIs with invalid
                                         unmixed traces

        'decrosstalk_invalid_unmixed_active' -- ROIs with invalid
                                                unmixed active traces

    (raw_traces,
     invalid_raw_traces) -- two decrosstalk_types.ROISetDicts containing
                            the raw trace data for the ROIs and the
                            invalid raw traces

    (unmixed_traces,
     invalid_unmixed_traces) -- two decrosstalk_types.ROISetDicts containing
                                the unmixed trace data for the ROIs and then
                                invalid unmixed traces

    (raw_trace_events,
     invalid_raw_trace_events) -- two decrosstalk_types.ROIEventSets
                                  characterizing the active timestamps
                                  from the raw traces and the invalid
                                  active timestamps

    (unmixed_trace_events,
     invalid_unmixed_trace_events) -- two decrosstalk_types.ROIEventSets
                                      characterizing the active timestamps
                                      from the unmixed traces and the invalid
                                      unmixed trace events
    """
    raw_traces = dc_types.ROISetDict()
    unmixed_traces = dc_types.ROISetDict()
    raw_trace_events = dc_types.ROIEventSet()
    unmixed_trace_events = dc_types.ROIEventSet()

    invalid_raw_traces = dc_types.ROISetDict()
    invalid_unmixed_traces = dc_types.ROISetDict()
    invalid_raw_trace_events = dc_types.ROIEventSet()
    invalid_unmixed_trace_events = dc_types.ROIEventSet()

    roi_flags: Dict[str, List[int]] = {}

    ghost_key = 'decrosstalk_ghost'
    raw_key = 'decrosstalk_invalid_raw'
    raw_active_key = 'decrosstalk_invalid_raw_active'
    unmixed_key = 'decrosstalk_invalid_unmixed'
    unmixed_active_key = 'decrosstalk_invalid_unmixed_active'

    roi_flags[ghost_key] = []
    roi_flags[raw_key] = []
    roi_flags[unmixed_key] = []
    roi_flags[raw_active_key] = []
    roi_flags[unmixed_active_key] = []

    # If there are no ROIs in the signal plane,
    # just return a set of empty outputs
    if len(signal_plane.roi_list) == 0:
        return (roi_flags, (raw_traces, invalid_raw_traces),
                (unmixed_traces, invalid_unmixed_traces),
                (raw_trace_events, invalid_raw_trace_events),
                (unmixed_trace_events, invalid_unmixed_trace_events))

    ###############################
    # extract raw traces

    raw_traces = get_raw_traces(signal_plane, ct_plane)
    raw_trace_validation = d_utils.validate_traces(raw_traces)

    # remove invalid raw traces
    invalid_raw_trace_roi_id = []
    for roi_id in raw_trace_validation:
        if not raw_trace_validation[roi_id]:
            invalid_raw_trace_roi_id.append(roi_id)

            _roi = raw_traces['roi'].pop(roi_id)
            _neuropil = raw_traces['neuropil'].pop(roi_id)

            invalid_raw_traces['roi'][roi_id] = _roi
            invalid_raw_traces['neuropil'][roi_id] = _neuropil

    roi_flags[raw_key] += invalid_raw_trace_roi_id

    if len(raw_traces['roi']) == 0:
        msg = 'No raw traces were valid when applying '
        msg += 'decrosstalk to ophys_experiment_id: '
        msg += '%d (%d)' % (signal_plane.experiment_id, ct_plane.experiment_id)
        logger.error(msg)

        return (roi_flags, (raw_traces, invalid_raw_traces),
                (unmixed_traces, invalid_unmixed_traces),
                (raw_trace_events, invalid_raw_trace_events),
                (unmixed_trace_events, invalid_unmixed_trace_events))

    #########################################
    # detect activity in raw traces

    raw_trace_events = get_trace_events(raw_traces['roi'])

    # For each ROI, calculate a random seed based on the flux
    # in all timestamps *not* chosen as events (presumably,
    # random noise)
    roi_to_seed = {}
    two_to_32 = 2**32
    for roi_id in raw_trace_events.keys():
        flux_mask = np.ones(len(raw_traces['roi'][roi_id]['signal']),
                            dtype=bool)
        if len(raw_trace_events[roi_id]['signal']['events']) > 0:
            flux_mask[raw_trace_events[roi_id]['signal']['events']] = False
        _flux = np.abs(raw_traces['roi'][roi_id]['signal'][flux_mask])
        flux_sum = np.round(_flux.sum()).astype(int)
        roi_to_seed[roi_id] = flux_sum % two_to_32

    # remove ROIs with invalid active raw traces
    roi_id_list = list(raw_trace_events.keys())
    for roi_id in roi_id_list:
        signal = raw_trace_events[roi_id]['signal']['trace']
        if len(signal) == 0 or np.isnan(signal).any():
            roi_flags[raw_active_key].append(roi_id)

            _events = raw_trace_events.pop(roi_id)
            _roi = raw_traces['roi'].pop(roi_id)
            _neuropil = raw_traces['neuropil'].pop(roi_id)

            invalid_raw_traces['roi'][roi_id] = _roi
            invalid_raw_traces['neuropil'][roi_id] = _neuropil
            invalid_raw_trace_events[roi_id] = _events

    # if there was no activity in the raw traces, return an
    # empty ROISetDict because none of the ROIs were valid
    if len(raw_traces['roi']) == 0:
        return (roi_flags, (raw_traces, invalid_raw_traces),
                (unmixed_traces, invalid_unmixed_traces),
                (raw_trace_events, invalid_raw_trace_events),
                (unmixed_trace_events, invalid_unmixed_trace_events))

    ###########################################################
    # use Independent Component Analysis to separate out signal
    # and crosstalk

    (ica_converged, unmixed_traces) = unmix_all_ROIs(raw_traces, roi_to_seed)

    # clip dips in signal channel
    clipped_traces = clean_negative_traces(unmixed_traces)

    # save old signal to 'unclipped_signal'
    # save new signal to 'signal'
    for obj in ('roi', 'neuropil'):
        for roi_id in unmixed_traces[obj].keys():
            s = unmixed_traces[obj][roi_id]['signal']
            unmixed_traces[obj][roi_id]['unclipped_signal'] = s
            s = clipped_traces[obj][roi_id]['signal']
            unmixed_traces[obj][roi_id]['signal'] = s

    if not ica_converged:
        for roi_id in unmixed_traces['roi'].keys():
            roi_flags[unmixed_key].append(roi_id)

        msg = 'ICA did not converge for any ROIs when '
        msg += 'applying decrosstalk to ophys_experiment_id: '
        msg += '%d (%d)' % (signal_plane.experiment_id, ct_plane.experiment_id)
        logger.error(msg)

        return (roi_flags, (raw_traces, invalid_raw_traces),
                (unmixed_traces, invalid_unmixed_traces),
                (raw_trace_events, invalid_raw_trace_events),
                (unmixed_trace_events, invalid_unmixed_trace_events))

    unmixed_trace_validation = d_utils.validate_traces(unmixed_traces)

    # remove invalid unmixed traces
    invalid_unmixed_trace_roi_id = []
    for roi_id in unmixed_trace_validation:
        if not unmixed_trace_validation[roi_id]:
            invalid_unmixed_trace_roi_id.append(roi_id)

            _roi = unmixed_traces['roi'].pop(roi_id)
            _neuropil = unmixed_traces['neuropil'].pop(roi_id)

            invalid_unmixed_traces['roi'][roi_id] = _roi
            invalid_unmixed_traces['neuropil'][roi_id] = _neuropil

    roi_flags[unmixed_key] += invalid_unmixed_trace_roi_id

    if len(unmixed_traces['roi']) == 0:
        msg = 'No unmixed traces were valid when applying '
        msg += 'decrosstalk to ophys_experiment_id: '
        msg += '%d (%d)' % (signal_plane.experiment_id, ct_plane.experiment_id)
        logger.error(msg)

        return (roi_flags, (raw_traces, invalid_raw_traces),
                (unmixed_traces, invalid_unmixed_traces),
                (raw_trace_events, invalid_raw_trace_events),
                (unmixed_trace_events, invalid_unmixed_trace_events))

    ###################################################
    # Detect activity in unmixed traces

    unmixed_trace_events = get_trace_events(unmixed_traces['roi'])

    # Sometimes, unmixed_trace_events will return an array of NaNs.
    # Until we can debug that behavior, we will log those errors,
    # store the relevaten ROIs as decrosstalk_invalid_unmixed_trace,
    # and cull those ROIs from the data

    invalid_active_trace: Dict[str, List[int]] = {}
    invalid_active_trace['signal'] = []
    invalid_active_trace['crosstalk'] = []
    active_trace_had_NaNs = False
    for roi_id in unmixed_trace_events.keys():
        local_traces = unmixed_trace_events[roi_id]
        for channel in ('signal', 'crosstalk'):
            is_valid = True
            if len(local_traces[channel]['trace']) == 0:
                is_valid = False
            else:
                nan_trace = np.isnan(local_traces[channel]['trace']).any()
                nan_events = np.isnan(local_traces[channel]['events']).any()
                if nan_trace or nan_events:
                    is_valid = False
                    active_trace_had_NaNs = True

            if not is_valid:
                invalid_active_trace[channel].append(roi_id)

    if active_trace_had_NaNs:
        msg = 'ophys_experiment_id: %d (%d) ' % (signal_plane.experiment_id,
                                                 ct_plane.experiment_id)
        msg += 'had ROIs with active event channels that contained NaNs'
        logger.error(msg)

    # remove ROIs with empty or NaN active trace signal channels
    # from the data being processed
    for roi_id in invalid_active_trace['signal']:
        roi_flags[unmixed_active_key].append(roi_id)

        _events = unmixed_trace_events.pop(roi_id)
        _roi = unmixed_traces['roi'].pop(roi_id)
        _neuropil = unmixed_traces['neuropil'].pop(roi_id)

        invalid_unmixed_traces['roi'][roi_id] = _roi
        invalid_unmixed_traces['neuropil'][roi_id] = _neuropil
        invalid_unmixed_trace_events[roi_id] = _events

    ########################################################
    # For each ROI, assess whether or not it is a "ghost"
    # (i.e. whether any of its activity is due to the signal,
    # independent of the crosstalk; if not, it is a ghost)

    independent_events = {}
    ghost_roi_id = []
    for roi_id in unmixed_trace_events.keys():
        signal = unmixed_trace_events[roi_id]['signal']
        crosstalk = unmixed_trace_events[roi_id]['crosstalk']

        (is_a_cell,
         ind_events) = d_utils.validate_cell_crosstalk(signal, crosstalk)

        local = {'is_a_cell': is_a_cell, 'independent_events': ind_events}

        independent_events[roi_id] = local
        if not is_a_cell:
            ghost_roi_id.append(roi_id)
    roi_flags[ghost_key] += ghost_roi_id

    return (roi_flags, (raw_traces, invalid_raw_traces),
            (unmixed_traces, invalid_unmixed_traces),
            (raw_trace_events, invalid_raw_trace_events),
            (unmixed_trace_events, invalid_unmixed_trace_events))