def test_generator_3():
    """
    Test 1D Generator + repeat(4)
    """
    logger.info("Test 1D Generator : 0 - 63 + Repeat(4)")

    # apply dataset operations
    data1 = ds.GeneratorDataset(generator_1d, ["data"])

    data1 = data1.repeat(4)

    i = 0
    for item in data1.create_dict_iterator():  # each data is a dictionary
        golden = np.array([i])
        np.testing.assert_array_equal(item["data"], golden)
        i = i + 1
        if i == 64:
            i = 0
def test_generator_15():
    """
    Test 1D Generator MP + Python sampler
    """
    logger.info("Test 1D Generator MP : 0 - 63")

    sampler = [x for x in range(256)]
    source = [(np.array([x]), ) for x in range(256)]
    ds1 = ds.GeneratorDataset(source, ["data"],
                              sampler=sampler,
                              num_parallel_workers=4).repeat(2)
    i = 0
    for data in ds1.create_dict_iterator():  # each data is a dictionary
        golden = np.array([i])
        np.testing.assert_array_equal(data["data"], golden)
        i = i + 1
        if i == 256:
            i = 0
示例#3
0
 def test_config(lookup_str, data_type=None):
     try:
         vocab = text.Vocab.from_list(["w1", "w2", "w3"],
                                      special_tokens=["<unk>"],
                                      special_first=True)
         data = ds.GeneratorDataset(gen(lookup_str), column_names=["text"])
         # if data_type is None, test the default value of data_type
         op = text.Lookup(vocab,
                          "<unk>") if data_type is None else text.Lookup(
                              vocab, "<unk>", data_type)
         data = data.map(operations=op, input_columns=["text"])
         res = []
         for d in data.create_dict_iterator(num_epochs=1,
                                            output_numpy=True):
             res.append(d["text"])
         return res[0].dtype
     except (ValueError, RuntimeError, TypeError) as e:
         return str(e)
示例#4
0
def test_simple_sync_wait():
    """
    Test simple sync wait: test sync in dataset pipeline
    """
    logger.info("test_simple_sync_wait")
    batch_size = 4
    dataset = ds.GeneratorDataset(gen, column_names=["input"])

    aug = Augment(0)
    dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
    dataset = dataset.map(operations=[aug.preprocess], input_columns=["input"])
    dataset = dataset.batch(batch_size)
    count = 0
    for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
        assert data["input"][0] == count
        count += batch_size
        data = {"loss": count}
        dataset.sync_update(condition_name="policy", data=data)
示例#5
0
def test_case_16():
    """
    Test multi column generator Mp + CPP sampler
    """
    logger.info("Test multi column generator")

    source = [(np.array([x]), np.array([x + 1])) for x in range(256)]
    # apply dataset operations
    data1 = ds.GeneratorDataset(source, ["col0", "col1"],
                                sampler=ds.SequentialSampler())

    i = 0
    for item in data1.create_dict_iterator():  # each data is a dictionary
        golden = np.array([i])
        assert np.array_equal(item["col0"], golden)
        golden = np.array([i + 1])
        assert np.array_equal(item["col1"], golden)
        i = i + 1
示例#6
0
def create_yolo_dataset(image_dir, anno_path, batch_size=32, repeat_num=10, device_num=1, rank=0,
                        is_training=True, num_parallel_workers=8):
    """Creatr YOLOv3 dataset with GeneratorDataset."""
    yolo_dataset = YoloDataset(image_dir=image_dir, anno_path=anno_path)
    distributed_sampler = DistributedSampler(yolo_dataset.dataset_size, batch_size, device_num, rank)
    ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "annotation"], sampler=distributed_sampler)
    ds.set_dataset_size(len(distributed_sampler))
    compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training))
    hwc_to_chw = P.HWC2CHW()
    ds = ds.map(input_columns=["image", "annotation"],
                output_columns=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"],
                columns_order=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"],
                operations=compose_map_func, num_parallel_workers=num_parallel_workers)
    ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=num_parallel_workers)
    ds = ds.shuffle(buffer_size=256)
    ds = ds.batch(batch_size, drop_remainder=True)
    ds = ds.repeat(repeat_num)
    return ds
