예제 #1
0
파일: test.py 프로젝트: PurdueMINDS/SAGA
    def concat_prediction(prediction1, prediction2):
        # link data
        dataset1 = prediction1.dataset
        keys1 = prediction1.keys
        identifier1 = prediction1.identifier

        dataset2 = prediction2.dataset
        keys2 = prediction2.keys
        identifier2 = prediction2.identifier

        # predictions should match each other
        assert identifier1 == identifier2

        # keywords should not overlap
        assert set(keys1) & set(keys2) == set()

        # concatenate data
        dataset = []
        identifier = []
        for i in range(len(dataset1)):
            # identifier should match each other
            assert identifier1[i] == identifier2[i]

            dataset.append(np.concatenate([dataset1[i], dataset2[i]], axis=1))
            identifier.append(identifier1[i])

        keys = keys1 + keys2

        return flight.FlightPruneData((dataset, keys, identifier))
예제 #2
0
파일: anime.py 프로젝트: PurdueMINDS/SAGA
def process_data(dir, phone, date):
    def process_g1000():
        raw_g1000_data = raw.RawG1000Data('{}/g1000'.format(dir), want=[date])
        ext_g1000_data = flight.FlightExtensionData(raw_g1000_data)

        ext_g1000_data.append_num_diff(key='alt',
                                       new_key='spd_alt',
                                       step=5,
                                       pad='repeat_base')
        ext_g1000_data.append_num_diff(key='lat',
                                       new_key='spd_lat',
                                       step=5,
                                       pad='repeat_base')
        ext_g1000_data.append_num_diff(key='long',
                                       new_key='spd_long',
                                       step=5,
                                       pad='repeat_base')
        ext_g1000_data.append_ground_speed(spd_lat_key='spd_lat',
                                           spd_long_key='spd_long',
                                           new_key='spd_gd')

        ext_g1000_data.append_num_diff(key='pitch',
                                       new_key='spd_pitch',
                                       step=5,
                                       pad='repeat_base')
        ext_g1000_data.append_num_diff(key='roll',
                                       new_key='spd_roll',
                                       step=5,
                                       pad='repeat_base')
        ext_g1000_data.append_deg_diff(key='heading',
                                       new_key='spd_heading',
                                       step=5,
                                       pad='repeat_base')

        ext_g1000_data.append_deg_sin(key='pitch', new_key='sin_pitch')
        ext_g1000_data.append_deg_sin(key='roll', new_key='sin_roll')
        ext_g1000_data.append_deg_sin(key='heading', new_key='sin_heading')

        ext_g1000_data.append_deg_cos(key='pitch', new_key='cos_pitch')
        ext_g1000_data.append_deg_cos(key='roll', new_key='cos_roll')
        ext_g1000_data.append_deg_cos(key='heading', new_key='cos_heading')

        return ext_g1000_data

    def process_phone():
        raw_phone_data = raw.RawPhoneData('{}/{}'.format(dir, phone),
                                          want=[date])
        ext_phone_data = flight.FlightExtensionData(raw_phone_data)

        ext_phone_data.append_num_diff(key='alt',
                                       new_key='spd_alt',
                                       step=5,
                                       pad='repeat_base')
        ext_phone_data.append_num_diff(key='lat',
                                       new_key='spd_lat',
                                       step=5,
                                       pad='repeat_base')
        ext_phone_data.append_num_diff(key='long',
                                       new_key='spd_long',
                                       step=5,
                                       pad='repeat_base')
        ext_phone_data.append_ground_speed(spd_lat_key='spd_lat',
                                           spd_long_key='spd_long',
                                           new_key='spd_gd')

        ext_phone_data.append_num_diff(key='spd_alt',
                                       new_key='acc_alt',
                                       step=1,
                                       pad='repeat_base')
        ext_phone_data.append_num_diff(key='spd_lat',
                                       new_key='acc_lat',
                                       step=1,
                                       pad='repeat_base')
        ext_phone_data.append_num_diff(key='spd_long',
                                       new_key='acc_long',
                                       step=1,
                                       pad='repeat_base')

        return ext_phone_data

    ext_g1000_data = process_g1000()
    ext_phone_data = process_phone()

    prn_g1000_data = flight.FlightPruneData(ext_g1000_data)
    prn_phone_data = flight.FlightPruneData(ext_phone_data)

    g1000_date = set([itr.split('_')[0] for itr in prn_g1000_data.identifier])
    phone_date = set(prn_phone_data.identifier)
    share_date = g1000_date & phone_date
    union_date = g1000_date | phone_date
    share_date, union_date = list(sorted(share_date)), list(sorted(union_date))

    for date in union_date:
        if date in share_date:
            logging.info("Detect Date - \033[32;1m{}\033[0m".format(date))
        elif date in g1000_date:
            logging.warning(
                "Detect Date - \033[31;1m{}\033[0m (G1000)".format(date))
        elif date in phone_date:
            logging.warning("Detect Date - \033[31;1m{}\033[0m ({})".format(
                date, phone_type))
        else:
            raise NotImplementedError

    # discard data not in the intersection
    g1000_discard = [
        itr for itr in prn_g1000_data.identifier
        if itr.split('_')[0] not in share_date
    ]
    prn_g1000_data.prune_identifier(discard_identifier=g1000_discard)
    prn_phone_data.prune_identifier(remain_identifier=share_date)

    prn_g1000_data.prune_identifier(discard_identifier=constant.HIZARD_FLIGHTS)
    prn_g1000_data.detect_parking(method='time')

    phone_requirment = prn_g1000_data.time_date_flights()
    prn_phone_data.prune_identifier(remain_identifier=phone_requirment.keys())
    prn_phone_data.detect_parking(method='time', time_flights=phone_requirment)

    prn_g1000_data.prune_parking()
    prn_phone_data.prune_parking()

    g1000_idt = set(prn_g1000_data.identifier)
    phone_idt = set(prn_phone_data.identifier)
    share_idt = g1000_idt & phone_idt
    union_idt = g1000_idt | phone_idt
    share_idt, union_idt = list(sorted(share_idt)), list(sorted(union_idt))

    for idt in union_idt:
        if idt in share_idt:
            logging.info("Valid Record: \033[32;1m{}\033[0m".format(idt))
        elif idt in g1000_idt:
            logging.warning(
                "Redundant Record: \033[31;1m{}\033[0m (G1000)".format(idt))
        elif idt in phone_idt:
            logging.warning(
                "Redundant Record: \033[31;1m{}\033[0m ({})".format(
                    idt, phone_type))
        else:
            raise NotImplementedError

    prn_g1000_data.prune_identifier(
        remain_identifier=prn_phone_data.identifier)

    seq_pair_data = frame.FlightSequencePairData(entity_input=prn_phone_data,
                                                 entity_target=prn_g1000_data)
    seq_pair_data.align_and_interpolate(match_keys=('alt', 'lat', 'long'))
    seq_pair_data.distribute()

    return seq_pair_data
