def setup_module(module): auto_parallel_context().set_enable_all_reduce_fusion( enable_all_reduce_fusion=True) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) _reset_op_id()
def reset_test_context(): context.reset_auto_parallel_context() auto_parallel_context().set_enable_all_reduce_fusion(enable_all_reduce_fusion=True) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) reset_cost_model_context() reset_algo_parameters() _reset_op_id()
def setup_module(): auto_parallel_context().set_enable_all_reduce_fusion(enable_all_reduce_fusion=True) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) GlobalComm.INITED = True reset_cost_model_context() reset_algo_parameters() _reset_op_id()
def test_one_dev(): _reset_op_id() strategies = all_to_all_common() for (k, v) in strategies.items(): if re.search('SoftmaxCrossEntropyWithLogits-op', k) is not None: assert v == [[1, 1], [1, 1]] elif re.search('Transpose-op', k) is not None: assert v == [[1, 1]] elif re.search('MatMul-op', k) is not None: assert v == [[1, 1], [1, 1]]
def test_one_dev(): _reset_op_id() strategys = all_to_all_common() expect_dict = { 'Default/network-_VirtualDatasetCell/_backbone-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits' '/SoftmaxCrossEntropyWithLogits-op9': [[1, 1], [1, 1]], 'Default/network-_VirtualDatasetCell/_backbone-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits' '/OneHot-op10': [[1, 1], [], []], 'Default/network-_VirtualDatasetCell/_backbone-WithLossCell/_backbone-AllToAllNet/Transpose-op11': [[1, 1]], 'Default/network-_VirtualDatasetCell/_backbone-WithLossCell/_backbone-AllToAllNet/MatMul-op12': [[1, 1], [1, 1]] } assert (strategys == expect_dict)
def test_all_to_all(): strategy1 = ((8, 1),) context.set_context(mode=context.GRAPH_MODE, save_graphs=False) _reset_op_id() strategys = all_to_all_common(strategy1) print(strategys) expect_dict = {'Default/network-_VirtualDatasetCell/_backbone-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits' '/SoftmaxCrossEntropyWithLogits-op3': [[8, 1], [8, 1]], 'Default/network-_VirtualDatasetCell/_backbone-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/OneHot-op4': [ [8, 1], [], []], 'Default/network-_VirtualDatasetCell/_backbone-WithLossCell/_backbone-AllToAllNet/Transpose-op1': [ [8, 1]], 'Default/network-_VirtualDatasetCell/_backbone-WithLossCell/_backbone-AllToAllNet/MatMul-op0': [ [1, 1], [1, 8]]} assert strategys == expect_dict context.set_context(save_graphs=False)
def test_all_to_all(): strategy1 = ((8, 1), ) context.set_context(mode=context.GRAPH_MODE, save_graphs=False) _reset_op_id() strategys = all_to_all_common(strategy1) print(strategys) for (k, v) in strategys.items(): if re.search('SoftmaxCrossEntropyWithLogits-op', k) is not None: assert v == [[8, 1], [8, 1]] elif re.search('OneHot-op', k) is not None: assert v == [[8, 1], [], []] elif re.search('Transpose-op', k) is not None: assert v == [[8, 1]] elif re.search('MatMul-op', k) is not None: assert v == [[1, 1], [1, 8]] context.set_context(save_graphs=False)
def test_data_parallel_mode(): _reset_op_id() learning_rate = 0.1 momentum = 0.9 epoch_size = 2 context.set_context(mode=context.GRAPH_MODE, save_graphs=False) context.reset_auto_parallel_context() context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, full_batch=True) predict = Tensor(np.ones([256, 128]), dtype=ms.float32) label = Tensor(np.ones([256]), dtype=ms.int32) dataset = Dataset(predict, label, 2) net = all_to_all_net(None) loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') opt = Momentum(net.trainable_params(), learning_rate, momentum) model = Model(net, loss, opt) with pytest.raises(RuntimeError): model.train(epoch_size, dataset, dataset_sink_mode=False)
def teardown_module(): context.reset_auto_parallel_context() reset_cost_model_context() reset_algo_parameters() _reset_op_id()
def test_all_to_all(): strategy1 = ((8, 1), ) _reset_op_id() all_to_all_common(strategy1)
def teardown_module(): context.reset_auto_parallel_context() _reset_op_id()
def teardown_module(): context.reset_auto_parallel_context() GlobalComm.INITED = False reset_cost_model_context() reset_algo_parameters() _reset_op_id()
def test_model_callback(): strategy1 = ((8, 1), ) _reset_op_id() all_to_all_common(strategy1)
def setup_module(module): context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) _reset_op_id()