def __init__(self, dataset): super(_DatasetIterGE, self).__init__(dataset) self.loop_count = self.get_loop_count(dataset) parallel_mode = _get_parallel_mode() self.need_to_full = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) batch_expand_num = 1 if self.need_to_full: batch_expand_num = _get_device_num() tensor_list_run = _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num) def op(): return tensor_list_run self.op = op
def test_init_dataset_graph_one_dim(): types = (ms.float32,) shapes = ((1, 3, 224, 224),) _construct_tensor_list(types, shapes)
def test_init_dataset_graph_dim_error(): types = (ms.float32, ms.float32) shapes = ((1, 3, 224, 224),) with pytest.raises(ValueError): _construct_tensor_list(types, shapes)
def test_init_dataset_graph(): types = (ms.float32, ms.float32) shapes = ((1, 3, 224, 224), (32,)) _construct_tensor_list(types, shapes)