예제 #3
0
파일: main.py 프로젝트: PurdueMINDS/SAGA
def data_flight(input_dir, output_dir, phone_type, args=None):
    """Generate flight data

    Args
    ----
    input_dir : str
        root directory to load data
    output_dir : str
        root directory to save data
    phone_type : str
        phone type as input
    args : dict
        global arguments

    Returns
    -------
    seq_pair_data : frame.FlightSequencePairData
        sequence pair data after all data processing

    It will load extended data, align and truncate data to remain only pure and necessary
    flight data, and normalize data for neural network.

    """
    # load extension data
    ext_g1000_data = flight.FlightExtensionData('{}/g1000.extension'.format(input_dir))
    ext_phone_data = flight.FlightExtensionData('{}/{}.extension'.format(input_dir, phone_type))

    # generate prune data
    prn_g1000_data = flight.FlightPruneData(ext_g1000_data)
    prn_phone_data = flight.FlightPruneData(ext_phone_data)

    # only focus on identifier intersection between g1000 and phone
    g1000_date = set([itr.split('_')[0] for itr in prn_g1000_data.identifier])
    phone_date = set(prn_phone_data.identifier)
    share_date = g1000_date & phone_date
    union_date = g1000_date | phone_date
    share_date, union_date = list(sorted(share_date)), list(sorted(union_date))

    for date in union_date:
        if date in share_date:
            logging.info("Detect Date - \033[32;1m{}\033[0m".format(date))
        elif date in g1000_date:
            logging.warning("Detect Date - \033[31;1m{}\033[0m (G1000)".format(date))
        elif date in phone_date:
            logging.warning("Detect Date - \033[31;1m{}\033[0m ({})".format(date, phone_type))
        else:
            raise NotImplementedError

    # discard data not in the intersection
    g1000_discard = [itr for itr in prn_g1000_data.identifier if itr.split('_')[0] not in share_date]
    prn_g1000_data.prune_identifier(discard_identifier=g1000_discard)
    prn_phone_data.prune_identifier(remain_identifier=share_date)

    # plot preview
    if args['--preview-plot']:
        prev_g1000_data = prn_g1000_data.clone()
        prev_phone_data = prn_phone_data.clone()
        prev_g1000_data.prune_keys(remain_keys=['alt', 'lat', 'long', 'time'])
        prev_phone_data.prune_keys(remain_keys=['alt', 'lat', 'long', 'time'])
        prev_g1000_data.plot('{}/preview/g1000'.format(output_dir))
        prev_phone_data.plot('{}/preview/{}'.format(output_dir, phone_type))

    # detect pure flight data for g1000 according to given requirement (no)
    prn_g1000_data.prune_identifier(discard_identifier=constant.HIZARD_FLIGHTS)
    prn_g1000_data.detect_parking(method='time')

    # detect pure flight data for phone according to given requirement
    phone_requirment = prn_g1000_data.time_date_flights()
    prn_phone_data.prune_identifier(remain_identifier=phone_requirment.keys())
    prn_phone_data.detect_parking(method='time', time_flights=phone_requirment)

    # plot parking criterion
    if not args['--no-plot']:
        prn_g1000_data.plot_parking_criterion('{}/park/g1000'.format(output_dir))
        prn_phone_data.plot_parking_criterion('{}/park/{}'.format(output_dir, phone_type))

    # prune parking data for both phone and g1000
    prn_g1000_data.prune_parking()
    prn_phone_data.prune_parking()

    # check if there are missing record from phone records
    g1000_idt = set(prn_g1000_data.identifier)
    phone_idt = set(prn_phone_data.identifier)
    share_idt = g1000_idt & phone_idt
    union_idt = g1000_idt | phone_idt
    share_idt, union_idt = list(sorted(share_idt)), list(sorted(union_idt))

    for idt in union_idt:
        if idt in share_idt:
            logging.info("Valid Record: \033[32;1m{}\033[0m".format(idt))
        elif idt in g1000_idt:
            logging.warning("Redundant Record: \033[31;1m{}\033[0m (G1000)".format(idt))
        elif idt in phone_idt:
            logging.warning("Redundant Record: \033[31;1m{}\033[0m ({})".format(idt, phone_type))
        else:
            raise NotImplementedError

    # It is possible phone data record less flights than g1000 on the same date
    # (e.g. not enough battery)
    prn_g1000_data.prune_identifier(remain_identifier=prn_phone_data.identifier)

    # align prune data
    seq_pair_data = frame.FlightSequencePairData(entity_input=prn_phone_data, entity_target=prn_g1000_data)
    seq_pair_data.align_and_interpolate(match_keys=('alt', 'lat', 'long'))
    seq_pair_data.distribute()

    # plot alignment criterion
    if not args['--no-plot']:
        seq_pair_data.plot_match_criterion('{}/wrap/{}_g1000'.format(output_dir, phone_type))

    # save sequence data
    seq_pair_data.save('{}/{}_g1000.sequence'.format(output_dir, phone_type))

    return seq_pair_data