Example #1
0
 def collate_fn(self, items):
     batch = []
     items = itertools.chain.from_iterable(items)
     for uttid, data in items:
         aux_info = self.aux_utt_info.get(uttid, {})
         aux_info.pop("length", None)
         data.update(aux_info)
         data["x"] = torch.from_numpy(data["x"]).float()
         if self.tokenizer is not None:
             data["labels"] = torch.tensor(
                 self.tokenizer.text2ids(data["text"]))
         data.pop("rate", None)
         data["uttid"] = uttid
         batch.append(data)
     return batch
Example #2
0
    def load_json(json_path,
                  x='feat_length',
                  y='token_length',
                  x_range=(1, 9999),
                  y_range=(1, 999),
                  rate=(1, 99)):
        # json
        try:
            # json_path is a single file
            with open(json_path) as f:
                data = json.load(f)
        except:
            # json_path is a dir where *.json in
            data = []
            for dir, _, fs in os.walk(json_path):  # os.walk获取所有的目录
                for f in fs:
                    if f.endswith('.json'):  # 判断是否是".json"结尾
                        filename = os.path.join(dir, f)
                        print('loading json file :', filename)
                        with open(filename) as f:
                            add = json.load(f)
                            data.extend(add)
                        print('loaded {} samples'.format(len(add)))

        list_to_pop = []
        for i, sample in enumerate(data):
            len_x = sample[x]
            len_y = sample[y]
            if not (x_range[0] <= len_x <= x_range[1]) or \
               not (y_range[0] <= len_y <= y_range[1]) or \
               not (rate[0] <= (len_x / len_y) <= rate[1]):
                list_to_pop.append(i)

        # filter
        print('filtered {}/{} samples\n'.format(len(list_to_pop), len(data)))
        list_to_pop.reverse()
        [data.pop(i) for i in list_to_pop]

        return data
Example #3
0
    def __init__(self,
                 json_path,
                 reverse=False,
                 feat_range=(1, 99999),
                 label_range=(1, 100),
                 rate_in_out=(4, 999)):
        try:
            # json_path is a single file
            with open(json_path) as f:
                data = json.load(f)
        except:
            # json_path is a dir where *.json in
            data = []
            for dir, _, fs in os.walk(json_path):  # os.walk获取所有的目录
                for f in fs:
                    if f.endswith('.json'):  # 判断是否是".json"结尾
                        filename = os.path.join(dir, f)
                        print('loading json file :', filename)
                        with open(filename) as f:
                            add = json.load(f)
                            data.extend(add)
                        print('loaded {} samples'.format(len(add)))

        # filter
        list_to_pop = []
        for i, sample in enumerate(data):
            len_x = sample['feat_length']
            len_y = sample['token_length']
            if not (feat_range[0] <= len_x <= feat_range[1]) or \
               not (label_range[0] <= len_y <= label_range[1]) or \
               not (rate_in_out[0] <= (len_x / len_y) <= rate_in_out[1]):
                list_to_pop.append(i)
        print('filtered {}/{} samples\n'.format(len(list_to_pop), len(data)))
        list_to_pop.reverse()
        [data.pop(i) for i in list_to_pop]

        self.data = sorted(data, key=lambda x: float(x["feat_length"]))
        if reverse:
            self.data.reverse()
Example #4
0
    def load_flist(data_file, x='feat_length', x_range=(1, 9999)):
        data = []
        with open(data_file) as f:
            for i, line in enumerate(f):
                f_path, duration = line.strip().split()
                sample = {
                    'uttid': i,
                    'path': f_path,
                    'feat_length': int(duration)
                }
                data.append(sample)

        list_to_pop = []
        for i, sample in enumerate(data):
            len_x = sample[x]
            if not (x_range[0] <= len_x <= x_range[1]):
                list_to_pop.append(i)

        # filter
        print('filtered {}/{} samples\n'.format(len(list_to_pop), len(data)))
        list_to_pop.reverse()
        [data.pop(i) for i in list_to_pop]

        return data
Example #5
0
else:
    exit(1)

# load data
data = DL.DataLoader(m)
print(data.keys())
# for validation data
if mode == 'train':
    for k in data.keys():
        data[k] = data[k][76:]
elif mode == 'valid':
    for k in data.keys():
        data[k] = data[k][:76]

# get time data
output = data.pop('output')
input = data.pop('input')
swa = data.pop('SWA')
print(np.array(swa).shape)

datasize = len(swa)
seq_len = len(swa[0])
# input_test = test.pop('input')

