def test_get_optimizer_failed(self): """Test get optimizer failed.""" class Net(Cell): """Define net.""" def __init__(self): super(Net, self).__init__() self.add = TensorAdd() def construct(self, data): return data cb_params = _InternalCallbackParam() cb_params.optimizer = None cb_params.train_network = Net() cb_params.mode = ModeEnum.TRAIN.value summary_collector = SummaryCollector( (tempfile.mkdtemp(dir=self.base_summary_dir))) optimizer = summary_collector._get_optimizer(cb_params) assert optimizer is None assert summary_collector._temp_optimizer == 'Failed' # Test get optimizer again optimizer = summary_collector._get_optimizer(cb_params) assert optimizer is None assert summary_collector._temp_optimizer == 'Failed'
def test_collect_input_data_with_train_dataset_element_none(self): """Test the param 'train_dataset_element' in cb_params is none.""" cb_params = _InternalCallbackParam() cb_params.train_dataset_element = None summary_collector = SummaryCollector( (tempfile.mkdtemp(dir=self.base_summary_dir))) summary_collector._collect_input_data(cb_params) assert not summary_collector._collect_specified_data[ 'collect_input_data']
def test_collect_input_data_with_train_dataset_element_invalid(self): """Test the param 'train_dataset_element' in cb_params is invalid.""" cb_params = _InternalCallbackParam() for invalid in (), [], None: cb_params.train_dataset_element = invalid summary_collector = SummaryCollector( tempfile.mkdtemp(dir=self.base_summary_dir)) summary_collector._collect_input_data(cb_params) assert not summary_collector._collect_specified_data[ 'collect_input_data']
def test_check_callback_with_multi_instances(self): """Use multi SummaryCollector instances to test check_callback function.""" cb_params = _InternalCallbackParam() cb_params.list_callback = [ SummaryCollector(tempfile.mkdtemp(dir=self.base_summary_dir)), SummaryCollector(tempfile.mkdtemp(dir=self.base_summary_dir)) ] with pytest.raises(ValueError) as exc: SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir)))._check_callbacks(cb_params) assert f"more than one SummaryCollector instance in callback list" in str(exc.value)
def test_get_optimizer_from_cb_params_success(self): """Test get optimizer success from cb params.""" cb_params = _InternalCallbackParam() cb_params.optimizer = Optimizer(learning_rate=0.1, parameters=[Parameter(Tensor(1), 'weight')]) summary_collector = SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir))) optimizer = summary_collector._get_optimizer(cb_params) assert optimizer == cb_params.optimizer # Test get optimizer again assert summary_collector._get_optimizer(cb_params) == cb_params.optimizer
def test_params_with_summary_dir_value_error(self, summary_dir): """Test the exception scenario for summary dir.""" if isinstance(summary_dir, str): with pytest.raises(ValueError) as exc: SummaryCollector(summary_dir=summary_dir) assert str(exc.value) == 'For `summary_dir` the value should be a valid string of path, ' \ 'but got empty string.' else: with pytest.raises(TypeError) as exc: SummaryCollector(summary_dir=summary_dir) assert 'For `summary_dir` the type should be a valid type' in str(exc.value)
def test_get_loss(self, net_output, expected_loss): """Test get loss success and failed.""" cb_params = _InternalCallbackParam() cb_params.net_outputs = net_output summary_collector = SummaryCollector( (tempfile.mkdtemp(dir=self.base_summary_dir))) assert summary_collector._is_parse_loss_success assert summary_collector._get_loss(cb_params) == expected_loss if expected_loss is None: assert not summary_collector._is_parse_loss_success
def test_get_optimizer_from_network(self, mode): """Get optimizer from train network""" cb_params = _InternalCallbackParam() cb_params.optimizer = None cb_params.mode = mode if mode == ModeEnum.TRAIN.value: cb_params.train_network = CustomNet() else: cb_params.eval_network = CustomNet() summary_collector = SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir))) optimizer = summary_collector._get_optimizer(cb_params) assert isinstance(optimizer, Optimizer)
def test_params_with_collect_freq_exception(self, collect_freq): """Test the exception scenario for collect freq.""" summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) if isinstance(collect_freq, int): with pytest.raises(ValueError) as exc: SummaryCollector(summary_dir=summary_dir, collect_freq=collect_freq) expected_msg = f'For `collect_freq` the value should be greater than 0, but got `{collect_freq}`.' assert expected_msg == str(exc.value) else: with pytest.raises(TypeError) as exc: SummaryCollector(summary_dir=summary_dir, collect_freq=collect_freq) expected_msg = f"For `collect_freq` the type should be a valid type of ['int'], " \ f'bug got {type(collect_freq).__name__}.' assert expected_msg == str(exc.value)
def _run_network(self, dataset_sink_mode=False, num_samples=2, **kwargs): lenet = LeNet5() loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") optim = Momentum(lenet.trainable_params(), learning_rate=0.1, momentum=0.9) model = Model(lenet, loss_fn=loss, optimizer=optim, metrics={'loss': Loss()}) summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) summary_collector = SummaryCollector(summary_dir=summary_dir, collect_freq=2, **kwargs) ds_train = create_dataset(os.path.join(self.mnist_path, "train"), num_samples=num_samples) model.train(1, ds_train, callbacks=[summary_collector], dataset_sink_mode=dataset_sink_mode) ds_eval = create_dataset(os.path.join(self.mnist_path, "test")) model.eval(ds_eval, dataset_sink_mode=dataset_sink_mode, callbacks=[summary_collector]) return summary_dir
def test_collect_histogram_from_regular(self, mock_add_value, histogram_regular, expected_names, expected_values): """Test collect histogram from regular success.""" mock_add_value.side_effect = add_value cb_params = _InternalCallbackParam() parameters = [ Parameter(Tensor(1), 'conv1.weight1'), Parameter(Tensor(2), 'conv2.weight2'), Parameter(Tensor(3), 'conv1.bias1'), Parameter(Tensor(4), 'conv3.bias'), Parameter(Tensor(5), 'conv5.bias'), Parameter(Tensor(6), 'conv6.bias'), ] cb_params.optimizer = Optimizer(learning_rate=0.1, parameters=parameters) with SummaryCollector((tempfile.mkdtemp( dir=self.base_summary_dir))) as summary_collector: summary_collector._collect_specified_data[ 'histogram_regular'] = histogram_regular summary_collector._collect_histogram(cb_params) result = get_value() assert PluginEnum.HISTOGRAM.value == result[0][0] assert expected_names == [data[1] for data in result] assert expected_values == [data[2] for data in result]
def main(data_path, device_target='Ascend', summary_dir='./summary_dir', learning_rate=0.01): context.set_context(mode=context.GRAPH_MODE, device_target=device_target) momentum = 0.9 epoch_size = 1 batch_size = 32 network = LeNet5() network.set_train() net_loss = CrossEntropyLoss() net_opt = nn.Momentum(network.trainable_params(), learning_rate, momentum) model = Model(network, net_loss, net_opt) # Init SummaryCollector callback to record summary data in model.train or model.eval summary_collector = SummaryCollector(summary_dir=summary_dir, collect_freq=10) ds = create_dataset(os.path.join(data_path, "train"), batch_size=batch_size) print("============== Starting Training ==============") model.train(epoch_size, ds, callbacks=[summary_collector], dataset_sink_mode=False) print("============== Train End =====================")
def _run_network(self, dataset_sink_mode=True): lenet = LeNet5() loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") optim = Momentum(lenet.trainable_params(), learning_rate=0.1, momentum=0.9) model = Model(lenet, loss_fn=loss, optimizer=optim, metrics={'acc': Accuracy()}) summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) summary_collector = SummaryCollector(summary_dir=summary_dir, collect_freq=1) ds_train = create_dataset(os.path.join(self.mnist_path, "train")) model.train(1, ds_train, callbacks=[summary_collector], dataset_sink_mode=dataset_sink_mode) ds_eval = create_dataset(os.path.join(self.mnist_path, "test")) model.eval(ds_eval, dataset_sink_mode=dataset_sink_mode, callbacks=[summary_collector]) self._check_summary_result(summary_dir)
def test_params_with_action_exception(self, action): """Test the exception scenario for action.""" summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) with pytest.raises(TypeError) as exc: SummaryCollector(summary_dir=summary_dir, keep_default_action=action) expected_msg = f"For `keep_default_action` the type should be a valid type of ['bool'], " \ f"bug got {type(action).__name__}." assert expected_msg == str(exc.value)
def test_params_with_collect_specified_data_unexpected_key(self): """Test the collect_specified_data parameter with unexpected key.""" summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) data = {'unexpected_key': True} with pytest.raises(ValueError) as exc: SummaryCollector(summary_dir, collect_specified_data=data) expected_msg = f"For `collect_specified_data` the keys {set(data)} are unsupported" assert expected_msg in str(exc.value)
def test_collect_input_data_success(self, mock_add_value): """Mock a image data, and collect image data success.""" mock_add_value.side_effect = add_value cb_params = _InternalCallbackParam() image_data = Tensor(np.random.randint(0, 255, size=(1, 1, 1, 1)).astype(np.uint8)) cb_params.train_dataset_element = image_data with SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir))) as summary_collector: summary_collector._collect_input_data(cb_params)
def test_params_with_export_options_unexpected_key(self): """Test the export_options parameter with unexpected key.""" summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) data = {'unexpected_key': "value"} with pytest.raises(ValueError) as exc: SummaryCollector(summary_dir, export_options=data) expected_msg = f"For `export_options` the keys {set(data)} are unsupported" assert expected_msg in str(exc.value)
def test_process_specified_data(self, specified_data, action, expected_result): """Test process specified data.""" summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) summary_collector = SummaryCollector(summary_dir, collect_specified_data=specified_data, keep_default_action=action) assert summary_collector._collect_specified_data == expected_result
def test_params_with_summary_dir_not_dir(self): """Test the given summary dir parameter is not a directory.""" summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) summary_file = os.path.join(summary_dir, 'temp_file.txt') with open(summary_file, 'w') as file_handle: file_handle.write('temp') print(os.path.isfile(summary_file)) with pytest.raises(NotADirectoryError): SummaryCollector(summary_dir=summary_file)
def test_params_with_histogram_regular_value_error(self): """Test histogram regular.""" summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) with pytest.raises(ValueError) as exc: SummaryCollector(summary_dir, collect_specified_data={'histogram_regular': '*'}) assert 'For `collect_specified_data`, the value of `histogram_regular`' in str( exc.value)
def test_params_with_collect_specified_data_key_type_error(self, collect_specified_data): """Test the key of collect specified data param.""" summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) with pytest.raises(TypeError) as exc: SummaryCollector(summary_dir, collect_specified_data=collect_specified_data) param_name = list(collect_specified_data)[0] expected_msg = f"For `{param_name}` the type should be a valid type of ['str'], " \ f"bug got {type(param_name).__name__}." assert expected_msg == str(exc.value)
def test_params_with_export_options_type_error(self, export_options): """Test type error scenario for collect specified data param.""" summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) with pytest.raises(TypeError) as exc: SummaryCollector(summary_dir, export_options=export_options) expected_msg = f"For `export_options` the type should be a valid type of ['dict', 'NoneType'], " \ f"but got {type(export_options).__name__}." assert expected_msg == str(exc.value)
def test_collect_dataset_graph_success(self, mock_add_value): """Test collect dataset graph.""" dataset = import_module('mindspore.dataset') mock_add_value.side_effect = add_value cb_params = _InternalCallbackParam() cb_params.train_dataset = dataset.MnistDataset(dataset_dir=tempfile.mkdtemp(dir=self.base_summary_dir)) cb_params.mode = ModeEnum.TRAIN.value with SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir))) as summary_collector: summary_collector._collect_dataset_graph(cb_params) plugin, name, _ = get_value()[0] assert plugin == 'dataset_graph' assert name == 'train_dataset'
def test_params_with_tensor_format_type_error(self, export_options): """Test type error scenario for collect specified data param.""" summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) with pytest.raises(ValueError) as exc: SummaryCollector(summary_dir, export_options=export_options) unexpected_format = {export_options.get("tensor_format")} expected_msg = f'For `export_options`, the export_format {unexpected_format} are ' \ f'unsupported for tensor_format, expect the follow values: ' \ f'{list(_DEFAULT_EXPORT_OPTIONS.get("tensor_format"))}' assert expected_msg == str(exc.value)
def test_params_with_collect_specified_data_value_type_error(self, collect_specified_data): """Test the value of collect specified data param.""" summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) with pytest.raises(TypeError) as exc: SummaryCollector(summary_dir, collect_specified_data=collect_specified_data) param_name = list(collect_specified_data)[0] param_value = collect_specified_data[param_name] expected_type = "['bool']" if param_name != 'histogram_regular' else "['str', 'NoneType']" expected_msg = f'For `{param_name}` the type should be a valid type of {expected_type}, ' \ f'bug got {type(param_value).__name__}.' assert expected_msg == str(exc.value)
def test_graph_summary_callback(): dataset = get_dataset() net = Net() loss = nn.SoftmaxCrossEntropyWithLogits() optim = Momentum(net.trainable_params(), 0.1, 0.9) context.set_context(mode=context.GRAPH_MODE) model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) summary_collector = SummaryCollector( SUMMARY_DIR, collect_freq=1, keep_default_action=False, collect_specified_data={'collect_graph': True}) model.train(1, dataset, callbacks=[summary_collector])
def test_params_with_custom_lineage_data_type_error(self, custom_lineage_data): """Test the custom lineage data parameter type error.""" summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) with pytest.raises(TypeError) as exc: SummaryCollector(summary_dir, custom_lineage_data=custom_lineage_data) if not isinstance(custom_lineage_data, dict): expected_msg = f"For `custom_lineage_data` the type should be a valid type of ['dict', 'NoneType'], " \ f"bug got {type(custom_lineage_data).__name__}." else: param_name = list(custom_lineage_data)[0] param_value = custom_lineage_data[param_name] if not isinstance(param_name, str): arg_name = f'custom_lineage_data -> {param_name}' expected_msg = f"For `{arg_name}` the type should be a valid type of ['str'], " \ f'bug got {type(param_name).__name__}.' else: arg_name = f'the value of custom_lineage_data -> {param_name}' expected_msg = f"For `{arg_name}` the type should be a valid type of ['int', 'str', 'float'], " \ f'bug got {type(param_value).__name__}.' assert expected_msg == str(exc.value)
0.01, 0.9) model = Model(net, loss_fn=ls, optimizer=opt, metrics={'acc'}) # as for train, users could use model.train if args_opt.do_train: dataset = create_dataset() config_ck = CheckpointConfig(save_checkpoint_steps=10, keep_checkpoint_max=1000) ckpoint_cb = ModelCheckpoint(prefix="", directory=os.path.join( summary_dir, "weights"), config=config_ck) # data_saver_callback = DataSaverCallback(summary_dir=summary_dir, save_interval=1) data_saver_callback = DataSaverCallbackGPU(summary_dir=summary_dir) summary_cb = SummaryCollector(summary_dir=summary_dir, collect_freq=1000) model.train(epoch_size, dataset, callbacks=[ LossMonitor(), data_saver_callback, summary_cb, ckpoint_cb ], dataset_sink_mode=False) # model.train(epoch_size, dataset, callbacks=[LossMonitor(), summary_cb], dataset_sink_mode=False) # as for evaluation, users could use model.eval if args_opt.do_eval: if args_opt.checkpoint_path: param_dict = load_checkpoint(args_opt.checkpoint_path) load_param_into_net(net, param_dict) eval_dataset = create_dataset(training=False)
def _build_training_pipeline(config: GNMTConfig, pre_training_dataset=None, fine_tune_dataset=None, test_dataset=None): """ Build training pipeline. Args: config (GNMTConfig): Config of mass model. pre_training_dataset (Dataset): Pre-training dataset. fine_tune_dataset (Dataset): Fine-tune dataset. test_dataset (Dataset): Test dataset. """ net_with_loss = GNMTNetworkWithLoss(config, is_training=True, use_one_hot_embeddings=True) net_with_loss.init_parameters_data() _load_checkpoint_to_net(config, net_with_loss) dataset = pre_training_dataset if pre_training_dataset is not None \ else fine_tune_dataset if dataset is None: raise ValueError( "pre-training dataset or fine-tuning dataset must be provided one." ) update_steps = config.epochs * dataset.get_dataset_size() lr = _get_lr(config, update_steps) optimizer = _get_optimizer(config, net_with_loss, lr) # Dynamic loss scale. scale_manager = DynamicLossScaleManager( init_loss_scale=config.init_loss_scale, scale_factor=config.loss_scale_factor, scale_window=config.scale_window) net_with_grads = GNMTTrainOneStepWithLossScaleCell( network=net_with_loss, optimizer=optimizer, scale_update_cell=scale_manager.get_update_cell()) net_with_grads.set_train(True) model = Model(net_with_grads) loss_monitor = LossCallBack(config) dataset_size = dataset.get_dataset_size() time_cb = TimeMonitor(data_size=dataset_size) ckpt_config = CheckpointConfig( save_checkpoint_steps=config.save_ckpt_steps, keep_checkpoint_max=config.keep_ckpt_max) rank_size = os.getenv('RANK_SIZE') callbacks = [time_cb, loss_monitor] if rank_size is not None and int( rank_size) > 1 and MultiAscend.get_rank() % 8 == 0: ckpt_callback = ModelCheckpoint( prefix=config.ckpt_prefix, directory=os.path.join(config.ckpt_path, 'ckpt_{}'.format(os.getenv('DEVICE_ID'))), config=ckpt_config) callbacks.append(ckpt_callback) summary_callback = SummaryCollector(summary_dir="./summary", collect_freq=50) callbacks.append(summary_callback) if rank_size is None or int(rank_size) == 1: ckpt_callback = ModelCheckpoint( prefix=config.ckpt_prefix, directory=os.path.join(config.ckpt_path, 'ckpt_{}'.format(os.getenv('DEVICE_ID'))), config=ckpt_config) callbacks.append(ckpt_callback) summary_callback = SummaryCollector(summary_dir="./summary", collect_freq=50) callbacks.append(summary_callback) print(f" | ALL SET, PREPARE TO TRAIN.") _train(model=model, config=config, pre_training_dataset=pre_training_dataset, fine_tune_dataset=fine_tune_dataset, test_dataset=test_dataset, callbacks=callbacks)
def main(): set_seed(1) date = time.strftime("%Y%m%d%H%M%S", time.localtime()) print(f'* Preparing to train model {date}') # ************** configuration **************** # - training setting resume = config['resume'] if config['mode'] == 'PYNATIVE': mode = context.PYNATIVE_MODE else: mode = context.GRAPH_MODE device = config['device'] device_id = config['device_id'] dataset_sink_mode = config['dataset_sink_mode'] # use in dataset div = 8 # setting bias and padding if resume: print('* Resuming model...') resume_config_log = config['resume_config_log'] resume_config = get_eval_config(resume_config_log) if 'best_ckpt' in resume_config.keys(): resume_model_path = resume_config['best_ckpt'] else: resume_model_path = resume_config['latest_model'] print('* [WARNING] Not using the best model, but latest saved model instead.') has_bias = resume_config['has_bias'] use_dropout = resume_config['use_dropout'] pad_mode = resume_config['pad_mode'] if pad_mode == 'pad': padding = resume_config['padding'] elif pad_mode == 'same': padding = 0 else: raise ValueError(f"invalid pad mode: {pad_mode}!") best_acc = resume_config['best_acc'] best_ckpt = resume_config['best_ckpt'] print('* The best accuracy in dev dataset for the current resumed model is {:.2f}%'.format(best_acc * 100)) else: has_bias = config['has_bias'] use_dropout = config['use_dropout'] pad_mode = config['pad_mode'] if pad_mode == 'pad': padding = config['padding'] elif pad_mode == 'same': padding = 0 else: raise ValueError(f"invalid pad mode: {pad_mode}!") # hyper-parameters if resume: batch_size = resume_config['batch_size'] opt_type = resume_config['opt'] use_dynamic_lr = resume_config['use_dynamic_lr'] warmup_step = resume_config['warmup_step'] warmup_ratio = resume_config['warmup_ratio'] else: batch_size = config['batch_size'] opt_type = config['opt'] use_dynamic_lr = config['use_dynamic_lr'] warmup_step = config['warmup_step'] warmup_ratio = config['warmup_ratio'] test_dev_batch_size = config['test_dev_batch_size'] learning_rate = float(config['learning_rate']) epochs = config['epochs'] loss_scale = config['loss_scale'] # configuration of saving model checkpoint save_checkpoint_steps = config['save_checkpoint_steps'] keep_checkpoint_max = config['keep_checkpoint_max'] prefix = config['prefix'] + '_' + date model_dir = config['model_dir'] # loss monitor loss_monitor_step = config['loss_monitor_step'] # whether to use mindInsight summary use_summary = config['use_summary'] # step_eval use_step_eval = config['use_step_eval'] eval_step = config['eval_step'] eval_epoch = config['eval_epoch'] patience = config['patience'] # eval in steps or epochs step_eval = True if eval_step == -1: step_eval = False # ************** end of configuration ************** if device == 'GPU': context.set_context(mode=mode, device_target=device, device_id=device_id) elif device == 'Ascend': import moxing as mox from utils.const import DATA_PATH, MODEL_PATH, BEST_MODEL_PATH, LOG_PATH obs_datapath = config['obs_datapath'] obs_saved_model = config['obs_saved_model'] obs_best_model = config['obs_best_model'] obs_log = config['obs_log'] mox.file.copy_parallel(obs_datapath, DATA_PATH) mox.file.copy_parallel(MODEL_PATH, obs_saved_model) mox.file.copy_parallel(BEST_MODEL_PATH, obs_best_model) mox.file.copy_parallel(LOG_PATH, obs_log) context.set_context(mode=mode, device_target=device) use_summary = False # callbacks function callbacks = [] # data train_loader, idx2label, label2idx = get_dataset(batch_size=batch_size, phase='train', test_dev_batch_size=test_dev_batch_size, div=div, num_parallel_workers=4) if eval_step == 0: eval_step = train_loader.get_dataset_size() # network net = DFCNN(num_classes=len(label2idx), padding=padding, pad_mode=pad_mode, has_bias=has_bias, use_dropout=use_dropout) # Criterion criterion = CTCLoss() # resume if resume: print("* Loading parameters...") param_dict = load_checkpoint(resume_model_path) # load the parameter into net load_param_into_net(net, param_dict) print(f'* Parameters loading from {resume_model_path} succeeded!') net.set_train(True) net.set_grad(True) # lr schedule if use_dynamic_lr: dataset_size = train_loader.get_dataset_size() learning_rate = Tensor(dynamic_lr(base_lr=learning_rate, warmup_step=warmup_step, warmup_ratio=warmup_ratio, epochs=epochs, steps_per_epoch=dataset_size), mstype.float32) print('* Using dynamic learning rate, which will be set up as :', learning_rate.asnumpy()) # optim if opt_type == 'adam': opt = nn.Adam(net.trainable_params(), learning_rate=learning_rate, beta1=0.9, beta2=0.999, weight_decay=0.0, eps=10e-8) elif opt_type == 'rms': opt = nn.RMSProp(params=net.trainable_params(), centered=True, learning_rate=learning_rate, momentum=0.9, loss_scale=loss_scale) elif opt_type == 'sgd': opt = nn.SGD(params=net.trainable_params(), learning_rate=learning_rate) else: raise ValueError(f"optimizer: {opt_type} is not supported for now!") if resume: # load the parameter into optimizer load_param_into_net(opt, param_dict) # save_model config_ck = CheckpointConfig(save_checkpoint_steps=save_checkpoint_steps, keep_checkpoint_max=keep_checkpoint_max) ckpt_cb = ModelCheckpoint(prefix=prefix, directory=model_dir, config=config_ck) # logger the_logger = logger(config, date) log = Logging(logger=the_logger, model_ckpt=ckpt_cb) callbacks.append(ckpt_cb) callbacks.append(log) net = WithLossCell(net, criterion) scaling_sens = Tensor(np.full((1), loss_scale), dtype=mstype.float32) net = DFCNNCTCTrainOneStepWithLossScaleCell(net, opt, scaling_sens) net.set_train(True) model = Model(net) if use_step_eval: # step evaluation step_eval = StepAccInfo(model=model, name=prefix, div=div, test_dev_batch_size=test_dev_batch_size, step_eval=step_eval, eval_step=eval_step, eval_epoch=eval_epoch, logger=the_logger, patience=patience, dataset_size=train_loader.get_dataset_size()) callbacks.append(step_eval) # loss monitor loss_monitor = LossMonitor(loss_monitor_step) callbacks.append(loss_monitor) if use_summary: summary_dir = os.path.join(SUMMARY_DIR, date) if not os.path.exists(summary_dir): os.mkdir(summary_dir) # mindInsight summary_collector = SummaryCollector(summary_dir=summary_dir, collect_freq=1, max_file_size=4 * 1024 ** 3) callbacks.append(summary_collector) if resume: the_logger.update_acc_ckpt(best_acc, best_ckpt) print(f'* Start training...') model.train(epochs, train_loader, callbacks=callbacks, dataset_sink_mode=dataset_sink_mode)