예제 #1
0
def ReadPixelData(events, DIM_X, DIM_Y, spatial_scale):

    import numpy as np
    from utils import remove_annotations
    import neo

    PixelLabel = events.array_annotations[
        'y_coords'] * DIM_Y + events.array_annotations['x_coords']
    UpTrans = events.times
    Sorted_Idx = np.argsort(UpTrans)
    UpTrans = UpTrans[Sorted_Idx]
    PixelLabel = PixelLabel[Sorted_Idx]

    UpTrans_Evt = neo.Event(times=UpTrans,
                name='UpTrans',
                array_annotations={'channels':PixelLabel},
                description='Transitions from down to up states. '\
                           +'Annotated with the channel id ("channels")',
                Dim_x = DIM_X,
                Dim_y = DIM_Y,
                spatial_scale = spatial_scale)
    remove_annotations(UpTrans_Evt, del_keys=['nix_name', 'neo_name'])
    UpTrans_Evt.annotations.update(events.annotations)

    return (UpTrans_Evt)
예제 #2
0
def cluster_triggers(event, metric, neighbour_distance, min_samples, time_dim):
    up_idx = np.where(event.labels == 'UP')[0]

    # build 3D array of trigger times
    triggers = np.zeros((len(up_idx), 3))
    triggers[:,0] = event.array_annotations['x_coords'][up_idx]
    triggers[:,1] = event.array_annotations['y_coords'][up_idx]
    triggers[:,2] = event.times[up_idx] * args.time_dim
    #
    # for i, channel in enumerate(evts.array_annotations['channels'][up_idx]):
    #     triggers[i][0] = asig.array_annotations['x_coords'][int(channel)]
    #     triggers[i][1] = asig.array_annotations['y_coords'][int(channel)]

    clustering = DBSCAN(eps=args.neighbour_distance,
                        min_samples=args.min_samples,
                        metric=args.metric)
    clustering.fit(triggers)

    if len(np.unique(clustering.labels_)) < 1:
        raise ValueError("No Clusters found, please adapt the parameters!")

    # remove unclassified trigger points (label == -1)
    cluster_idx = np.where(clustering.labels_ != -1)[0]
    if not len(cluster_idx):
        raise ValueError("Clusters couldn't be classified, please adapt the parameters!")
        
    wave_idx = up_idx[cluster_idx]

    evt = neo.Event(times=event.times[wave_idx],
                    labels=clustering.labels_[cluster_idx],
                    name='Wavefronts',
                    array_annotations={'channels':event.array_annotations['channels'][wave_idx],
                                       'x_coords':triggers[:,0][cluster_idx],
                                       'y_coords':triggers[:,1][cluster_idx]},
                    description='Transitions from down to up states. '\
                               +'Labels are ids of wavefronts. '
                               +'Annotated with the channel id ("channels") and '\
                               +'its position ("x_coords", "y_coords").',
                    cluster_algorithm='sklearn.cluster.DBSCAN',
                    cluster_eps=args.neighbour_distance,
                    cluster_metric=args.metric,
                    cluster_min_samples=args.min_samples)

    remove_annotations(event, del_keys=['nix_name', 'neo_name'])
    evt.annotations.update(event.annotations)
    return evt
예제 #3
0
def threshold(asig, threshold_array):
    dim_t, channel_num = asig.shape
    th_signal = asig.as_array()\
              - np.repeat(threshold_array[np.newaxis, :], dim_t, axis=0)
    state_array = th_signal > 0
    rolled_state_array = np.roll(state_array, 1, axis=0)

    all_times = np.array([])
    all_channels = np.array([], dtype=int)
    all_labels = np.array([])
    for label, func in zip(['UP',        'DOWN'],
                           [lambda x: x, lambda x: np.bitwise_not(x)]):
        trans = np.where(func(np.bitwise_not(rolled_state_array))\
                       * func(state_array))
        channels = trans[1]
        times = asig.times[trans[0]]

        if not len(times):
            raise ValueError("The choosen threshold lies not within the range "\
                           + "of the signal values!")

        all_channels = np.append(all_channels, channels)
        all_times = np.append(all_times, times)
        all_labels = np.append(all_labels, np.array([label for _ in times]))

    sort_idx = np.argsort(all_times)

    evt = neo.Event(times=all_times[sort_idx]*asig.times.units,
                    labels=all_labels[sort_idx],
                    name='Transitions',
                    array_annotations={'channels':all_channels[sort_idx]},
                    threshold=threshold_array,
                    description='Transitions between down and up states with '\
                            +'labels "UP" and "DOWN". '\
                            +'Annotated with the channel id ("channels").')

    for key in asig.array_annotations.keys():
        evt_ann = {key : asig.array_annotations[key][all_channels[sort_idx]]}
        evt.array_annotations.update(evt_ann)

    remove_annotations(asig, del_keys=['nix_name', 'neo_name'])
    evt.annotations.update(asig.annotations)
    return evt