示例#7
0
    def test_config(input_line,
                    output_line,
                    n,
                    l_pad=("", 0),
                    r_pad=("", 0),
                    sep=" "):
        def gen(texts):
            yield (np.array(texts.split(" "), dtype='S'), )

        dataset = ds.GeneratorDataset(gen(input_line), column_names=["text"])
        dataset = dataset.map(input_columns=["text"],
                              operations=text.Ngram(n,
                                                    l_pad,
                                                    r_pad,
                                                    separator=sep))
        for data in dataset.create_dict_iterator():
            assert [d.decode("utf8")
                    for d in data["text"]] == output_line, output_line
示例#8
0
def test_generator_tuple_1():
    """
    test generator tuple 1
    """
    logger.info("Test 1D Generator : 0 - 63")

    # apply dataset operations
    data1 = ds.GeneratorDataset(generator_1d, ["data"])

    for _ in range(10):
        i = 0
        # BAD. Do not create iterator every time inside.
        # Create iterator outside the epoch for loop.
        for item in data1.create_tuple_iterator():  # each data is a dictionary
            golden = np.array([i])
            np.testing.assert_array_equal(item[0], golden)
            i = i + 1
        assert i == 64
示例#9
0
def test_filter_by_generator_with_map_part_col():
    dataset = ds.GeneratorDataset(generator_mc(12), ["col1", "col2"])
    dataset_map = dataset.map(operations=func_map_part,
                              input_columns=["col1"],
                              output_columns=["out1"])

    dataset_f = dataset_map.filter(input_columns=["out1", "col2"],
                                   predicate=filter_func_map,
                                   num_parallel_workers=4)
    num_iter = 0
    ret_data = []
    for item in dataset_f.create_dict_iterator(num_epochs=1,
                                               output_numpy=True):
        num_iter += 1
        ret_data.append(item["out1"])
    assert num_iter == 3
    assert ret_data[0] == 9
    assert ret_data[2] == 11
def test_generator_17():
    """
    Test multi column generator Mp + Python sampler
    """
    logger.info("Test multi column generator")

    sampler = [x for x in range(256)]
    source = [(np.array([x]), np.array([x + 1])) for x in range(256)]
    # apply dataset operations
    data1 = ds.GeneratorDataset(source, ["col0", "col1"], sampler=sampler)

    i = 0
    for item in data1.create_dict_iterator():  # each data is a dictionary
        golden = np.array([i])
        np.testing.assert_array_equal(item["col0"], golden)
        golden = np.array([i + 1])
        np.testing.assert_array_equal(item["col1"], golden)
        i = i + 1
示例#11
0
def test_sync_exception_03():
    """
    Test sync: with wrong batch size
    """
    logger.info("test_sync_exception_03")

    dataset = ds.GeneratorDataset(gen, column_names=["input"])

    aug = Augment(0)
    # try to create dataset with batch_size < 0
    try:
        dataset = dataset.sync_wait(condition_name="every batch",
                                    num_batch=-1,
                                    callback=aug.update)
    except Exception as e:
        assert "num_batch" in str(e)

    dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
