def test_train_4k_8p_gpu(batch_size=32, num_classes=4096): dev_num = 8 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num) set_algo_parameters(elementwise_op_strategy_follow=True) resset_op_id() np.random.seed(6) input_np = np.ones([batch_size, 3, 224, 224]).astype(np.float32) label_np = np.zeros([batch_size]).astype(np.int32) for i in range(0, batch_size): label_np[i] = i % num_classes dataset = DatasetLenet(Tensor(input_np), Tensor(label_np), 1) net = resnet50(num_classes) loss = SoftmaxCrossEntropyExpand(sparse=True) opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9) model = Model(net, loss_fn=loss, optimizer=opt) model.train(5, dataset, dataset_sink_mode=False) strategies = _executor._get_shard_strategy(model._train_network) for (k, v) in strategies.items(): if re.search('Conv2D-op', k) is not None: assert v[0][0] == dev_num elif re.search('MatMul-op', k) is not None: assert v == [[dev_num, 1], [1, 1]] elif re.search('ReduceSum-op', k) is not None: assert v == [[dev_num, 1]]
def test_train_64k_8p(epoch_size=3, batch_size=32, num_classes=65536): #1048576 #131072 #32768 #8192 dev_num = 8 context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num) cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=260.0) set_algo_parameters(elementwise_op_strategy_follow=True) resset_op_id() np.random.seed(6) input_np = np.ones([batch_size, 3, 224, 224]).astype(np.float32) label_np = np.zeros([batch_size]).astype(np.int32) for i in range(0, batch_size): label_np[i] = i % num_classes dataset = DatasetLenet(Tensor(input_np), Tensor(label_np), 1) net = resnet50(num_classes) loss = SoftmaxCrossEntropyExpand(sparse=True) opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9) model = Model(net, loss_fn=loss, optimizer=opt) model.train(5, dataset, dataset_sink_mode=False) strategies = _executor._get_strategy(model._train_network) for (k, v) in strategies.items(): if re.match(k, 'Conv2D-op') is not None: assert v[0][0] == dev_num elif re.match(k, 'MatMul-op') is not None: assert v == [[1, 1], [dev_num, 1]] elif re.match(k, 'ReduceSum-op') is not None: assert v == [[1, dev_num]]
def test_train_32k_8p(epoch_size=3, batch_size=32, num_classes=32768): #1048576 #131072 #32768 #8192 dev_num = 8 context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num) cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=260.0) cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1) cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2) cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5) set_algo_parameters(elementwise_op_strategy_follow=True) resset_op_id() np.random.seed(6) input_np = np.ones([batch_size, 3, 224, 224]).astype(np.float32) label_np = np.zeros([batch_size]).astype(np.int32) for i in range(0, batch_size): label_np[i] = i % num_classes dataset = DatasetLenet(Tensor(input_np), Tensor(label_np), 1) net = resnet50(num_classes) loss = SoftmaxCrossEntropyExpand(sparse=True) opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9) model = Model(net, loss_fn=loss, optimizer=opt) model.train(5, dataset, dataset_sink_mode=False) strategies = _executor._get_strategy(model._train_network) for (k, v) in strategies.items(): if re.search('Conv2D-op', k) is not None: assert v[0][0] == dev_num elif re.search('MatMul-op', k) is not None: assert v == [[dev_num, 1], [1, 1]] elif re.search('ReduceSum-op', k) is not None: assert v == [[dev_num, 1]] allreduce_fusion_dict = _executor._get_allreduce_fusion(model._train_network) print(allreduce_fusion_dict) expect_dict = {'end_point.bias': 2, 'end_point.weight': 2, 'layer4.2.bn3.beta': 2, 'layer4.2.bn3.gamma': 2, 'layer4.2.conv3.weight': 2, 'layer4.2.bn2.beta': 2, 'layer4.2.bn2.gamma': 2, 'layer4.2.conv2.weight': 2, 'layer4.2.bn1.beta': 2, 'layer4.2.bn1.gamma': 2, 'layer4.2.conv1.weight': 2, 'layer4.1.bn3.beta': 2, 'layer4.1.bn3.gamma': 2, 'layer4.1.conv3.weight': 2, 'layer4.1.bn2.beta': 2, 'layer4.1.bn2.gamma': 2, 'layer4.1.conv2.weight': 2, 'layer4.1.bn1.beta': 2, 'layer4.1.bn1.gamma': 2, 'layer4.1.conv1.weight': 2, 'layer4.0.bn_down_sample.beta': 2, 'layer4.0.bn_down_sample.gamma': 2, 'layer4.0.conv_down_sample.weight': 2, 'layer4.0.bn3.beta': 2, 'layer4.0.bn3.gamma': 2, 'layer4.0.conv3.weight': 2, 'layer4.0.bn2.beta': 2, 'layer4.0.bn2.gamma': 2, 'layer4.0.conv2.weight': 2, 'layer4.0.bn1.beta': 2, 'layer4.0.bn1.gamma': 2, 'layer4.0.conv1.weight': 2, 'layer3.5.bn3.beta': 2, 'layer3.5.bn3.gamma': 2, 'layer3.5.conv3.weight': 2, 'layer3.5.bn2.beta': 2, 'layer3.5.bn2.gamma': 2, 'layer3.5.conv2.weight': 2, 'layer3.5.bn1.beta': 2, 'layer3.5.bn1.gamma': 2, 'layer3.5.conv1.weight': 2, 'layer3.4.bn3.beta': 2, 'layer3.4.bn3.gamma': 2, 'layer3.4.conv3.weight': 2, 'layer3.4.bn2.beta': 2, 'layer3.4.bn2.gamma': 2, 'layer3.4.conv2.weight': 2, 'layer3.4.bn1.beta': 2, 'layer3.4.bn1.gamma': 2, 'layer3.4.conv1.weight': 2, 'layer3.3.bn3.beta': 2, 'layer3.3.bn3.gamma': 2, 'layer3.3.conv3.weight': 2, 'layer3.3.bn2.beta': 2, 'layer3.3.bn2.gamma': 2, 'layer3.3.conv2.weight': 2, 'layer3.3.bn1.beta': 2, 'layer3.3.bn1.gamma': 2, 'layer3.3.conv1.weight': 2, 'layer3.2.bn3.beta': 2, 'layer3.2.bn3.gamma': 2, 'layer3.2.conv3.weight': 2, 'layer3.2.bn2.beta': 2, 'layer3.2.bn2.gamma': 2, 'layer3.2.conv2.weight': 2, 'layer3.2.bn1.beta': 2, 'layer3.2.bn1.gamma': 2, 'layer3.2.conv1.weight': 2, 'layer3.1.bn3.beta': 2, 'layer3.1.bn3.gamma': 2, 'layer3.1.conv3.weight': 2, 'layer3.1.bn2.beta': 2, 'layer3.1.bn2.gamma': 2, 'layer3.1.conv2.weight': 2, 'layer3.1.bn1.beta': 2, 'layer3.1.bn1.gamma': 2, 'layer3.1.conv1.weight': 2, 'layer3.0.bn_down_sample.beta': 1, 'layer3.0.bn_down_sample.gamma': 1, 'layer3.0.conv_down_sample.weight': 2, 'layer3.0.bn3.beta': 1, 'layer3.0.bn3.gamma': 1, 'layer3.0.conv3.weight': 2, 'layer3.0.bn2.beta': 2, 'layer3.0.bn2.gamma': 2, 'layer3.0.conv2.weight': 2, 'layer3.0.bn1.beta': 2, 'layer3.0.bn1.gamma': 2, 'layer3.0.conv1.weight': 2, 'layer2.3.bn3.beta': 2, 'layer2.3.bn3.gamma': 2, 'layer2.3.conv3.weight': 2, 'layer2.3.bn2.beta': 2, 'layer2.3.bn2.gamma': 2, 'layer2.3.conv2.weight': 2, 'layer2.3.bn1.beta': 2, 'layer2.3.bn1.gamma': 2, 'layer2.3.conv1.weight': 2, 'layer2.2.bn3.beta': 2, 'layer2.2.bn3.gamma': 2, 'layer2.2.conv3.weight': 2, 'layer2.2.bn2.beta': 2, 'layer2.2.bn2.gamma': 2, 'layer2.2.conv2.weight': 2, 'layer2.2.bn1.beta': 2, 'layer2.2.bn1.gamma': 2, 'layer2.2.conv1.weight': 2, 'layer2.1.bn3.beta': 1, 'layer2.1.bn3.gamma': 1, 'layer2.1.conv3.weight': 2, 'layer2.1.bn2.beta': 2, 'layer2.1.bn2.gamma': 2, 'layer2.1.conv2.weight': 2, 'layer2.1.bn1.beta': 2, 'layer2.1.bn1.gamma': 2, 'layer2.1.conv1.weight': 2, 'layer2.0.bn_down_sample.beta': 1, 'layer2.0.bn_down_sample.gamma': 1, 'layer2.0.conv_down_sample.weight': 2, 'layer2.0.bn3.beta': 1, 'layer2.0.bn3.gamma': 1, 'layer2.0.conv3.weight': 2, 'layer2.0.bn2.beta': 2, 'layer2.0.bn2.gamma': 2, 'layer2.0.conv2.weight': 2, 'layer2.0.bn1.beta': 2, 'layer2.0.bn1.gamma': 2, 'layer2.0.conv1.weight': 2, 'layer1.2.bn3.beta': 2, 'layer1.2.bn3.gamma': 2, 'layer1.2.conv3.weight': 2, 'layer1.2.bn2.beta': 2, 'layer1.2.bn2.gamma': 2, 'layer1.2.conv2.weight': 2, 'layer1.2.bn1.beta': 2, 'layer1.2.bn1.gamma': 2, 'layer1.2.conv1.weight': 2, 'layer1.1.bn3.beta': 1, 'layer1.1.bn3.gamma': 1, 'layer1.1.conv3.weight': 2, 'layer1.1.bn2.beta': 2, 'layer1.1.bn2.gamma': 2, 'layer1.1.conv2.weight': 2, 'layer1.1.bn1.beta': 2, 'layer1.1.bn1.gamma': 2, 'layer1.1.conv1.weight': 2, 'layer1.0.bn_down_sample.beta': 1, 'layer1.0.bn_down_sample.gamma': 1, 'layer1.0.conv_down_sample.weight': 2, 'layer1.0.bn3.beta': 1, 'layer1.0.bn3.gamma': 1, 'layer1.0.conv3.weight': 2, 'layer1.0.bn2.beta': 2, 'layer1.0.bn2.gamma': 2, 'layer1.0.conv2.weight': 2, 'layer1.0.bn1.beta': 2, 'layer1.0.bn1.gamma': 2, 'layer1.0.conv1.weight': 2, 'bn1.beta': 1, 'bn1.gamma': 1, 'conv1.weight': 2} assert (allreduce_fusion_dict == expect_dict) cost_model_context.reset_cost_model_context()