コード例 #1
0
def test_RandomSampler():
    indices = list(range(20))
    indices_copy = copy.deepcopy(indices)
    sampler = RandomSampler(ArrayDataset(indices_copy))
    sample_indices = sampler
    assert indices != list(each[0] for each in sample_indices)
    assert indices == sorted(list(each[0] for each in sample_indices))
コード例 #2
0
def get_dataloader(instance_num=102400):
    #instance_num = 102400
    datas = []
    labels = []
    positions = []
    for i in range(instance_num):
        inp_seq = np.zeros((MAXLEN, len(CHARSET)), dtype='int32')
        cur_len = MAXLEN
        out_seq = np.zeros((MAXLEN, ), dtype='int32')
        pos_encoding = pos_to_query

        for j in range(cur_len):
            pos = np.random.randint(1,
                                    len(CHARSET) -
                                    1)  # not generate '@' and '-'
            inp_seq[j][pos] = 1
            out_seq[cur_len - 1 - j] = pos

        datas.append(inp_seq)
        labels.append(out_seq)
        positions.append(pos_encoding)

    reverse_dataset = ArrayDataset(datas, labels, positions)
    random_sampler = RandomSampler(reverse_dataset, batch_size)
    dataloader = DataLoader(reverse_dataset, random_sampler)

    return dataloader
コード例 #3
0
def get_dataloader():
    instance_num = 102400
    datas = []
    labels = []
    masks = []
    for i in range(instance_num):
        cur_len = np.random.randint(MINLEN, MAXLEN + 1)
        inp_seq = np.zeros((MAXLEN + 1, len(CHARSET)), dtype='int32')
        cur_len = MAXLEN
        mask = np.zeros((MAXLEN + 1, ), dtype='int32')
        out_seq = np.zeros((MAXLEN + 1, ), dtype='int32')

        inp_seq[cur_len][len(CHARSET) - 1] = 1
        out_seq[cur_len] = len(CHARSET) - 1
        mask[:cur_len + 1] = 1
        for j in range(cur_len):
            pos = np.random.randint(1,
                                    len(CHARSET) -
                                    1)  # not generate '@' and '-'
            inp_seq[j][pos] = 1
            out_seq[cur_len - 1 - j] = pos

        datas.append(inp_seq)
        labels.append(out_seq)
        masks.append(mask)

    reverse_dataset = ArrayDataset(datas, labels, masks)
    random_sampler = RandomSampler(reverse_dataset, batch_size)
    dataloader = DataLoader(reverse_dataset, random_sampler)

    return dataloader
コード例 #4
0
def test_sampler_drop_last_true():
    batch_size = 5
    drop_last = True
    indices = list(range(24))
    sampler = SequentialSampler(
        ArrayDataset(indices), batch_size=batch_size, drop_last=drop_last
    )
    assert len([each for each in sampler]) == len(sampler)
コード例 #5
0
def test_ReplacementSampler():
    num_samples = 30
    indices = list(range(20))
    weights = list(range(20))
    sampler = ReplacementSampler(
        ArrayDataset(indices), num_samples=num_samples, weights=weights
    )
    assert len(list(each[0] for each in sampler)) == num_samples
コード例 #6
0
def test_random_sampler_seed():
    seed = [0, 1]
    indices = list(range(20))
    indices_copy1 = copy.deepcopy(indices)
    indices_copy2 = copy.deepcopy(indices)
    indices_copy3 = copy.deepcopy(indices)
    sampler1 = RandomSampler(ArrayDataset(indices_copy1), seed=seed[0])
    sampler2 = RandomSampler(ArrayDataset(indices_copy2), seed=seed[0])
    sampler3 = RandomSampler(ArrayDataset(indices_copy3), seed=seed[1])
    assert indices != list(each[0] for each in sampler1)
    assert indices != list(each[0] for each in sampler2)
    assert indices != list(each[0] for each in sampler3)
    assert indices == sorted(list(each[0] for each in sampler1))
    assert indices == sorted(list(each[0] for each in sampler2))
    assert indices == sorted(list(each[0] for each in sampler3))
    assert list(each[0] for each in sampler1) == list(each[0] for each in sampler2)
    assert list(each[0] for each in sampler1) != list(each[0] for each in sampler3)
コード例 #7
0
def init_dataset():
    sample_num = 100
    rand_data = np.random.randint(0,
                                  255,
                                  size=(sample_num, 1, 32, 32),
                                  dtype=np.uint8)
    label = np.random.randint(0, 10, size=(sample_num, ), dtype=int)
    dataset = ArrayDataset(rand_data, label)
    return dataset
コード例 #8
0
def test_array_dataset():
    size = (10, )
    data_shape = (3, 256, 256)
    label_shape = (1, )
    data = np.random.randint(0, 255, size + data_shape)
    label = np.random.randint(0, 9, size + label_shape)
    dataset = ArrayDataset(data, label)
    assert dataset[0][0].shape == data_shape
    assert dataset[0][1].shape == label_shape
    assert len(dataset) == size[0]
コード例 #9
0
 def get_dataloader(self, examples, batch_size, is_random=False):
     features = convert_examples_to_features(
         examples, self.label_list, self.args.max_seq_length, self.tokenizer
     )
     all_input_ids, all_input_mask, all_segment_ids, all_label_ids = self.to_inputs(
         features
     )
     dataset = ArrayDataset(
         all_input_ids, all_input_mask, all_segment_ids, all_label_ids
     )
     if is_random:
         sampler = RandomSampler(
             dataset=dataset, batch_size=batch_size, drop_last=True
         )
     else:
         sampler = SequentialSampler(
             dataset=dataset, batch_size=batch_size, drop_last=True
         )
     dataloader = DataLoader(dataset=dataset, sampler=sampler,)
     return dataloader, len(features)
コード例 #10
0
def test_sequential_sampler():
    indices = list(range(100))
    sampler = SequentialSampler(ArrayDataset(indices))
    assert indices == list(each[0] for each in sampler)
コード例 #11
0
def test_array_dataset_dim_error():
    data = np.random.randint(0, 255, (10, 3, 256, 256))
    label = np.random.randint(0, 9, (1, ))
    with pytest.raises(ValueError):
        ArrayDataset(data, label)