예제 #4
0
def detect_minima(asig, order):
    signal = asig.as_array()
    t_idx, channel_idx = argrelmin(signal, order=order, axis=0)

    sort_idx = np.argsort(t_idx)

    evt = neo.Event(times=asig.times[t_idx[sort_idx]],
                    labels=['UP'] * len(t_idx),
                    name='Transitions',
                    minima_order=order,
                    array_annotations={'channels': channel_idx[sort_idx]})

    for key in asig.array_annotations.keys():
        evt_ann = {key: asig.array_annotations[key][channel_idx[sort_idx]]}
        evt.array_annotations.update(evt_ann)

    remove_annotations(asig, del_keys=['nix_name', 'neo_name'])
    evt.annotations.update(asig.annotations)
    return evt
예제 #5
0
def _load_flickr30k(dataroot, img_id2idx, bbox, pos_boxes):
    """Load entries

    img_id2idx: dict {img_id -> val} val can be used to retrieve image or features
    dataroot: root path of dataset
    name: 'train', 'val', 'test-dev2015', test2015'
    """
    pattern_phrase = r'\[(.*?)\]'
    pattern_no = r'\/EN\#(\d+)'

    missing_entity_count = dict()
    multibox_entity_count = 0

    entries = []
    for image_id, idx in img_id2idx.items():

        phrase_file = os.path.join(
            dataroot, 'Flickr30kEntities/Sentences/%d.txt' % image_id)
        anno_file = os.path.join(
            dataroot, 'Flickr30kEntities/Annotations/%d.xml' % image_id)

        with open(phrase_file, 'r', encoding='utf-8') as f:
            sents = [x.strip() for x in f]

        # Parse Annotation
        root = parse(anno_file).getroot()
        obj_elems = root.findall('./object')
        pos_box = pos_boxes[idx]
        bboxes = bbox[pos_box[0]:pos_box[1]]
        target_bboxes = {}

        for elem in obj_elems:
            if elem.find('bndbox') == None or len(elem.find('bndbox')) == 0:
                continue
            left = int(elem.findtext('./bndbox/xmin'))
            top = int(elem.findtext('./bndbox/ymin'))
            right = int(elem.findtext('./bndbox/xmax'))
            bottom = int(elem.findtext('./bndbox/ymax'))
            assert 0 < left and 0 < top

            for name in elem.findall('name'):
                entity_id = int(name.text)
                assert 0 < entity_id
                if not entity_id in target_bboxes:
                    target_bboxes[entity_id] = []
                else:
                    multibox_entity_count += 1
                target_bboxes[entity_id].append([left, top, right, bottom])

        # Parse Sentence
        for sent_id, sent in enumerate(sents):
            sentence = utils.remove_annotations(sent)
            entities = re.findall(pattern_phrase, sent)
            entity_indices = []
            target_indices = []
            entity_ids = []
            entity_types = []

            for entity_i, entity in enumerate(entities):
                info, phrase = entity.split(' ', 1)
                entity_id = int(re.findall(pattern_no, info)[0])
                entity_type = info.split('/')[2:]

                entity_idx = utils.find_sublist(sentence.split(' '),
                                                phrase.split(' '))
                assert 0 <= entity_idx

                if not entity_id in target_bboxes:
                    if entity_id >= 0:
                        missing_entity_count[
                            entity_type[0]] = missing_entity_count.get(
                                entity_type[0], 0) + 1
                    continue

                assert 0 < entity_id

                entity_ids.append(entity_id)
                entity_types.append(entity_type)

                target_idx = utils.get_match_index(target_bboxes[entity_id],
                                                   bboxes)
                entity_indices.append(entity_idx)
                target_indices.append(target_idx)

            if 0 == len(entity_ids):
                continue

            entries.append(
                _create_flickr_entry(idx, sentence, entity_indices,
                                     target_indices, entity_ids, entity_types))

    if 0 < len(missing_entity_count.keys()):
        print('missing_entity_count=')
        print(missing_entity_count)
        print('multibox_entity_count=%d' % multibox_entity_count)

    return entries
