def test_double_subgraphs(): set_multi_subgraphs() context.set_context(save_graphs=True) context.set_auto_parallel_context(device_num=8, global_rank=0) context.set_auto_parallel_context(parallel_mode="auto_parallel") net = TrainStepWarp(NetWithLoss(Net())) net.set_auto_parallel() x = Tensor(np.ones([8, 8, 8, 8]), dtype=ms.float32) reset_op_id() _executor.compile(net, x, phase='train') strategies = _executor._get_strategy(net) expected_strategies = {'Default/network-NetWithLoss/ReduceMean-op0': [[8, 1, 1, 1]], 'Default/network-NetWithLoss/net-Net/ReLU-op1': [[8, 1, 1, 1]], 'Default/network-NetWithLoss/net-Net/Mul-op2': [[8, 1, 1, 1], [8, 1, 1, 1]], 'Default/network-NetWithLoss/net-Net/Mul-op3': [[8, 1, 1, 1], [8, 1, 1, 1]], 'Default/network-NetWithLoss/ReduceSum-op4': [[8, 1, 1, 1]]} assert strategies == expected_strategies
config=ckptconfig) context.set_auto_parallel_context(strategy_ckpt_save_file=config.stra_ckpt) callback_list = [ TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback ] if not host_device_mix: callback_list.append(ckpoint_cb) model.train(epochs, ds_train, callbacks=callback_list, dataset_sink_mode=(not host_device_mix)) if __name__ == "__main__": wide_deep_config = WideDeepConfig() wide_deep_config.argparse_init() context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target, save_graphs=True) context.set_context(variable_memory_max_size="24GB") context.set_context(enable_sparse=True) set_multi_subgraphs() init() if wide_deep_config.host_device_mix == 1: context.set_auto_parallel_context( parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True) else: context.set_auto_parallel_context( parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True) train_and_eval(wide_deep_config)