def test_graphdata_distributed():
    """
    Test distributed
    """
    logger.info('test distributed.\n')

    server_port = random.randint(10000, 60000)

    p1 = Process(target=graphdata_startserver, args=(server_port,))
    p1.start()
    time.sleep(5)

    g = ds.GraphData(DATASET_FILE, 1, 'client', port=server_port)
    nodes = g.get_all_nodes(1)
    assert nodes.tolist() == [101, 102, 103, 104, 105, 106, 107, 108, 109, 110]
    row_tensor = g.get_node_feature(nodes.tolist(), [1, 2, 3])
    assert row_tensor[0].tolist() == [[0, 1, 0, 0, 0], [1, 0, 0, 0, 1], [0, 0, 1, 1, 0], [0, 0, 0, 0, 0],
                                      [1, 1, 0, 1, 0], [0, 0, 0, 0, 1], [0, 1, 0, 0, 0], [0, 0, 0, 1, 1],
                                      [0, 1, 1, 0, 0], [0, 1, 0, 1, 0]]
    assert row_tensor[2].tolist() == [1, 2, 3, 1, 4, 3, 5, 3, 5, 4]

    edges = g.get_all_edges(0)
    assert edges.tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
                              21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40]
    features = g.get_edge_feature(edges, [1, 2])
    assert features[0].tolist() == [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0,
                                    0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]

    batch_num = 2
    edge_num = g.graph_info()['edge_num'][0]
    out_column_names = ["neighbors", "neg_neighbors", "neighbors_features", "neg_neighbors_features"]
    dataset = ds.GeneratorDataset(source=GNNGraphDataset(g, batch_num), column_names=out_column_names,
                                  sampler=RandomBatchedSampler(edge_num, batch_num), num_parallel_workers=4,
                                  python_multiprocessing=False)
    dataset = dataset.repeat(2)
    itr = dataset.create_dict_iterator()
    i = 0
    for data in itr:
        assert data['neighbors'].shape == (2, 7)
        assert data['neg_neighbors'].shape == (6, 7)
        assert data['neighbors_features'].shape == (2, 7)
        assert data['neg_neighbors_features'].shape == (6, 7)
        i += 1
    assert i == 40
示例#13
0
def manual_test_keyborad_interrupt():
    """
    Test keyborad_interrupt
    """
    logger.info("Test 1D Generator MP : 0 - 63")

    class MyDS():
        def __getitem__(self, item):
            while True:
                pass

        def __len__(self):
            return 1024

    ds1 = ds.GeneratorDataset(MyDS(), ["data"],
                              num_parallel_workers=4).repeat(2)
    i = 0
    for data in ds1.create_dict_iterator():  # each data is a dictionary
        pass
示例#14
0
def test_sync_exception_02():
    """
    Test sync: with duplicated condition name
    """
    logger.info("test_sync_exception_02")
    batch_size = 6

    dataset = ds.GeneratorDataset(gen, column_names=["input"])

    aug = Augment(0)
    dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)

    dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])

    try:
        dataset = dataset.sync_wait(num_batch=2, condition_name="every batch")
    except Exception as e:
        assert "name" in str(e)
    dataset = dataset.batch(batch_size)
示例#15
0
def test_sync_exception_01():
    """
    Test sync: with shuffle in sync mode
    """
    logger.info("test_sync_exception_01")
    shuffle_size = 4
    batch_size = 10

    dataset = ds.GeneratorDataset(gen, column_names=["input"])

    aug = Augment(0)
    dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
    dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])

    try:
        dataset = dataset.shuffle(shuffle_size)
    except Exception as e:
        assert "shuffle" in str(e)
    dataset = dataset.batch(batch_size)
示例#16
0
def create_icdar_train_dataset(img_path, gt_path, batch_size=32, repeat_num=10, 
                                is_training=True, num_parallel_workers=1, length=512, scale=0.25):

    dataloader = ds.GeneratorDataset(source=custom_dataset(img_path, gt_path, scale=scale, length=length), column_names=["img", "score_map", "geo_map", "ignored_map"], shuffle=is_training, num_parallel_workers=num_parallel_workers)
    dataloader.set_dataset_size(1000)
    transform = py_transforms.ComposeOp([py_transforms.RandomColorAdjust(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.25), \
                                        py_transforms.ToTensor(), \
                                        py_transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])
    dataloader = dataloader.map(input_columns="img", operations=transform, num_parallel_workers=num_parallel_workers, python_multiprocessing=is_training)
    dataloader = dataloader.batch(batch_size, drop_remainder=True)
    
    return dataloader






        