예제 #6
0
def detect_minima(asig, order, interpolation_points, interpolation):
    signal = asig.as_array()
    sampling_time = asig.times[1] - asig.times[0]

    t_idx, channel_idx = argrelmin(signal, order=order, axis=0)
    t_idx_max, channel_idx_max = argrelmax(signal, order=order, axis=0)

    if interpolation:

        # minimum
        fitted_idx_times = np.zeros([len(t_idx)])
        start_arr = t_idx - int(interpolation_points / 2)
        start_arr = np.where(start_arr > 0, start_arr, 0)
        stop_arr = start_arr + int(interpolation_points)
        start_arr = np.where(stop_arr < len(signal), start_arr,
                             len(signal) - interpolation_points - 1)
        stop_arr = np.where(stop_arr < len(signal), stop_arr, len(signal) - 1)

        signal_arr = np.empty((interpolation_points, len(start_arr)))
        signal_arr.fill(np.nan)

        for i, (start, stop,
                channel_i) in enumerate(zip(start_arr, stop_arr, channel_idx)):
            signal_arr[:, i] = signal[start:stop, channel_i]

        X_temp = range(0, interpolation_points)
        params = np.polyfit(X_temp, signal_arr, 2)

        min_pos = -params[1, :] / (2 * params[0, :]) + start_arr
        min_pos = np.where(min_pos > 0, min_pos, 0)
        minimum_times = min_pos * sampling_time

        minimum_value = params[0, :] * (
            -params[1, :] / (2 * params[0, :]))**2 + params[1, :] * (
                -params[1, :] / (2 * params[0, :])) + params[2, :]

        # maximum
        fitted_idx_times = np.zeros([len(t_idx_max)])
        start_arr = t_idx_max - int(interpolation_points / 2)
        start_arr = np.where(start_arr > 0, start_arr, 0)
        stop_arr = start_arr + int(interpolation_points)
        start_arr = np.where(stop_arr < len(signal), start_arr,
                             len(signal) - interpolation_points - 1)
        stop_arr = np.where(stop_arr < len(signal), stop_arr, len(signal) - 1)

        signal_arr = np.empty((interpolation_points, len(start_arr)))
        signal_arr.fill(np.nan)

        for i, (start, stop, channel_i) in enumerate(
                zip(start_arr, stop_arr, channel_idx_max)):
            signal_arr[:, i] = signal[start:stop, channel_i]

        X_temp = range(0, interpolation_points)
        params = np.polyfit(X_temp, signal_arr, 2)

        max_pos = -params[1, :] / (2 * params[0, :]) + start_arr
        max_pos = np.where(max_pos > 0, max_pos, 0)

        maximum_times = max_pos * sampling_time
        maximum_value = params[0, :] * (
            -params[1, :] / (2 * params[0, :]))**2 + params[1, :] * (
                -params[1, :] / (2 * params[0, :])) + params[2, :]

        amplitude = []
        ch_arr = []
        min_arr = []

        for i in range(len(min_pos)):  # for each transition
            ch = channel_idx[i]
            min_time = min_pos[i]
            #print('min time', min_time)
            min_value = minimum_value[i]
            #print('min value', min_value)
            #print('signal', signal[t_idx[i]][ch])

            ch_idx = np.where(channel_idx_max == ch)[0]
            max_time = max_pos[ch_idx]
            max_value = maximum_value[ch_idx]
            time_idx = np.where(max_time > min_time)[0]
            times = max_time[time_idx]

            max_value = max_value[time_idx]
            #print('MAX VALUES', max_value)

            try:
                idx_min_ampl = np.argmin(times)
                amplitude.append(max_value[idx_min_ampl] - min_value)
                #print('time', times[idx_min_ampl])
                #print('max signal', signal[max_value[idx_min_ampl]][ch])

            except (IndexError, ValueError) as e:
                amplitude.append(max_value - min_value)
                #print('time', times)

            ch_arr.append(ch)
        #sio.savemat('/Users/chiaradeluca/Desktop/PhD/Wavescalephant/wavescalephant-master/Output/MF_LENS/stage03_trigger_detection/Amplitude.mat', {'Amplitude': amplitude, 'Ch': ch_arr, 'times': minimum_times})
        arr_dict = {
            'Amplitude': amplitude,
            'Ch': ch_arr,
            'times': minimum_times
        }

    else:
        minimum_times = asig.times[t_idx]
        maximum_times = asig.times[t_idx_max]
        amplitude = maximum_times - minimum_times
        ch_arr = channel_idx
        arr_dict = {
            'Amplitude': amplitude,
            'Ch': ch_arr,
            'times': minimum_times
        }

    sort_idx = np.argsort(minimum_times)

    evt = neo.Event(times=minimum_times[sort_idx],
                    labels=['UP'] * len(minimum_times),
                    name='Transitions',
                    minima_order=order,
                    use_quadtratic_interpolation=interpolation,
                    num_interpolation_points=interpolation_points,
                    array_annotations={'channels': channel_idx[sort_idx]})

    for key in asig.array_annotations.keys():
        evt_ann = {key: asig.array_annotations[key][channel_idx[sort_idx]]}
        evt.array_annotations.update(evt_ann)

    remove_annotations(asig, del_keys=['nix_name', 'neo_name'])
    evt.annotations.update(asig.annotations)
    return evt, arr_dict
