Ejemplo n.º 1
0
def data_batch(seq_pair_data,
               target_key,
               batch_size,
               set_normal=None,
               rnn=False):
    seq_pair_data.prune_keys(input_remain_keys=constant.INPUT_KEYWORDS,
                             target_remain_keys=(target_key, ))

    if set_normal is None: pass
    else: seq_pair_data.normalize(set_normal)

    window_config = constant.WINDOW_CONFIG
    frame_pair_data = frame.FlightFramePairData(
        seq_pair_data,
        input_win_len=window_config['input']['length'],
        target_win_len=window_config['target']['length'],
        input_win_offset=window_config['input']['offset_length'],
        target_win_offset=window_config['target']['offset_length'],
        input_win_offset_rate=window_config['input']['offset_rate'],
        target_win_offset_rate=window_config['target']['offset_rate'],
        input_pad=window_config['input']['padding'],
        target_pad=window_config['target']['padding'])

    if rnn:
        batch_loader = batch.TimestepPairDataLoader(frame_pair_data,
                                                    timestep=batch_size)
    else:
        batch_loader = batch.FramePairDataLoader(frame_pair_data,
                                                 batch_size=batch_size,
                                                 shuffle=False,
                                                 drop_last=False)

    return batch_loader, seq_pair_data.length_dict()
Ejemplo n.º 2
0
    def data_batch(input_dir,
                   phone_type,
                   target_key,
                   set_normal,
                   batch_size=16,
                   shuffle=False,
                   window_config=constant.WINDOW_CONFIG):
        # load sequence data
        seq_pair_data = frame.FlightSequencePairData(
            '{}/{}_g1000.sequence'.format(input_dir, phone_type))
        time_col = seq_pair_data.meta_target.clone()
        time_col.prune_keys(remain_keys=['time'])
        if not args['--all']:
            time_col.prune_identifier(remain_identifier=TEST_LIST)

        # remain only necessary keywords
        if args['--stratux'] is None:
            seq_pair_data.prune_keys(input_remain_keys=constant.INPUT_KEYWORDS,
                                     target_remain_keys=(target_key, ))
        else:
            lv = args['--stratux']
            input_key = target_key.split('_')[1]
            if lv == 0: input_remain_keys = (input_key, )
            elif lv == 1: input_remain_keys = ('alt', 'lat', 'long', input_key)
            elif lv == 2:
                input_remain_keys = ('alt', 'lat', 'long', 'pitch', 'roll',
                                     'heading')
            else:
                raise NotImplementedError
            seq_pair_data.prune_keys(input_remain_keys=input_remain_keys,
                                     target_remain_keys=(target_key, ))

        if not args['--all']:
            seq_pair_data.prune_identifier(remain_identifier=TEST_LIST)

        # normalize sequence data on time domain
        seq_pair_data.normalize(set_normal)

        # divide sequence data into frames
        frame_pair_data = frame.FlightFramePairData(
            seq_pair_data,
            input_win_len=window_config['input']['length'],
            target_win_len=window_config['target']['length'],
            input_win_offset=window_config['input']['offset_length'],
            target_win_offset=window_config['target']['offset_length'],
            input_win_offset_rate=window_config['input']['offset_rate'],
            target_win_offset_rate=window_config['target']['offset_rate'],
            input_pad=window_config['input']['padding'],
            target_pad=window_config['target']['padding'])

        batch_loader = batch.FramePairDataLoader(frame_pair_data,
                                                 batch_size=batch_size,
                                                 shuffle=shuffle,
                                                 drop_last=False)

        return batch_loader, seq_pair_data.length_dict(), time_col
Ejemplo n.º 3
0
    def to_frame_pair(self):
        """Extract data in the class into a frame pair data class"""
        if self.drop_last:
            logging.warning('Reconstruct from drop last batches')
            raise NotImplementedError

        # initialize frame data buffer
        frame_input = []
        frame_target = []
        frame_identifier = []

        # traverse all batches to construct flight data
        ptr = 0
        while ptr < len(self):
            # fetch current flight information
            current_flight = self.flight_identifier(self.identifier[ptr])
            num_batches = self.num_batches_dict[current_flight]

            # initialize current flight data buffer
            input_buffer = []
            target_buffer = []

            for i in range(num_batches):
                # batches must be ordered
                assert self.identifier[ptr + i] == '{}_{}'.format(
                    current_flight, i)

                # append to flight buffer
                input_buffer.append(self.input_set[ptr + i])
                target_buffer.append(self.target_set[ptr + i])

            # move to next flight
            ptr += num_batches

            # make sure that all batches of current flight have been fetched
            assert ptr >= len(self) or (not current_flight
                                        in self.identifier[ptr])

            # append to frame buffer
            frame_input.append(np.concatenate(input_buffer, axis=0))
            frame_target.append(np.concatenate(target_buffer, axis=0))
            frame_identifier.append(current_flight)

        return frame.FlightFramePairData((
            frame_input,
            frame_target,
            self.input_keys,
            self.target_keys,
            frame_identifier,
            self.appendix['window_dict'],
        ))