示例#17
0
def test_case_09(add_remove_file):

    # apply dataset operations
    d1 = ds.GeneratorDataset(generator_dynamic_2d_0, ["data"], shuffle=False)

    d1.save(AUTO_FILE)

    d2 = ds.MindDataset(dataset_file=AUTO_FILE,
                        num_parallel_workers=num_readers,
                        shuffle=False)

    i = 0
    for item in d2.create_dict_iterator(num_epochs=1, output_numpy=True):
        if i < 5:
            golden = np.arange(5).reshape([1, 5])
        else:
            golden = np.arange(10).reshape([2, 5])
        np.testing.assert_array_equal(item["data"], golden)
        i = i + 1
示例#18
0
def GetDataLoader(per_batch_size,
                  max_epoch,
                  rank,
                  group_size,
                  config,
                  split='train'):
    """
    Centerface get data loader
    """
    centerface_gen = CenterfaceDataset(config=config, split=split)
    sampler = DistributedSampler(
        centerface_gen, rank, group_size,
        shuffle=(split == 'train'))  # user defined sampling strategy
    de_dataset = ds.GeneratorDataset(centerface_gen, ["image", "anns"],
                                     sampler=sampler,
                                     num_parallel_workers=16)

    if group_size > 1:
        num_parallel_workers = 24
    else:
        num_parallel_workers = 64
    if split == 'train':
        compose_map_func = (
            lambda image, anns: preprocess_train(image, anns, config=config))
        columns = [
            'image', "hm", 'reg_mask', 'ind', 'wh', 'wight_mask', 'hm_offset',
            'hps_mask', 'landmarks'
        ]
        de_dataset = de_dataset.map(input_columns=["image", "anns"],
                                    output_columns=columns,
                                    column_order=columns,
                                    operations=compose_map_func,
                                    num_parallel_workers=num_parallel_workers,
                                    python_multiprocessing=True)

    de_dataset = de_dataset.batch(per_batch_size,
                                  drop_remainder=True,
                                  num_parallel_workers=8)
    if split == 'train':
        #de_dataset = de_dataset.repeat(1) # if use this, need an additional "for" cycle epoch times
        de_dataset = de_dataset.repeat(max_epoch)

    return de_dataset, de_dataset.get_dataset_size()
示例#19
0
def test_iterator_create_tuple_mstensor():
    """
    Test creating tuple iterator with output MSTensor
    """
    def generator():
        for i in range(64):
            yield (np.array([i], dtype=np.float32), )

    # apply dataset operations
    data1 = ds.GeneratorDataset(generator, ["data"])

    i = 0
    for item in data1.create_tuple_iterator(num_epochs=1):
        golden = np.array([i], dtype=np.float32)
        np.testing.assert_array_equal(item[0].asnumpy(), golden)
        assert isinstance(item[0], Tensor)
        assert item[0].dtype == mstype.float32
        i += 1
    assert i == 64
示例#20
0
def test_generator_tuple_infinite_repeat_repeat_4():
    """
    test generator tuple infinite repeat repeat 4
    """
    logger.info("Test 1D Generator : 0 - 63")

    # apply dataset operations
    data1 = ds.GeneratorDataset(generator_1d, ["data"])
    data1 = data1.repeat()
    data1 = data1.repeat()
    iter1 = data1.create_tuple_iterator(output_numpy=True)

    i = 0
    for item in iter1:  # each data is a dictionary
        golden = np.array([i % 64])
        np.testing.assert_array_equal(item[0], golden)
        i = i + 1
        if i == 100:
            break
示例#21
0
def type_tester_with_type_check_2c_schema(t, c):
    logger.info("Test with Type {}".format(t.__name__))

    schema = ds.Schema()
    schema.add_column("data0", c[0])
    schema.add_column("data1", c[1])

    # apply dataset operations
    data1 = ds.GeneratorDataset((lambda: generator_with_type_2c(t)),
                                schema=schema)

    data1 = data1.batch(4)

    i = 0
    for item in data1.create_dict_iterator(
            num_epochs=1, output_numpy=True):  # each data is a dictionary
        golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t)
        np.testing.assert_array_equal(item["data0"], golden)
        i = i + 4