예제 #7
0
def detect_minima(asig, order, interpolation_points, interpolation, threshold_fraction,  Min_Peak_Distance):
        
    signal = asig.as_array()
    times = asig.times
    sampling_time = asig.times[1] - asig.times[0]
    min_idx, channel_idx_minima = argrelmin(signal, order=order, axis=0)

    amplitude_span = np.max(signal, axis = 0) - np.min(signal, axis = 0)
    threshold = np.min(signal, axis = 0) + threshold_fraction*(amplitude_span)
    
    
    min_time_idx = []
    channel_idx = []
    for ch in range(len(signal[0])):
        peaks, _ = find_peaks(signal.T[ch], height=threshold[ch], distance = np.int32(Min_Peak_Distance/sampling_time))#, prominence=prominence)
        mins = min_idx[np.where(channel_idx_minima == ch)[0]]

        clean_mins = np.array([], dtype=int)
        for i, peak in enumerate(peaks):
            distance_to_peak = times[peak] - times[mins]
            distance_to_peak = distance_to_peak[distance_to_peak > 0]
            if distance_to_peak.size:
                trans_idx = np.argmin(distance_to_peak)
                clean_mins = np.append(clean_mins, mins[trans_idx])

        min_time_idx.extend(clean_mins)
        channel_idx.extend(list(np.ones(len(clean_mins))*ch))
        
    
    # compute local minima times.
    if interpolation:
        # parabolic fit around the local minima
        fitted_idx_times = np.zeros([len(min_time_idx)])
        start_arr = min_time_idx - 1 #int(interpolation_points/2)
        start_arr = np.where(start_arr > 0, start_arr, 0)
        stop_arr = start_arr + int(interpolation_points)

        start_arr = np.where(stop_arr < len(signal), start_arr, len(signal)-interpolation_points-1)
        stop_arr = np.where(stop_arr < len(signal), stop_arr, len(signal)-1)

        signal_arr = np.empty((interpolation_points, len(start_arr)))
        signal_arr.fill(np.nan)

        for i, (start, stop, channel_i) in enumerate(zip(start_arr, stop_arr, channel_idx)):
            signal_arr[:,i] = signal[start:stop, channel_i]

        X_temp = range(0, interpolation_points)
        params = np.polyfit(X_temp, signal_arr, 2)

        min_pos = -params[1,:] / (2*params[0,:]) + start_arr
        min_pos = np.where(min_pos > 0, min_pos, 0)
        minimum_times = min_pos * sampling_time
        minimum_value = params[0,:]*( -params[1,:] / (2*params[0,:]) )**2 + params[1,:]*( -params[1,:] / (2*params[0,:]) ) + params[2,:]

        minimum_times[np.where(minimum_times > asig.t_stop)[0]] = asig.t_stop
    else:
        minimum_times = asig.times[min_time_idx]
    
    ###################################
    sort_idx = np.argsort(minimum_times)
    channel_idx = np.int32(channel_idx)
    
    evt = neo.Event(times=minimum_times[sort_idx],
                    labels=['UP'] * len(minimum_times),
                    name='Transitions',
                    minima_order=order,
                    num_interpolation_points=interpolation_points,
                    array_annotations={'channels':channel_idx[sort_idx]})

    for key in asig.array_annotations.keys():
        evt_ann = {key : asig.array_annotations[key][channel_idx[sort_idx]]}
        evt.array_annotations.update(evt_ann)

    remove_annotations(asig, del_keys=['nix_name', 'neo_name'])
    evt.annotations.update(asig.annotations)

    return evt
