def get_path_dict(data_dirs,
                  hparams,
                  config,
                  data_type,
                  n_test=None,
                  rng=np.random.RandomState(123)):

    # Load metadata:
    path_dict = {}
    for data_dir in data_dirs:  # ['datasets/moon\\data']
        paths = glob(
            "{}/*.npz".format(data_dir)
        )  # ['datasets/moon\\data\\001.0000.npz', 'datasets/moon\\data\\001.0001.npz', 'datasets/moon\\data\\001.0002.npz', ...]

        if data_type == 'train':
            rng.shuffle(
                paths
            )  # ['datasets/moon\\data\\012.0287.npz', 'datasets/moon\\data\\004.0215.npz', 'datasets/moon\\data\\003.0149.npz', ...]

        if not config.skip_path_filter:
            # items = parallel_run( get_frame, paths, desc="filter_by_min_max_frame_batch", parallel=True)  # [('datasets/moon\\data\\012.0287.npz', 130, 21), ('datasets/moon\\data\\003.0149.npz', 209, 37), ...]
            items = []
            for path in paths:
                item = get_frame(path)
                items.append(item)

            min_n_frame = hparams.min_n_frame  # 5*30
            max_n_frame = hparams.max_n_frame - 1  # 5*200 - 5

            # 다음 단계에서 data가 많이 떨어져 나감. 글자수가 짧은 것들이 탈락됨.
            new_items = [
                (path, n) for path, n, n_tokens in items
                if min_n_frame <= n <= max_n_frame
                and n_tokens >= hparams.min_tokens
            ]  # [('datasets/moon\\data\\004.0383.npz', 297), ('datasets/moon\\data\\003.0533.npz', 394),...]

            new_paths = [path for path, n in new_items]
            new_n_frames = [n for path, n in new_items]

            hours = frames_to_hours(new_n_frames, hparams)

            log(' [{}] Loaded metadata for {} examples ({:.2f} hours)'.format(
                data_dir, len(new_n_frames), hours))
            log(' [{}] Max length: {}'.format(data_dir, max(new_n_frames)))
            log(' [{}] Min length: {}'.format(data_dir, min(new_n_frames)))
        else:
            new_paths = paths

        # train용 data와 test용 data로 나눈다.
        if data_type == 'train':
            new_paths = new_paths[:-n_test]  # 끝에 있는 n_test(batch_size)를 제외한 모두
        elif data_type == 'test':
            new_paths = new_paths[-n_test:]  # 끝에 있는 n_test
        else:
            raise Exception(" [!] Unkown data_type: {}".format(data_type))

        path_dict[
            data_dir] = new_paths  # ['datasets/moon\\data\\001.0621.npz', 'datasets/moon\\data\\003.0229.npz', ...]

    return path_dict
Ejemplo n.º 2
0
def get_path_dict(data_dirs, hparams, config,data_type, n_test=None,rng=np.random.RandomState(123)):

    # Load metadata:
    path_dict = {}
    for data_dir in data_dirs:  # ['datasets/moon\\data']
        paths = glob("{}/*.npz".format(data_dir)) # ['datasets/moon\\data\\001.0000.npz', 'datasets/moon\\data\\001.0001.npz', 'datasets/moon\\data\\001.0002.npz', ...]

        if data_type == 'train':
            rng.shuffle(paths)  # ['datasets/moon\\data\\012.0287.npz', 'datasets/moon\\data\\004.0215.npz', 'datasets/moon\\data\\003.0149.npz', ...]

        if not config.skip_path_filter:
            items = parallel_run( get_frame, paths, desc="filter_by_min_max_frame_batch", parallel=True)  # [('datasets/moon\\data\\012.0287.npz', 130, 21), ('datasets/moon\\data\\003.0149.npz', 209, 37), ...]

            min_n_frame = hparams.reduction_factor * hparams.min_iters   # 5*30
            max_n_frame = hparams.reduction_factor * hparams.max_iters - hparams.reduction_factor  # 5*200 - 5
            
            # 다음 단계에서 data가 많이 떨어져 나감. 글자수가 짧은 것들이 탈락됨.
            new_items = [(path, n) for path, n, n_tokens in items if min_n_frame <= n <= max_n_frame and n_tokens >= hparams.min_tokens] # [('datasets/moon\\data\\004.0383.npz', 297), ('datasets/moon\\data\\003.0533.npz', 394),...]

            if any(check in data_dir for check in ["son", "yuinna"]):
                blacklists = [".0000.", ".0001.", "NB11479580.0001"]
                new_items = [item for item in new_items if any(check not in item[0] for check in blacklists)]

            new_paths = [path for path, n in new_items]
            new_n_frames = [n for path, n in new_items]

            hours = frames_to_hours(new_n_frames,hparams)

            log(' [{}] Loaded metadata for {} examples ({:.2f} hours)'.format(data_dir, len(new_n_frames), hours))
            log(' [{}] Max length: {}'.format(data_dir, max(new_n_frames)))
            log(' [{}] Min length: {}'.format(data_dir, min(new_n_frames)))
        else:
            new_paths = paths

        if data_type == 'train':
            new_paths = new_paths[:-n_test]
        elif data_type == 'test':
            new_paths = new_paths[-n_test:]
        else:
            raise Exception(" [!] Unkown data_type: {}".format(data_type))

        path_dict[data_dir] = new_paths  # ['datasets/moon\\data\\001.0621.npz', 'datasets/moon\\data\\003.0229.npz', ...]

    return path_dict