示例#22
0
def test_generator_reset_6():
    """
    Test Generator -> Repeat -> Repeat -> EpochCtrl
    """
    logger.info("test_generator_reset_6")
    # apply dataset operations
    data1 = ds.GeneratorDataset(generator_10to12, ["data"])
    branch1 = data1.repeat(2).take(5).repeat(2).skip(2)
    iter1 = branch1.create_dict_iterator(num_epochs=3, output_numpy=True)

    output = np.array([0])
    for _ in range(2):
        for item in iter1:
            output = np.append(output, item["data"])

    golden = np.array(
        [0, 12, 10, 11, 10, 11, 12, 10, 11, 12, 10, 11, 10, 11, 12, 10, 11])

    np.testing.assert_array_equal(output, golden)
示例#23
0
def test_to_number_typical_case_non_integral():
    input_strings = [["-1.1", "1.4"], ["-2219.321", "7623.453"],
                     ["-816256.234282", "162371864.243243"]]
    epsilons = [0.001, 0.001, 0.0001, 0.0001, 0.0000001, 0.0000001]

    for ms_type, inputs in zip(ms_non_integral_types, input_strings):
        dataset = ds.GeneratorDataset(string_dataset_generator(inputs),
                                      "strings")
        dataset = dataset.map(input_columns=["strings"],
                              operations=text.ToNumber(ms_type))

        expected_output = [float(string) for string in inputs]
        output = []
        for data in dataset.create_dict_iterator():
            output.append(data["strings"])

        for expected, actual, epsilon in zip(expected_output, output,
                                             epsilons):
            assert abs(expected - actual) < epsilon
示例#24
0
def test_take_18():
    """
    Test take: take first, then do fiter, skip, batch and repeat operation
    """
    logger.info("test_take_18")
    data1 = ds.GeneratorDataset(generator_10, ["data"])

    data1 = data1.take(8)
    data1 = data1.filter(predicate=filter_func_ge, num_parallel_workers=4)
    data1 = data1.skip(2)

    data1 = data1.batch(2)
    data1 = data1.repeat(2)

    # Here i refers to index, d refers to data element
    for _, d in enumerate(data1):
        assert d[0].asnumpy()[0] == 2

    assert sum([1 for _ in data1]) == 2
示例#25
0
def test_sync_exception_04():
    """
    Test sync: with negative batch size in update
    """
    logger.info("test_sync_exception_04")

    dataset = ds.GeneratorDataset(gen, column_names=["input"])

    aug = Augment(0)
    # try to create dataset with batch_size < 0
    dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
    dataset = dataset.map(operations=[aug.preprocess], input_columns=["input"])
    count = 0
    with pytest.raises(RuntimeError) as e:
        for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
            count += 1
            data = {"loss": count}
            dataset.sync_update(condition_name="every batch", num_batch=-1, data=data)
    assert "Sync_update batch size can only be positive" in str(e.value)
def test_bucket_batch_single_bucket_no_padding():
    dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"])

    column_names = ["col1"]
    bucket_boundaries = [1, 2, 3]
    bucket_batch_sizes = [1, 1, 5, 1]
    element_length_function = (lambda x: 2)

    dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries,
                                             bucket_batch_sizes, element_length_function)

    expected_output = [[[0], [1], [2], [3], [4]],
                       [[5], [6], [7], [8], [9]]]

    output = []
    for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
        output.append(data["col1"].tolist())

    assert output == expected_output
示例#27
0
def test_tensor_empty_map():
    def gen():
        for _ in range(4):
            (yield np.array([], dtype=np.int64), np.array([], dtype='S'), np.array([1], dtype=np.float64))

    data = ds.GeneratorDataset(gen, column_names=["col1", "col2", "col3"])

    def func(x, y, z):
        x = np.array([1], dtype=np.int64)
        y = np.array(["Hi"], dtype='S')
        z = np.array([], dtype=np.float64)
        return x, y, z

    data = data.map(input_columns=["col1", "col2", "col3"], operations=func)

    for d in data:
        np.testing.assert_array_equal(np.array([1], dtype=np.int64), d[0])
        np.testing.assert_array_equal(np.array(["Hi"], dtype='S'), d[1])
        np.testing.assert_array_equal(np.array([], dtype=np.float64), d[2])