예제 #8
0
    Label = []
    Pixels = []

    for i in range(0, len(Wave)):
        Times.extend(Wave[i]['times'].magnitude)
        Label.extend(np.ones([len(Wave[i]['ndx'])]) * i)
        Pixels.extend(Wave[i]['ch'])

    Label = [str(i) for i in Label]
    Times = Times * (Wave[0]['times'].units)
    waves = neo.Event(times=Times.rescale(pq.s),
                    labels=Label,
                    name='Wavefronts',
                    array_annotations={'channels':Pixels,
                                       'x_coords':[p % DIM_Y for p in Pixels],
                                       'y_coords':[np.floor(p/DIM_Y) for p in Pixels]},
                    description='Transitions from down to up states. '\
                               +'Labels are ids of wavefronts. '
                               +'Annotated with the channel id ("channels") and '\
                               +'its position ("x_coords", "y_coords").',
                    spatial_scale = UpTrans_Evt.annotations['spatial_scale'])

    #remove_annotations(waves, del_keys=['nix_name', 'neo_name'])
    waves.annotations.update(evts.annotations)
    remove_annotations(waves, del_keys=['nix_name', 'neo_name'])

    block.segments[0].events.append(waves)
    remove_annotations(waves, del_keys=['nix_name', 'neo_name'])

    write_neo(args.output, block)
예제 #9
0
        return X * signal.units
    elif isinstance(signal, np.ndarray):
        return X


if __name__ == '__main__':
    CLI = argparse.ArgumentParser()
    CLI.add_argument("--data", nargs='?', type=str)
    CLI.add_argument("--output", nargs='?', type=str)
    CLI.add_argument("--order", nargs='?', type=int)
    args = CLI.parse_args()

    # load images
    with neo.NixIO(args.data) as io:
        block = io.read_block()

    check_analogsignal_shape(block.segments[0].analogsignals)
    remove_annotations([block] + block.segments +
                       block.segments[0].analogsignals)

    asig = detrending(block.segments[0].analogsignals[0], args.order)

    # save processed data
    asig.name += ""
    asig.description += "Detrended by order {} ({}). "\
                        .format(args.order, os.path.basename(__file__))
    block.segments[0].analogsignals[0] = asig

    with neo.NixIO(args.output) as io:
        io.write(block)
