コード例 #1
0
ファイル: dataset_helper.py プロジェクト: lyc4614/mindspore
    def __init__(self, dataset):
        super(_DatasetIterGE, self).__init__(dataset)
        self.loop_count = self.get_loop_count(dataset)
        parallel_mode = _get_parallel_mode()
        self.need_to_full = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL,
                                              ParallelMode.AUTO_PARALLEL)
        batch_expand_num = 1
        if self.need_to_full:
            batch_expand_num = _get_device_num()
        tensor_list_run = _construct_tensor_list(self.dataset_types,
                                                 self.dataset_shapes,
                                                 batch_expand_num)

        def op():
            return tensor_list_run

        self.op = op
コード例 #2
0
ファイル: test_utils.py プロジェクト: zuoshou030/mindspore
def test_init_dataset_graph_one_dim():
    types = (ms.float32,)
    shapes = ((1, 3, 224, 224),)
    _construct_tensor_list(types, shapes)
コード例 #3
0
ファイル: test_utils.py プロジェクト: zuoshou030/mindspore
def test_init_dataset_graph_dim_error():
    types = (ms.float32, ms.float32)
    shapes = ((1, 3, 224, 224),)
    with pytest.raises(ValueError):
        _construct_tensor_list(types, shapes)
コード例 #4
0
ファイル: test_utils.py プロジェクト: zuoshou030/mindspore
def test_init_dataset_graph():
    types = (ms.float32, ms.float32)
    shapes = ((1, 3, 224, 224), (32,))
    _construct_tensor_list(types, shapes)