def reshape_common(parallel_mode, strategy0, strategy1, strategy2, strategy_loss): learning_rate = 0.1 momentum = 0.9 epoch_size = 2 context.reset_auto_parallel_context() context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8) predict = Tensor(np.ones([32, 512, 7, 7]), dtype=ms.float32) label = Tensor(np.ones([32]), dtype=ms.int32) dataset = Dataset(predict, label, 2) net = reshape_net(strategy0, strategy1, strategy2) loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') loss.softmax_cross_entropy.shard(strategy_loss) loss.one_hot.shard(((8, 1), (), ())) opt = Momentum(net.trainable_params(), learning_rate, momentum) model = Model(net, loss, opt) model.train(epoch_size, dataset, dataset_sink_mode=False)
def no_ps_impl(self): context.set_ps_context(enable_ps=False) net = Menet(self.in_channels, self.out_channels, self.kernel_size, self.vocab_size, self.embedding_size, self.output_channels, self.target, self.sparse) net.conv.conv2d.add_prim_attr('primitive_target', 'CPU') net.conv.bias_add.add_prim_attr('primitive_target', 'CPU') net.set_train() loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') opt = Adam( params=filter(lambda x: x.requires_grad, net.get_parameters())) opt.target = 'CPU' model = Model(net, loss, opt) model.train(self.epoch_size, self.dataset, dataset_sink_mode=False) input_me = Tensor(self.input_np) out_me = model.predict(input_me) context.set_ps_context(enable_ps=True) return out_me.asnumpy()
def test_flatten_reshape4(parallel_mode="semi_auto_parallel"): batch_size = 16 learning_rate = 0.1 momentum = 0.9 epoch_size = 2 context.reset_auto_parallel_context() context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8) set_algo_parameters(fully_use_devices=False) net = ParallelReduceMeanNet(conv_in_channel=3, conv_out_channel=64, reducemean_keep_dims=True, strategy=((4, 1, 1, 1),)) loss = CrossEntropyLoss2() predict = Tensor(np.ones([batch_size, 3, 32, 32]), dtype=ms.float32) label = Tensor(np.ones([batch_size, 2048]), dtype=ms.float32) dataset = Dataset(predict, label, 2, input_num=2) opt = Momentum(net.trainable_params(), learning_rate, momentum) model = Model(net, loss_fn=loss, optimizer=opt) model.train(epoch_size, dataset, dataset_sink_mode=False)
def test_compile_f16_model_train_fixed(): dataset_types = (np.float32, np.float32) dataset_shapes = ((16, 16), (16, 16)) dataset = MindDataSet(dataset_types, dataset_shapes) net = NetFP16(16, 16) net.set_train() scale_manager = FixedLossScaleManager() loss = MSELoss() optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) model = Model(net, loss_fn=loss, optimizer=optimizer, metrics=None, loss_scale_manager=scale_manager) model.train(2, dataset)
def test_train_and_eval_lenet(): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") network = LeNet5(10) net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) print("============== Starting Training ==============") ds_train = create_dataset( os.path.join('/home/workspace/mindspore_dataset/mnist', "train"), 32, 1) model.train(1, ds_train, callbacks=[LossMonitor()], dataset_sink_mode=True) print("============== Starting Testing ==============") ds_eval = create_dataset( os.path.join('/home/workspace/mindspore_dataset/mnist', "test"), 32, 1) acc = model.eval(ds_eval, dataset_sink_mode=True) print("============== {} ==============".format(acc))
def test_semi_one_hot_net_model(): batch_size = 16 learning_rate = 0.1 momentum = 0.9 epoch_size = 2 predict = Tensor(np.ones([batch_size, 512]), dtype=ms.float32) label = Tensor(np.ones([batch_size]), dtype=ms.int32) dataset = Dataset(predict, label, 2, input_num=2) net = SemiAutoOneHotNet(args=Args(), strategy=StrategyModel()) opt = Momentum(net.trainable_params(), learning_rate, momentum) context.reset_auto_parallel_context() context.set_auto_parallel_context( parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=16) context.set_context(mode=context.GRAPH_MODE) model = Model(net, optimizer=opt) model.train(epoch_size, dataset, dataset_sink_mode=False)
def test_data_parallel_mode(): _reset_op_id() learning_rate = 0.1 momentum = 0.9 epoch_size = 2 context.set_context(mode=context.GRAPH_MODE, save_graphs=False) context.reset_auto_parallel_context() context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, full_batch=True) predict = Tensor(np.ones([256, 128]), dtype=ms.float32) label = Tensor(np.ones([256]), dtype=ms.int32) dataset = Dataset(predict, label, 2) net = all_to_all_net(None) loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') opt = Momentum(net.trainable_params(), learning_rate, momentum) model = Model(net, loss, opt) with pytest.raises(RuntimeError): model.train(epoch_size, dataset, dataset_sink_mode=False)
def test_membership_inference_object_train(): net = Net() loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) model = Model(network=net, loss_fn=loss, optimizer=opt) inference_model = MembershipInference(model, 2) assert isinstance(inference_model, MembershipInference) config = [{ "method": "KNN", "params": { "n_neighbors": [3, 5, 7], } }] ds_train = ds.GeneratorDataset(dataset_generator, ["image", "label"]) ds_test = ds.GeneratorDataset(dataset_generator, ["image", "label"]) inference_model.train(ds_train, ds_test, config)
def train(network, net_opt, ds_train, prefix, directory, print_times): net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") model = Model(network, loss_fn=net_loss, optimizer=net_opt, metrics={"acc"}) loss_cb = LossMonitor(per_print_times=print_times) config_ck = CheckpointConfig( save_checkpoint_steps=cfg.save_checkpoint_steps, keep_checkpoint_max=cfg.keep_checkpoint_max) ckpoint_cb = ModelCheckpoint(prefix=prefix, directory=directory, config=config_ck) print("============== Starting Training ==============") model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, loss_cb], dataset_sink_mode=False) return model
def test_membership_inference_eval(): net = Net() loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) model = Model(network=net, loss_fn=loss, optimizer=opt) inference_model = MembershipInference(model, -1) assert isinstance(inference_model, MembershipInference) batch_size = 16 batches = 1 eval_train = ds.GeneratorDataset(dataset_generator(batch_size, batches), ["image", "label"]) eval_test = ds.GeneratorDataset(dataset_generator(batch_size, batches), ["image", "label"]) metrics = ["precision", "accuracy", "recall"] inference_model.eval(eval_train, eval_test, metrics)
def train(data_dir, lr=0.01, momentum=0.9, num_epochs=3): ds_train = create_dataset(data_dir) ds_eval = create_dataset(data_dir, training=False) net = LeNet5() loss = nn.loss.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') opt = nn.Momentum(net.trainable_params(), lr, momentum) loss_cb = LossMonitor(per_print_times=ds_train.get_dataset_size()) model = Model(net, loss, opt, metrics={'acc', 'loss'}) # dataset_sink_mode can be True when using Ascend model.train(num_epochs, ds_train, callbacks=[loss_cb], dataset_sink_mode=False) metrics = model.eval(ds_eval, dataset_sink_mode=False) print('Metrics:', metrics)
def all_to_all_common(): learning_rate = 0.1 momentum = 0.9 epoch_size = 2 context.reset_auto_parallel_context() context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=1, global_rank=0) predict = Tensor(np.ones([32, 128]), dtype=ms.float32) label = Tensor(np.ones([32]), dtype=ms.int32) dataset = Dataset(predict, label, 2) net = all_to_all_net() loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) opt = Momentum(net.trainable_params(), learning_rate, momentum) model = Model(net, loss, opt) model.train(epoch_size, dataset, dataset_sink_mode=False) strategys = _executor._get_strategy(model._train_network) return strategys
def infer(data_dir): ds = create_dataset(data_dir, training=False).create_dict_iterator() data = ds.get_next() images = data['image'] labels = data['label'] net = LeNet5() load_checkpoint(CKPT_2, net=net) model = Model(net) output = model.predict(Tensor(data['image'])) preds = np.argmax(output.asnumpy(), axis=1) for i in range(1, 5): plt.subplot(2, 2, i) plt.imshow(np.squeeze(images[i])) color = 'blue' if preds[i] == labels[i] else 'red' plt.title("prediction: {}, truth: {}".format(preds[i], labels[i]), color=color) plt.xticks([]) plt.show()
def train_lenet_quant(): context.set_context(mode=context.GRAPH_MODE, device_target=device_target) cfg = quant_cfg ckpt_path = './ckpt_lenet_noquant-10_1875.ckpt' ds_train = create_dataset(os.path.join(data_path, "train"), cfg.batch_size, 1) step_size = ds_train.get_dataset_size() # define fusion network network = LeNet5Fusion(cfg.num_classes) # load quantization aware network checkpoint param_dict = load_checkpoint(ckpt_path) load_nonquant_param_into_quant_net(network, param_dict) # convert fusion network to quantization aware network quantizer = QuantizationAwareTraining(quant_delay=900, bn_fold=False, per_channel=[True, False], symmetric=[True, False]) network = quantizer.quantize(network) # define network loss net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") # define network optimization net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) # call back and monitor config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, keep_checkpoint_max=cfg.keep_checkpoint_max) ckpt_callback = ModelCheckpoint(prefix="ckpt_lenet_quant", config=config_ckpt) # define model model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) print("============== Starting Training ==============") model.train(cfg['epoch_size'], ds_train, callbacks=[ckpt_callback, LossMonitor()], dataset_sink_mode=True) print("============== End Training ==============")
def test_callbacks_non_sink_mismatch_size(): logger.info("test_callbacks_non_sink_mismatch_size") default_timeout = ds.config.get_callback_timeout() ds.config.set_callback_timeout(1) events = [] my_cb1 = MyWaitedCallback(events, 2) my_cb2 = MyMSCallback(events) arr = [1, 2, 3, 4] data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False) data = data.map(operations=(lambda x: x), callbacks=my_cb1) data = data.batch(3) net = Net() model = Model(net) with pytest.raises(Exception) as err: model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1]) assert "RuntimeError: ds_step_begin timed out after 1 second(s)" in str(err.value) ds.config.set_callback_timeout(default_timeout)
def _model_train_and_save_ckpt(self, net, dataset, epoch): self.opt = Momentum(learning_rate=0.01, momentum=0.9, params=net.get_parameters()) self.loss_fn = SoftmaxCrossEntropyWithLogits(reduction='mean') self.model = Model(network=net, loss_fn=self.loss_fn, optimizer=self.opt) ckpt_config = CheckpointConfig(keep_checkpoint_max=1) ckpt_path = './rank_{}_ckpt'.format(self.global_rank_id) ckpt_callback = ModelCheckpoint(prefix='parallel', directory=ckpt_path, config=ckpt_config) clean_all_ckpt_files(ckpt_path) self.model.train(epoch=epoch, train_dataset=dataset, callbacks=[ckpt_callback], dataset_sink_mode=False) newest_ckpt_file = find_newest_ckpt_file(ckpt_path) return load_checkpoint(newest_ckpt_file)
def train(Net): ds_train, ds_test = create_dataset() # 构建网络 network = Net(cfg.num_classes) # 定义模型的损失函数,优化器 net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") net_opt = nn.Adam(network.trainable_params(), cfg.lr) # 训练模型 model = Model(network, loss_fn=net_loss, optimizer=net_opt, metrics={'acc': Accuracy()}) loss_cb = LossMonitor() print("============== Starting Training ==============") model.train(30, ds_train, callbacks=[loss_cb], dataset_sink_mode=True) # 验证 metric = model.eval(ds_test) print(metric) return model
def compile_net(net): context.set_context(save_graphs=True) learning_rate = 0.1 momentum = 0.9 epoch_size = 2 dataset = Dataset(_x, _b) opt = Momentum(net.trainable_params(), learning_rate, momentum) model = Model(net, optimizer=opt) ckpt_config = CheckpointConfig(keep_checkpoint_max=1) ckpt_path = "./parallel_ckpt" ckpt_cb = ModelCheckpoint(prefix="parallel", directory=ckpt_path, config=ckpt_config) model.train(epoch_size, dataset, dataset_sink_mode=False, callbacks=[ckpt_cb]) assert len(model._train_network.parallel_parameter_merge_net_dict) == 4 clean_all_ckpt_files(ckpt_path) context.reset_auto_parallel_context()
def main(data_path, device_target='Ascend', summary_dir='./summary_dir', learning_rate=0.01): context.set_context(mode=context.GRAPH_MODE, device_target=device_target) momentum = 0.9 epoch_size = 1 batch_size = 32 network = LeNet5() network.set_train() net_loss = CrossEntropyLoss() net_opt = nn.Momentum(network.trainable_params(), learning_rate, momentum) model = Model(network, net_loss, net_opt) # Init SummaryCollector callback to record summary data in model.train or model.eval summary_collector = SummaryCollector(summary_dir=summary_dir, collect_freq=10) ds = create_dataset(os.path.join(data_path, "train"), batch_size=batch_size) print("============== Starting Training ==============") model.train(epoch_size, ds, callbacks=[summary_collector], dataset_sink_mode=False) print("============== Train End =====================")
def eval_lenet(): context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) network = LeNet5(config.num_classes) net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") # repeat_size = config.epoch_size net_opt = nn.Momentum(network.trainable_params(), config.lr, config.momentum) model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) print("============== Starting Testing ==============") param_dict = load_checkpoint(ckpt_path) load_param_into_net(network, param_dict) ds_eval = create_dataset(os.path.join(config.data_path, "test"), config.batch_size, 1) if ds_eval.get_dataset_size() == 0: raise ValueError("Please check dataset size > 0 and batch_size <= dataset size") acc = model.eval(ds_eval) print("============== {} ==============".format(acc))
def _model_train_and_save_ckpt(self, net, dataset, epoch): self.opt = Adam(params=net.get_parameters()) if self.target == 'CPU': self.opt.target = self.target if self.sparse: context.set_context(enable_sparse=True) self.model = Model(network=net, loss_fn=self.loss_fn, optimizer=self.opt) ckpt_config = CheckpointConfig(keep_checkpoint_max=1) ckpt_path = './rank_{}_ckpt'.format(self.global_rank_id) ckpt_callback = ModelCheckpoint(prefix='parallel', directory=ckpt_path, config=ckpt_config) clean_all_ckpt_files(ckpt_path) self.model.train(epoch=epoch, train_dataset=dataset, callbacks=[ckpt_callback], dataset_sink_mode=False) newest_ckpt_file = find_newest_ckpt_file(ckpt_path) return load_checkpoint(newest_ckpt_file)
def quant_resnet50(network, dataset, loss, input_data): """quantize the resnet50 """ # step2: creat the quant config json file create_quant_config('./config.json', network, input_data) # step3: do some network modification and return the modified network calibration_network = quantize_model('./config.json', network, input_data) calibration_network.set_train(False) # step4: perform the evaluation of network to do activation calibration model = Model(calibration_network, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'}) _ = model.eval(dataset, dataset_sink_mode=False) # step5: export the air file save_model('results/resnet50_quant', calibration_network, input_data) print("[INFO] the quantized AIR file has been stored at: \n {}".format( 'results/resnet50_quant.air'))
def test_loss_scale_fp16_model_train_overflow(): dataset_types = (np.float32, np.float32) dataset_shapes = ((16, 16), (16, 16)) dataset = MindDataSet(dataset_types, dataset_shapes) net = NetFP16(16, 16) net.set_train() loss = MSELoss() optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) scale_manager = DynamicLossScaleManager(init_loss_scale=16, scale_factor=2, scale_window=2) model = Model(net, loss_fn=loss, optimizer=optimizer, metrics=None, loss_scale_manager=scale_manager) model.train(2, dataset, dataset_sink_mode=False)
def transpose_common(strategy1, strategy2): batch_size = 32 learning_rate = 0.1 momentum = 0.9 epoch_size = 2 context.reset_auto_parallel_context() context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=8, parameter_broadcast=False) predict = Tensor(np.ones([32, 128]), dtype=ms.float32) label = Tensor(np.ones([32]), dtype=ms.int32) dataset = Dataset(predict, label, 2) net = transpose_net(strategy1, strategy2) loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) loss.softmax_cross_entropy.set_strategy(((8, 1), (8, 1))) opt = Momentum(net.trainable_params(), learning_rate, momentum) context.set_context(mode=context.GRAPH_MODE) model = Model(net, loss, opt) model.train(epoch_size, dataset, dataset_sink_mode=False)
def train_lenet(): context.set_context(mode=context.GRAPH_MODE, device_target=device_target) cfg = nonquant_cfg ds_train = create_dataset(os.path.join(data_path, "train"), cfg.batch_size) network = LeNet5(cfg.num_classes) net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) config_ck = CheckpointConfig( save_checkpoint_steps=cfg.save_checkpoint_steps, keep_checkpoint_max=cfg.keep_checkpoint_max) ckpoint_cb = ModelCheckpoint(prefix="ckpt_lenet_noquant", config=config_ck) model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) print("============== Starting Training Lenet==============") model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()], dataset_sink_mode=True)
def test_flatten_reshape3(parallel_mode="auto_parallel"): batch_size = 16 learning_rate = 0.1 momentum = 0.9 epoch_size = 2 context.reset_auto_parallel_context() context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8) set_algo_parameters(fully_use_devices=False) net = ParallelReshapeNet(dense_in_channel=2048, dense_out_channel=1000, shape=(128, 1000), strategy=((16, 1), )) loss = CrossEntropyLoss() predict = Tensor(np.ones([batch_size, 1, 2, 1024]), dtype=ms.float32) label = Tensor(np.ones([batch_size, 1000]), dtype=ms.float32) dataset = Dataset(predict, label, 2, input_num=2) opt = Momentum(net.trainable_params(), learning_rate, momentum) model = Model(net, loss_fn=loss, optimizer=opt) model.train(epoch_size, dataset, dataset_sink_mode=False)
def test_compile_model_train_O2(): dataset_types = (np.float32, np.float32) dataset_shapes = ((16, 16), (16, 16)) dataset = MindDataSet(dataset_types, dataset_shapes) net = NetNoLoss(16, 16) loss = nn.MSELoss() optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"acc"}, amp_level="O2") model.train(2, dataset, dataset_sink_mode=False) with pytest.raises(ValueError): # not actual run, the metrics step will fail, check if compile ok. model.eval(dataset)
def test_fuzzing_ascend(): context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") # load network net = Net() model = Model(net) batch_size = 8 num_classe = 10 mutate_config = [{'method': 'Blur', 'params': {'auto_param': [True]}}, {'method': 'Contrast', 'params': {'factor': [2, 1]}}, {'method': 'Translate', 'params': {'x_bias': [0.1, 0.3], 'y_bias': [0.2]}}, {'method': 'FGSM', 'params': {'eps': [0.1, 0.2, 0.3], 'alpha': [0.1]}} ] # initialize fuzz test with training dataset train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) model_coverage_test = ModelCoverageMetrics(model, 10, 1000, train_images) # fuzz test with original test data # get test data test_images = np.random.rand(batch_size, 1, 32, 32).astype(np.float32) test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32) test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32) initial_seeds = [] # make initial seeds for img, label in zip(test_images, test_labels): initial_seeds.append([img, label]) initial_seeds = initial_seeds[:100] model_coverage_test.calculate_coverage( np.array(test_images[:100]).astype(np.float32)) LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc()) model_fuzz_test = Fuzzer(model, train_images, 10, 1000) _, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds) print(metrics)
def test_callbacks_non_sink_batch_size2(): logger.info("test_callbacks_non_sink_batch_size2") events = [] my_cb1 = MyWaitedCallback(events, 2) my_cb2 = MyMSCallback(events) arr = [1, 2, 3, 4] data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False) data = data.map(operations=(lambda x: x), callbacks=my_cb1) data = data.batch(2) net = Net() model = Model(net) model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1]) expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_3', 'ms_step_end_1_2', 'ms_epoch_end_1_2', 'ds_epoch_begin_2_4', 'ds_step_begin_2_5', 'ms_step_end_2_3', 'ds_step_begin_2_7', 'ms_step_end_2_4', 'ms_epoch_end_2_4'] assert events[:10] == expected_synced_events
def all_to_all_common(strategy1): learning_rate = 0.1 momentum = 0.9 epoch_size = 2 context.reset_auto_parallel_context() context.set_auto_parallel_context( parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=8) predict = Tensor(np.ones([32, 128]), dtype=ms.float32) label = Tensor(np.ones([32]), dtype=ms.int32) dataset = Dataset(predict, label, 2) net = all_to_all_net(strategy1) loss = SoftmaxCrossEntropyWithLogits(sparse=True) loss.softmax_cross_entropy.shard(((8, 1), (8, 1))) loss.one_hot.shard(((8, 1), (), ())) opt = Momentum(net.trainable_params(), learning_rate, momentum) model = Model(net, loss, opt) model.train(epoch_size, dataset, dataset_sink_mode=False) strategys = _executor._get_shard_strategy(model._train_network) return strategys