예제 #10
0
def detect_transitions(asig, transition_phase):
    # ToDo: replace with elephant function
    signal = asig.as_array()
    dim_t, channel_num = signal.shape

    hilbert_signal = hilbert(signal, axis=0)
    hilbert_phase = np.angle(hilbert_signal)

    def _detect_phase_crossings(phase):
        # detect phase crossings from below phase to above phase
        is_larger = hilbert_phase > phase
        positive_crossings = ~is_larger & np.roll(is_larger, -1, axis=0)
        positive_crossings = positive_crossings[:-1]

        # select phases within [-pi, pi]
        real_crossings = np.real(hilbert_signal[:-1]) > np.imag(
            hilbert_signal[:-1])
        crossings = real_crossings & positive_crossings

        # arrange transitions times per channel
        times = asig.times[:-1]
        crossings_list = [
            times[crossings[:, channel]].magnitude
            for channel in range(channel_num)
        ]
        return crossings_list

    # UP transitions: A change of the hilbert phase from < transtion_phase
    #                 to > transition_phase, followed by a peak (phase = 0).

    peaks = _detect_phase_crossings(0)
    start = time.time()
    transitions = _detect_phase_crossings(transition_phase)

    up_transitions = np.array([])
    channels = np.array([], dtype=int)

    for channel_id, (channel_peaks,
                     channel_transitions) in enumerate(zip(peaks,
                                                           transitions)):
        channel_up_transitions = np.array([])
        if channel_peaks is not None:
            for peak in channel_peaks:
                distance_to_peak = peak - np.array(channel_transitions)
                distance_to_peak = distance_to_peak[distance_to_peak > 0]
                if distance_to_peak.size:
                    trans_idx = np.argmin(distance_to_peak)
                    channel_up_transitions = np.append(
                        channel_up_transitions, channel_transitions[trans_idx])
        channel_up_transitions = np.unique(channel_up_transitions)
        up_transitions = np.append(up_transitions, channel_up_transitions)
        channels = np.append(
            channels,
            np.ones_like(channel_up_transitions, dtype=int) * channel_id)

    # save transitions as Event labels:'UP', array_annotations: channels
    sort_idx = np.argsort(up_transitions)

    evt = neo.Event(times=up_transitions[sort_idx]*asig.times.units,
                    labels=['UP'] * len(up_transitions),
                    name='Transitions',
                    array_annotations={'channels':channels[sort_idx]},
                    hilbert_transition_phase=transition_phase,
                    description='Transitions from down to up states. '\
                               +'annotated with the channel id ("channels").')

    for key in asig.array_annotations.keys():
        evt_ann = {key: asig.array_annotations[key][channels[sort_idx]]}
        evt.array_annotations.update(evt_ann)

    remove_annotations(asig, del_keys=['nix_name', 'neo_name'])
    evt.annotations.update(asig.annotations)
    return evt
예제 #11
0
def detect_transitions(asig, transition_phase):
    # ToDo: replace with elephant function
    signal = asig.as_array()
    dim_t, channel_num = signal.shape

    hilbert_signal = hilbert(signal, axis=0)
    hilbert_phase = np.angle(hilbert_signal)

    def _detect_phase_crossings(phase):
        t_idx, channel_idx = np.where(
            np.diff(np.signbit(hilbert_phase - phase), axis=0))
        crossings = [None] * channel_num
        for ti, channel in zip(t_idx, channel_idx):
            # select only crossings from negative to positive
            if (hilbert_phase-phase)[ti][channel] <= 0 \
            and np.real(hilbert_signal[ti][channel]) \
              > np.imag(hilbert_signal[ti][channel]):
                if crossings[channel] is None:
                    crossings[channel] = np.array([])
                if asig.times[ti].magnitude not in crossings[channel]:
                    crossings[channel] = np.append(crossings[channel],
                                                   asig.times[ti].magnitude)
        return crossings

    # UP transitions: A change of the hilbert phase from < transtion_phase
    #                 to > transition_phase, followed by a peak (phase = 0).

    peaks = _detect_phase_crossings(0)
    transitions = _detect_phase_crossings(transition_phase)
    up_transitions = np.array([])
    channels = np.array([], dtype=int)

    for channel_id, (channel_peaks,
                     channel_transitions) in enumerate(zip(peaks,
                                                           transitions)):
        channel_up_transitions = np.array([])
        if channel_peaks is not None:
            for peak in channel_peaks:
                distance_to_peak = peak - np.array(channel_transitions)
                distance_to_peak = distance_to_peak[distance_to_peak > 0]
                if distance_to_peak.size:
                    trans_idx = np.argmin(distance_to_peak)
                    channel_up_transitions = np.append(
                        channel_up_transitions, channel_transitions[trans_idx])
        channel_up_transitions = np.unique(channel_up_transitions)
        up_transitions = np.append(up_transitions, channel_up_transitions)
        channels = np.append(
            channels,
            np.ones_like(channel_up_transitions, dtype=int) * channel_id)

    # save transitions as Event labels:'UP', array_annotations: channels
    sort_idx = np.argsort(up_transitions)

    evt = neo.Event(times=up_transitions[sort_idx]*asig.times.units,
                     labels=['UP'] * len(up_transitions),
                     name='Transitions',
                     array_annotations={'channels':channels[sort_idx]},
                     hilbert_transition_phase=transition_phase,
                     description='Transitions from down to up states. '\
                                +'annotated with the channel id ("channels").')

    for key in asig.array_annotations.keys():
        evt_ann = {key: asig.array_annotations[key][channels[sort_idx]]}
        evt.array_annotations.update(evt_ann)

    remove_annotations(asig, del_keys=['nix_name', 'neo_name'])
    evt.annotations.update(asig.annotations)
    return evt