Ejemplo n.º 4
0
def data_batch(input_dir, output_dir, phone_type,
               target_key, rnn=False, batch_size=16, shuffle=False,
               select_rate=(0.0, 1.0, 'large'),
               window_config=None,
               set_normal=None,
               return_len_dict=False, args=None):
    """Generate batch loader

    Args
    ----
    input_dir : str
        root directory to load data
    output_dir : str
        root directory to save data
    phone_type : str
        phone type as input
    target_key : str
        target keyword
    batch_size : int
        batch size
    shuffle : bool
        if should shuffle batch loader
    select_rate : tuple
        proportion to select from original data
        large mode will extend select range on both head and tail
        small mode will truncate select range on both head and tail
    window_config : dict
        configuration of frame window
    set_normal : None or frame.BasePairData.Normalization
        normalize with given argument or return new normalization
    return_len_dict : bool
        if return length dict for future conversion back to sequence
    args : dict
        global arguments

    Returns
    -------
    batch_loader : batch.FramePairDataLoader
        batch loader
    normal : frame.BasePairData.Normalization
        normalization parameters
    len_dict : dict
        dict of length of each sequence data

    It will load sequence pair data, and convert into batches.
    It will discard useless keywords, and can truncate flights to generate different data loader.

    """
    # load sequence data
    seq_pair_data = frame.FlightSequencePairData('{}/{}_g1000.sequence'.format(input_dir, phone_type))

    if args['--limit'] is not None:
        seq_pair_data.prune_identifier(seq_pair_data.identifier[:args['--limit']])
    else:
        pass

    # extend hazardous state for g1000
    # It must locate after alignment and interpolation
    if args['--keyword'] == 'hazard':
        meta_target = seq_pair_data.meta_target
        meta_target.append_hazard(threshold=args['--threshold'], roll_key='roll')
        seq_pair_data.update_target(meta_target)

    # remain only necessary keywords
    if args['--stratux'] is None:
        seq_pair_data.prune_keys(
            input_remain_keys=constant.INPUT_KEYWORDS,
            target_remain_keys=(target_key,))
    else:
        lv = args['--stratux']
        logging.info("Stratux Level {} Batch".format(lv))
        if args['--keyword'] == 'hazard':
            input_key = 'roll'
        else:
            input_key = target_key.split('_')[1]
        if lv == 0: input_remain_keys = (input_key,)
        elif lv == 1: input_remain_keys = ('alt', 'lat', 'long', input_key)
        elif lv == 2: input_remain_keys = ('alt', 'lat', 'long', 'pitch', 'roll', 'heading')
        else: raise NotImplementedError
        seq_pair_data.prune_keys(
            input_remain_keys=input_remain_keys,
            target_remain_keys=(target_key,))

    # remain only selective range of flights
    num_flights = len(seq_pair_data)
    begin, end  = select_rate[0:2]
    if select_rate[2] == 'large':
        begin = int(math.floor(num_flights * begin))
        end   = int(math.ceil(num_flights * end))
    elif select_rate[2] == 'small':
        begin = int(math.ceil(num_flights * begin))
        end   = int(math.floor(num_flights * end))
    else:
        raise NotImplementedError
    if begin == end:
        if end == num_flights:
            begin -= 1
        else:
            end += 1

    seq_pair_data.prune_identifier(
        remain_identifier=seq_pair_data.identifier[begin:end])

    # normalize sequence data on time domain
    if not args['--freq'] and not args['--no-normal']:
        if set_normal: seq_pair_data.normalize(set_normal)
        else: seq_pair_data.normalize()
        normal = seq_pair_data.normal
    else:
        normal = None

    # divide sequence data into frames
    frame_pair_data = frame.FlightFramePairData(
        seq_pair_data,
        input_win_len =window_config['input'] ['length'],
        target_win_len=window_config['target']['length'],
        input_win_offset =window_config['input'] ['offset_length'],
        target_win_offset=window_config['target']['offset_length'],
        input_win_offset_rate =window_config['input'] ['offset_rate'],
        target_win_offset_rate=window_config['target']['offset_rate'],
        input_pad =window_config['input'] ['padding'],
        target_pad=window_config['target']['padding'])

    # transform to frequency domain and normalize
    if args['--freq']:
        if args['--freq'] == 'haar':
            frame_pair_data = frame.FlightFramePairData.time_to_haar(
                frame_pair_data, concat=True)
            if not args['--no-normal']:
                if set_normal: frame_pair_data.normalize(set_normal)
                else: frame_pair_data.normalize()
                normal = frame_pair_data.normal
            else:
                normal = None
        else:
            raise NotImplementedError

    # generate batch loader
    if rnn:
        batch_loader = batch.TimestepPairDataLoader(frame_pair_data, timestep=batch_size)
    else:
        batch_loader = batch.FramePairDataLoader(
            frame_pair_data, batch_size=batch_size, shuffle=shuffle, drop_last=False)

    if not return_len_dict: return batch_loader, normal
    else: return batch_loader, normal, seq_pair_data.length_dict()