示例#28
0
文件: eval.py 项目: yrpang/mindspore
def main():
    """eval"""
    for arg in vars(args):
        if vars(args)[arg] == 'True':
            vars(args)[arg] = True
        elif vars(args)[arg] == 'False':
            vars(args)[arg] = False
    train_dataset = SRData(args,
                           name=args.data_test,
                           train=False,
                           benchmark=False)
    train_de_dataset = de.GeneratorDataset(train_dataset, ['LR', "HR"],
                                           shuffle=False)
    train_de_dataset = train_de_dataset.batch(1, drop_remainder=True)
    train_loader = train_de_dataset.create_dict_iterator()

    net_m = ipt.IPT(args)
    print('load mindspore net successfully.')
    if args.pth_path:
        param_dict = load_checkpoint(args.pth_path)
        load_param_into_net(net_m, param_dict)
    net_m.set_train(False)
    num_imgs = train_de_dataset.get_dataset_size()
    psnrs = np.zeros((num_imgs, 1))
    inference = ipt.IPT_post(net_m, args)
    for batch_idx, imgs in enumerate(train_loader):
        lr = imgs['LR']
        hr = imgs['HR']
        hr_np = np.float32(hr.asnumpy())
        pred = inference.forward(lr)
        pred_np = np.float32(pred.asnumpy())
        pred_np = quantize(pred_np, 255)
        psnr = calc_psnr(pred_np, hr_np, 4, 255.0, y_only=True)
        psnrs[batch_idx, 0] = psnr
    if args.denoise:
        print('Mean psnr of %s DN_%s is %.4f' %
              (args.data_test[0], args.sigma, psnrs.mean(axis=0)[0]))
    elif args.derain:
        print('Mean psnr of Derain is %.4f' % (psnrs.mean(axis=0)))
    else:
        print('Mean psnr of %s x%s is %.4f' %
              (args.data_test[0], args.scale[0], psnrs.mean(axis=0)[0]))
示例#29
0
def create_dataset(args,
                   data_url,
                   epoch_num=1,
                   batch_size=1,
                   usage="train",
                   shuffle=True):
    """
    Create Dataset for DeepLabV3.

    Args:
        args (dict): Train parameters.
        data_url (str): Dataset path.
        epoch_num (int): Epoch of dataset (default=1).
        batch_size (int): Batch size of dataset (default=1).
        usage (str): Whether is use to train or eval (default='train').

    Returns:
        Dataset.
    """
    # create iter dataset
    dataset = HwVocRawDataset(data_url, usage=usage)
    dataset_len = len(dataset)

    # wrapped with GeneratorDataset
    dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=None)
    dataset.set_dataset_size(dataset_len)
    dataset = dataset.map(input_columns=["image", "label"],
                          operations=DataTransform(args, usage=usage))

    channelswap_op = C.HWC2CHW()
    dataset = dataset.map(input_columns="image", operations=channelswap_op)

    # 1464 samples / batch_size 8 = 183 batches
    # epoch_num is num of steps
    # 3658 steps / 183 = 20 epochs
    if usage == "train" and shuffle:
        dataset = dataset.shuffle(1464)
    dataset = dataset.batch(batch_size, drop_remainder=(usage == "train"))
    dataset = dataset.repeat(count=epoch_num)
    dataset.map_model = 4

    return dataset
示例#30
0
def test_generator_tuple_2():
    """
    test generator tuple 2
    """
    logger.info("Test 1D Generator : 0 - 63")

    # apply dataset operations
    data1 = ds.GeneratorDataset(generator_1d, ["data"])
    iter1 = data1.create_tuple_iterator(output_numpy=True)
    for _ in range(10):
        i = 0
        for item in iter1:  # each data is a dictionary
            golden = np.array([i])
            np.testing.assert_array_equal(item[0], golden)
            i = i + 1
        assert i == 64

    # iter1 is still alive and running.
    item1 = iter1.__next__()
    assert item1