# delete 0 after finesh #######
# search end of simulation
# ends = []
# for i, al in enumerate(data['SWA']):
#     for j, dat in enumerate(reversed(al)):
#         if dat != 0:
#             if j == 0:
Example #6
0
def tree_reduce(data):
    while len(data) > 1:
        a = data.pop(0)
        b = data.pop(0)
        data.append(merge_stats.remote(a, b))
    return ray.get(data)[0]
Example #7
0
 def __getitem__(self, item):
     data = super().__getitem__(item)
     data.observations = data.pop('actions')
     return data
Example #8
0
    def __init__(self,
                 data_json_path,
                 batch_size,
                 max_length_in,
                 max_length_out,
                 num_batches=0,
                 batch_frames=0):
        # From: espnet/src/asr/asr_utils.py: make_batchset()
        """
        Args:
            data: espnet/espnet json format file.
            num_batches: for debug. only use num_batches minibatch but not all.
        """
        super().__init__()
        try:
            with open(data_json_path, 'rb') as f:
                data = json.load(f)['utts']
        except:
            data = {}
            for fpathe, _, fs in os.walk(
                    os.path.dirname(data_json_path)):  # os.walk获取所有的目录
                for f in fs:
                    if f.endswith('.json'):  # 判断是否是".sfx"结尾
                        filename = os.path.join(fpathe, f)
                        with open(filename, 'rb') as f:
                            data = dict(
                                list(data.items()) +
                                list(json.load(f)['utts'].items()))

        list_to_pop = []
        for key, sample in data.items():
            len_x = int(sample['input'][0]['shape'][0])
            len_y = int(sample['output'][0]['shape'][0])
            if len_x / len_y < 5.0:
                list_to_pop.append(key)

        [data.pop(i) for i in list_to_pop]
        # sort it by input lengths (long to short)
        sorted_data = sorted(
            data.items(),
            key=lambda data: int(data[1]['input'][0]['shape'][0]),
            reverse=True)
        # change batchsize depending on the input and output length
        minibatch = []
        # Method 1: Generate minibatch based on batch_size
        # i.e. each batch contains #batch_size utterances
        if batch_frames == 0:
            start = 0
            while True:
                ilen = int(sorted_data[start][1]['input'][0]['shape'][0])
                olen = int(sorted_data[start][1]['output'][0]['shape'][0])
                factor = max(int(ilen / max_length_in),
                             int(olen / max_length_out))
                # if ilen = 1000 and max_length_in = 800
                # then b = batchsize / 2
                # and max(1, .) avoids batchsize = 0
                b = max(1, int(batch_size / (1 + factor)))
                end = min(len(sorted_data), start + b)
                minibatch.append(sorted_data[start:end])
                # DEBUG
                # total= 0
                # for i in range(start, end):
                #     total += int(sorted_data[i][1]['input'][0]['shape'][0])
                # print(total, end-start)
                if end == len(sorted_data):
                    break
                start = end
        # Method 2: Generate minibatch based on batch_frames
        # i.e. each batch contains approximately #batch_frames frames
        else:  # batch_frames > 0
            print("NOTE: Generate minibatch based on batch_frames.")
            print(
                "i.e. each batch contains approximately #batch_frames frames")
            start = 0
            while True:
                total_frames = 0
                end = start
                while total_frames < batch_frames and end < len(sorted_data):
                    ilen = int(sorted_data[end][1]['input'][0]['shape'][0])
                    total_frames += ilen
                    end += 1
                # print(total_frames, end-start)
                minibatch.append(sorted_data[start:end])
                if end == len(sorted_data):
                    break
                start = end
        if num_batches > 0:
            minibatch = minibatch[:num_batches]
        self.minibatch = minibatch
Example #9
0
    accu_test = mytest("target")
    print('============ Test ============= \n')
    print('Accuracy of the %s dataset: %f\n' % ('test', accu_test))
    return accu_test


acc = 0
for i in range(5):
    f = open('dataset/SEED/data.pkl', 'rb')
    data = pickle.load(f)
    f1 = open('dataset/SEED/source.pkl', 'wb')
    f2 = open('dataset/SEED/target.pkl', 'wb')
    source = {}
    target = {}
    target["sub_" + str(i)] = data["sub_" + str(i)]
    data.pop("sub_" + str(i))
    flag = False
    for item in list(data.keys()):
        source[item] = data[item]
    pickle.dump(source, f1)
    pickle.dump(target, f2)
    f1.close()
    f2.close()

    dataset_source = GetLoader(
        data_root=os.path.join('dataset', 'SEED'),
        data_list='source.pkl',
    )

    dataloader_source = torch.utils.data.DataLoader(
        dataset=dataset_source,