예제 #12
0
def _load_kairos(dataset,
                 img_id2idx,
                 bbox,
                 pos_boxes,
                 topic_doc_json,
                 topic=None):
    """Load entries

    img_id2idx: dict {img_id -> val} val can be used to retrieve image or features
    dataroot: root path of dataset
    name: 'train', 'val', 'test-dev2015', test2015'
    """
    pattern_phrase = r'\[\/EN\#(.*?)\]'
    pattern_no = r'\/EN\#(\d+)'

    missing_entity_count = dict()
    multibox_entity_count = 0

    entries = []

    if topic is None:
        topic = dataset

    for image_id, idx in img_id2idx.items():

        # anno_file = f'data/{dataset}/annotations/{image_id}.xml'

        for phrase_id in topic_doc_json[topic]:

            phrase_file = f'data/{dataset}/json_output/ent_sents/{phrase_id}.txt'

            with open(phrase_file, 'r', encoding='utf-8') as f:
                sents = [x.strip() for x in f]

            # Parse Annotation
            # root = parse(anno_file).getroot()
            # obj_elems = root.findall('./object')
            # pos_box = pos_boxes[idx]
            # bboxes = bbox[pos_box[0]:pos_box[1]]
            # target_bboxes = {}

            # for elem in obj_elems:
            #     if elem.find('bndbox') == None or len(elem.find('bndbox')) == 0:
            #         continue
            #     left = int(elem.findtext('./bndbox/xmin'))
            #     top = int(elem.findtext('./bndbox/ymin'))
            #     right = int(elem.findtext('./bndbox/xmax'))
            #     bottom = int(elem.findtext('./bndbox/ymax'))
            #     assert 0 <= left and 0 <= bottom, f"[{left}, {top}, {right}, {bottom}]"

            #     for name in elem.findall('name'):
            #         entity_id = int(name.text)
            #         assert 0 < entity_id
            #         if not entity_id in target_bboxes:
            #             target_bboxes[entity_id] = []
            #         else:
            #             multibox_entity_count += 1
            #         target_bboxes[entity_id].append([left, top, right, bottom])

            # Parse Sentence
            for sent_id, sent in enumerate(sents):
                sentence = utils.remove_annotations(sent)
                entities = re.findall(pattern_phrase, sent)
                entity_indices = []
                # target_indices = []
                entity_ids = []
                entity_types = []
                # pdb.set_trace()

                for entity_i, entity in enumerate(entities):
                    entity = "/EN#" + entity
                    info, phrase = entity.split(' ', 1)
                    try:
                        entity_id = int(re.findall(pattern_no, info)[0])
                    except:
                        print(
                            f"entity_id = {entity_id}, entity = {entity} \nsentence = {sentence}, info = {info}"
                        )
                        raise Exception("entry creation failed")
                    entity_type = info.split('/')[2:]

                    entity_idx = utils.find_sublist(sentence.split(' '),
                                                    phrase.split(' '))
                    try:
                        assert 0 <= entity_idx, f"entity_idx = {entity_idx}, entity = {phrase} \nsentence = {sentence}, info = {info}"
                    except:
                        continue

                    # if not entity_id in target_bboxes:
                    #     if entity_id >= 0:
                    #         missing_entity_count[entity_type[0]] = missing_entity_count.get(entity_type[0], 0) + 1

                    assert 0 < entity_id

                    entity_ids.append(entity_id)
                    entity_types.append(entity_type)

                    # target_idx = utils.get_match_index(target_bboxes[entity_id], bboxes)
                    entity_indices.append(entity_idx)
                    # target_indices.append(target_idx)

                if 0 == len(entity_ids):
                    continue
                try:
                    entry = _create_kairos_entry(idx,
                                                 f"{phrase_id}-s{sent_id}",
                                                 sentence, entity_indices,
                                                 entity_ids, entity_types)
                except:
                    print(idx, sent_id, sentence, sent)
                    raise Exception("entry creation failed")

                entries.append(entry)

    if 0 < len(missing_entity_count.keys()):
        print('missing_entity_count=')
        print(missing_entity_count)
        print('multibox_entity_count=%d' % multibox_entity_count)

    return entries