示例#1
0
    def test_get_optimizer_failed(self):
        """Test get optimizer failed."""
        class Net(Cell):
            """Define net."""
            def __init__(self):
                super(Net, self).__init__()
                self.add = TensorAdd()

            def construct(self, data):
                return data

        cb_params = _InternalCallbackParam()
        cb_params.optimizer = None
        cb_params.train_network = Net()
        cb_params.mode = ModeEnum.TRAIN.value
        summary_collector = SummaryCollector(
            (tempfile.mkdtemp(dir=self.base_summary_dir)))
        optimizer = summary_collector._get_optimizer(cb_params)
        assert optimizer is None
        assert summary_collector._temp_optimizer == 'Failed'

        # Test get optimizer again
        optimizer = summary_collector._get_optimizer(cb_params)
        assert optimizer is None
        assert summary_collector._temp_optimizer == 'Failed'
示例#2
0
 def test_collect_input_data_with_train_dataset_element_none(self):
     """Test the param 'train_dataset_element' in cb_params is none."""
     cb_params = _InternalCallbackParam()
     cb_params.train_dataset_element = None
     summary_collector = SummaryCollector(
         (tempfile.mkdtemp(dir=self.base_summary_dir)))
     summary_collector._collect_input_data(cb_params)
     assert not summary_collector._collect_specified_data[
         'collect_input_data']
示例#3
0
 def test_collect_input_data_with_train_dataset_element_invalid(self):
     """Test the param 'train_dataset_element' in cb_params is invalid."""
     cb_params = _InternalCallbackParam()
     for invalid in (), [], None:
         cb_params.train_dataset_element = invalid
         summary_collector = SummaryCollector(
             tempfile.mkdtemp(dir=self.base_summary_dir))
         summary_collector._collect_input_data(cb_params)
         assert not summary_collector._collect_specified_data[
             'collect_input_data']
示例#4
0
 def test_check_callback_with_multi_instances(self):
     """Use multi SummaryCollector instances to test check_callback function."""
     cb_params = _InternalCallbackParam()
     cb_params.list_callback = [
         SummaryCollector(tempfile.mkdtemp(dir=self.base_summary_dir)),
         SummaryCollector(tempfile.mkdtemp(dir=self.base_summary_dir))
     ]
     with pytest.raises(ValueError) as exc:
         SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir)))._check_callbacks(cb_params)
     assert f"more than one SummaryCollector instance in callback list" in str(exc.value)
示例#5
0
    def test_get_optimizer_from_cb_params_success(self):
        """Test get optimizer success from cb params."""
        cb_params = _InternalCallbackParam()
        cb_params.optimizer = Optimizer(learning_rate=0.1, parameters=[Parameter(Tensor(1), 'weight')])
        summary_collector = SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir)))
        optimizer = summary_collector._get_optimizer(cb_params)
        assert optimizer == cb_params.optimizer

        # Test get optimizer again
        assert summary_collector._get_optimizer(cb_params) == cb_params.optimizer
示例#6
0
 def test_params_with_summary_dir_value_error(self, summary_dir):
     """Test the exception scenario for summary dir."""
     if isinstance(summary_dir, str):
         with pytest.raises(ValueError) as exc:
             SummaryCollector(summary_dir=summary_dir)
         assert str(exc.value) == 'For `summary_dir` the value should be a valid string of path, ' \
                                  'but got empty string.'
     else:
         with pytest.raises(TypeError) as exc:
             SummaryCollector(summary_dir=summary_dir)
         assert 'For `summary_dir` the type should be a valid type' in str(exc.value)
示例#7
0
    def test_get_loss(self, net_output, expected_loss):
        """Test get loss success and failed."""
        cb_params = _InternalCallbackParam()
        cb_params.net_outputs = net_output
        summary_collector = SummaryCollector(
            (tempfile.mkdtemp(dir=self.base_summary_dir)))

        assert summary_collector._is_parse_loss_success
        assert summary_collector._get_loss(cb_params) == expected_loss

        if expected_loss is None:
            assert not summary_collector._is_parse_loss_success
示例#8
0
 def test_get_optimizer_from_network(self, mode):
     """Get optimizer from train network"""
     cb_params = _InternalCallbackParam()
     cb_params.optimizer = None
     cb_params.mode = mode
     if mode == ModeEnum.TRAIN.value:
         cb_params.train_network = CustomNet()
     else:
         cb_params.eval_network = CustomNet()
     summary_collector = SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir)))
     optimizer = summary_collector._get_optimizer(cb_params)
     assert isinstance(optimizer, Optimizer)
示例#9
0
 def test_params_with_collect_freq_exception(self, collect_freq):
     """Test the exception scenario for collect freq."""
     summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
     if isinstance(collect_freq, int):
         with pytest.raises(ValueError) as exc:
             SummaryCollector(summary_dir=summary_dir, collect_freq=collect_freq)
         expected_msg = f'For `collect_freq` the value should be greater than 0, but got `{collect_freq}`.'
         assert expected_msg == str(exc.value)
     else:
         with pytest.raises(TypeError) as exc:
             SummaryCollector(summary_dir=summary_dir, collect_freq=collect_freq)
         expected_msg = f"For `collect_freq` the type should be a valid type of ['int'], " \
                        f'bug got {type(collect_freq).__name__}.'
         assert expected_msg == str(exc.value)
示例#10
0
    def _run_network(self, dataset_sink_mode=False, num_samples=2, **kwargs):
        lenet = LeNet5()
        loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
        optim = Momentum(lenet.trainable_params(),
                         learning_rate=0.1,
                         momentum=0.9)
        model = Model(lenet,
                      loss_fn=loss,
                      optimizer=optim,
                      metrics={'loss': Loss()})
        summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
        summary_collector = SummaryCollector(summary_dir=summary_dir,
                                             collect_freq=2,
                                             **kwargs)

        ds_train = create_dataset(os.path.join(self.mnist_path, "train"),
                                  num_samples=num_samples)
        model.train(1,
                    ds_train,
                    callbacks=[summary_collector],
                    dataset_sink_mode=dataset_sink_mode)

        ds_eval = create_dataset(os.path.join(self.mnist_path, "test"))
        model.eval(ds_eval,
                   dataset_sink_mode=dataset_sink_mode,
                   callbacks=[summary_collector])
        return summary_dir
示例#11
0
 def test_collect_histogram_from_regular(self, mock_add_value,
                                         histogram_regular, expected_names,
                                         expected_values):
     """Test collect histogram from regular success."""
     mock_add_value.side_effect = add_value
     cb_params = _InternalCallbackParam()
     parameters = [
         Parameter(Tensor(1), 'conv1.weight1'),
         Parameter(Tensor(2), 'conv2.weight2'),
         Parameter(Tensor(3), 'conv1.bias1'),
         Parameter(Tensor(4), 'conv3.bias'),
         Parameter(Tensor(5), 'conv5.bias'),
         Parameter(Tensor(6), 'conv6.bias'),
     ]
     cb_params.optimizer = Optimizer(learning_rate=0.1,
                                     parameters=parameters)
     with SummaryCollector((tempfile.mkdtemp(
             dir=self.base_summary_dir))) as summary_collector:
         summary_collector._collect_specified_data[
             'histogram_regular'] = histogram_regular
         summary_collector._collect_histogram(cb_params)
     result = get_value()
     assert PluginEnum.HISTOGRAM.value == result[0][0]
     assert expected_names == [data[1] for data in result]
     assert expected_values == [data[2] for data in result]
示例#12
0
def main(data_path,
         device_target='Ascend',
         summary_dir='./summary_dir',
         learning_rate=0.01):
    context.set_context(mode=context.GRAPH_MODE, device_target=device_target)

    momentum = 0.9
    epoch_size = 1
    batch_size = 32

    network = LeNet5()
    network.set_train()
    net_loss = CrossEntropyLoss()
    net_opt = nn.Momentum(network.trainable_params(), learning_rate, momentum)
    model = Model(network, net_loss, net_opt)

    # Init SummaryCollector callback to record summary data in model.train or model.eval
    summary_collector = SummaryCollector(summary_dir=summary_dir,
                                         collect_freq=10)

    ds = create_dataset(os.path.join(data_path, "train"),
                        batch_size=batch_size)

    print("============== Starting Training ==============")
    model.train(epoch_size,
                ds,
                callbacks=[summary_collector],
                dataset_sink_mode=False)
    print("============== Train End =====================")
示例#13
0
    def _run_network(self, dataset_sink_mode=True):
        lenet = LeNet5()
        loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False,
                                                sparse=True,
                                                reduction="mean")
        optim = Momentum(lenet.trainable_params(),
                         learning_rate=0.1,
                         momentum=0.9)
        model = Model(lenet,
                      loss_fn=loss,
                      optimizer=optim,
                      metrics={'acc': Accuracy()})
        summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
        summary_collector = SummaryCollector(summary_dir=summary_dir,
                                             collect_freq=1)

        ds_train = create_dataset(os.path.join(self.mnist_path, "train"))
        model.train(1,
                    ds_train,
                    callbacks=[summary_collector],
                    dataset_sink_mode=dataset_sink_mode)

        ds_eval = create_dataset(os.path.join(self.mnist_path, "test"))
        model.eval(ds_eval,
                   dataset_sink_mode=dataset_sink_mode,
                   callbacks=[summary_collector])

        self._check_summary_result(summary_dir)
示例#14
0
 def test_params_with_action_exception(self, action):
     """Test the exception scenario for action."""
     summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
     with pytest.raises(TypeError) as exc:
         SummaryCollector(summary_dir=summary_dir, keep_default_action=action)
     expected_msg = f"For `keep_default_action` the type should be a valid type of ['bool'], " \
                    f"bug got {type(action).__name__}."
     assert expected_msg == str(exc.value)
示例#15
0
 def test_params_with_collect_specified_data_unexpected_key(self):
     """Test the collect_specified_data parameter with unexpected key."""
     summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
     data = {'unexpected_key': True}
     with pytest.raises(ValueError) as exc:
         SummaryCollector(summary_dir, collect_specified_data=data)
     expected_msg = f"For `collect_specified_data` the keys {set(data)} are unsupported"
     assert expected_msg in str(exc.value)
示例#16
0
 def test_collect_input_data_success(self, mock_add_value):
     """Mock a image data, and collect image data success."""
     mock_add_value.side_effect = add_value
     cb_params = _InternalCallbackParam()
     image_data = Tensor(np.random.randint(0, 255, size=(1, 1, 1, 1)).astype(np.uint8))
     cb_params.train_dataset_element = image_data
     with SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir))) as summary_collector:
         summary_collector._collect_input_data(cb_params)
 def test_params_with_export_options_unexpected_key(self):
     """Test the export_options parameter with unexpected key."""
     summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
     data = {'unexpected_key': "value"}
     with pytest.raises(ValueError) as exc:
         SummaryCollector(summary_dir, export_options=data)
     expected_msg = f"For `export_options` the keys {set(data)} are unsupported"
     assert expected_msg in str(exc.value)
示例#18
0
    def test_process_specified_data(self, specified_data, action, expected_result):
        """Test process specified data."""
        summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
        summary_collector = SummaryCollector(summary_dir,
                                             collect_specified_data=specified_data,
                                             keep_default_action=action)

        assert summary_collector._collect_specified_data == expected_result
示例#19
0
 def test_params_with_summary_dir_not_dir(self):
     """Test the given summary dir parameter is not a directory."""
     summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
     summary_file = os.path.join(summary_dir, 'temp_file.txt')
     with open(summary_file, 'w') as file_handle:
         file_handle.write('temp')
     print(os.path.isfile(summary_file))
     with pytest.raises(NotADirectoryError):
         SummaryCollector(summary_dir=summary_file)
    def test_params_with_histogram_regular_value_error(self):
        """Test histogram regular."""
        summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
        with pytest.raises(ValueError) as exc:
            SummaryCollector(summary_dir,
                             collect_specified_data={'histogram_regular': '*'})

        assert 'For `collect_specified_data`, the value of `histogram_regular`' in str(
            exc.value)
示例#21
0
    def test_params_with_collect_specified_data_key_type_error(self, collect_specified_data):
        """Test the key of collect specified data param."""
        summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
        with pytest.raises(TypeError) as exc:
            SummaryCollector(summary_dir, collect_specified_data=collect_specified_data)

        param_name = list(collect_specified_data)[0]
        expected_msg = f"For `{param_name}` the type should be a valid type of ['str'], " \
                       f"bug got {type(param_name).__name__}."
        assert expected_msg == str(exc.value)
    def test_params_with_export_options_type_error(self, export_options):
        """Test type error scenario for collect specified data param."""
        summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
        with pytest.raises(TypeError) as exc:
            SummaryCollector(summary_dir, export_options=export_options)

        expected_msg = f"For `export_options` the type should be a valid type of ['dict', 'NoneType'], " \
                       f"but got {type(export_options).__name__}."

        assert expected_msg == str(exc.value)
示例#23
0
 def test_collect_dataset_graph_success(self, mock_add_value):
     """Test collect dataset graph."""
     dataset = import_module('mindspore.dataset')
     mock_add_value.side_effect = add_value
     cb_params = _InternalCallbackParam()
     cb_params.train_dataset = dataset.MnistDataset(dataset_dir=tempfile.mkdtemp(dir=self.base_summary_dir))
     cb_params.mode = ModeEnum.TRAIN.value
     with SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir))) as summary_collector:
         summary_collector._collect_dataset_graph(cb_params)
         plugin, name, _ = get_value()[0]
     assert plugin == 'dataset_graph'
     assert name == 'train_dataset'
    def test_params_with_tensor_format_type_error(self, export_options):
        """Test type error scenario for collect specified data param."""
        summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
        with pytest.raises(ValueError) as exc:
            SummaryCollector(summary_dir, export_options=export_options)

        unexpected_format = {export_options.get("tensor_format")}
        expected_msg = f'For `export_options`, the export_format {unexpected_format} are ' \
                       f'unsupported for tensor_format, expect the follow values: ' \
                       f'{list(_DEFAULT_EXPORT_OPTIONS.get("tensor_format"))}'

        assert expected_msg == str(exc.value)
示例#25
0
    def test_params_with_collect_specified_data_value_type_error(self, collect_specified_data):
        """Test the value of collect specified data param."""
        summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
        with pytest.raises(TypeError) as exc:
            SummaryCollector(summary_dir, collect_specified_data=collect_specified_data)

        param_name = list(collect_specified_data)[0]
        param_value = collect_specified_data[param_name]
        expected_type = "['bool']" if param_name != 'histogram_regular' else "['str', 'NoneType']"
        expected_msg = f'For `{param_name}` the type should be a valid type of {expected_type}, ' \
                       f'bug got {type(param_value).__name__}.'

        assert expected_msg == str(exc.value)
示例#26
0
def test_graph_summary_callback():
    dataset = get_dataset()
    net = Net()
    loss = nn.SoftmaxCrossEntropyWithLogits()
    optim = Momentum(net.trainable_params(), 0.1, 0.9)
    context.set_context(mode=context.GRAPH_MODE)
    model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
    summary_collector = SummaryCollector(
        SUMMARY_DIR,
        collect_freq=1,
        keep_default_action=False,
        collect_specified_data={'collect_graph': True})
    model.train(1, dataset, callbacks=[summary_collector])
示例#27
0
    def test_params_with_custom_lineage_data_type_error(self, custom_lineage_data):
        """Test the custom lineage data parameter type error."""
        summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
        with pytest.raises(TypeError) as exc:
            SummaryCollector(summary_dir, custom_lineage_data=custom_lineage_data)

        if not isinstance(custom_lineage_data, dict):
            expected_msg = f"For `custom_lineage_data` the type should be a valid type of ['dict', 'NoneType'], " \
                           f"bug got {type(custom_lineage_data).__name__}."
        else:
            param_name = list(custom_lineage_data)[0]
            param_value = custom_lineage_data[param_name]
            if not isinstance(param_name, str):
                arg_name = f'custom_lineage_data -> {param_name}'
                expected_msg = f"For `{arg_name}` the type should be a valid type of ['str'], " \
                               f'bug got {type(param_name).__name__}.'
            else:
                arg_name = f'the value of custom_lineage_data -> {param_name}'
                expected_msg = f"For `{arg_name}` the type should be a valid type of ['int', 'str', 'float'], " \
                               f'bug got {type(param_value).__name__}.'

        assert expected_msg == str(exc.value)
示例#28
0
                   0.01, 0.9)

    model = Model(net, loss_fn=ls, optimizer=opt, metrics={'acc'})

    # as for train, users could use model.train
    if args_opt.do_train:
        dataset = create_dataset()
        config_ck = CheckpointConfig(save_checkpoint_steps=10,
                                     keep_checkpoint_max=1000)
        ckpoint_cb = ModelCheckpoint(prefix="",
                                     directory=os.path.join(
                                         summary_dir, "weights"),
                                     config=config_ck)
        # data_saver_callback = DataSaverCallback(summary_dir=summary_dir, save_interval=1)
        data_saver_callback = DataSaverCallbackGPU(summary_dir=summary_dir)
        summary_cb = SummaryCollector(summary_dir=summary_dir,
                                      collect_freq=1000)
        model.train(epoch_size,
                    dataset,
                    callbacks=[
                        LossMonitor(), data_saver_callback, summary_cb,
                        ckpoint_cb
                    ],
                    dataset_sink_mode=False)
        # model.train(epoch_size, dataset, callbacks=[LossMonitor(), summary_cb], dataset_sink_mode=False)

    # as for evaluation, users could use model.eval
    if args_opt.do_eval:
        if args_opt.checkpoint_path:
            param_dict = load_checkpoint(args_opt.checkpoint_path)
            load_param_into_net(net, param_dict)
        eval_dataset = create_dataset(training=False)
示例#29
0
def _build_training_pipeline(config: GNMTConfig,
                             pre_training_dataset=None,
                             fine_tune_dataset=None,
                             test_dataset=None):
    """
    Build training pipeline.

    Args:
        config (GNMTConfig): Config of mass model.
        pre_training_dataset (Dataset): Pre-training dataset.
        fine_tune_dataset (Dataset): Fine-tune dataset.
        test_dataset (Dataset): Test dataset.
    """
    net_with_loss = GNMTNetworkWithLoss(config,
                                        is_training=True,
                                        use_one_hot_embeddings=True)
    net_with_loss.init_parameters_data()
    _load_checkpoint_to_net(config, net_with_loss)

    dataset = pre_training_dataset if pre_training_dataset is not None \
        else fine_tune_dataset

    if dataset is None:
        raise ValueError(
            "pre-training dataset or fine-tuning dataset must be provided one."
        )

    update_steps = config.epochs * dataset.get_dataset_size()

    lr = _get_lr(config, update_steps)
    optimizer = _get_optimizer(config, net_with_loss, lr)

    # Dynamic loss scale.
    scale_manager = DynamicLossScaleManager(
        init_loss_scale=config.init_loss_scale,
        scale_factor=config.loss_scale_factor,
        scale_window=config.scale_window)
    net_with_grads = GNMTTrainOneStepWithLossScaleCell(
        network=net_with_loss,
        optimizer=optimizer,
        scale_update_cell=scale_manager.get_update_cell())
    net_with_grads.set_train(True)
    model = Model(net_with_grads)
    loss_monitor = LossCallBack(config)
    dataset_size = dataset.get_dataset_size()
    time_cb = TimeMonitor(data_size=dataset_size)
    ckpt_config = CheckpointConfig(
        save_checkpoint_steps=config.save_ckpt_steps,
        keep_checkpoint_max=config.keep_ckpt_max)

    rank_size = os.getenv('RANK_SIZE')
    callbacks = [time_cb, loss_monitor]
    if rank_size is not None and int(
            rank_size) > 1 and MultiAscend.get_rank() % 8 == 0:
        ckpt_callback = ModelCheckpoint(
            prefix=config.ckpt_prefix,
            directory=os.path.join(config.ckpt_path,
                                   'ckpt_{}'.format(os.getenv('DEVICE_ID'))),
            config=ckpt_config)
        callbacks.append(ckpt_callback)
        summary_callback = SummaryCollector(summary_dir="./summary",
                                            collect_freq=50)
        callbacks.append(summary_callback)

    if rank_size is None or int(rank_size) == 1:
        ckpt_callback = ModelCheckpoint(
            prefix=config.ckpt_prefix,
            directory=os.path.join(config.ckpt_path,
                                   'ckpt_{}'.format(os.getenv('DEVICE_ID'))),
            config=ckpt_config)
        callbacks.append(ckpt_callback)
        summary_callback = SummaryCollector(summary_dir="./summary",
                                            collect_freq=50)
        callbacks.append(summary_callback)

    print(f" | ALL SET, PREPARE TO TRAIN.")
    _train(model=model,
           config=config,
           pre_training_dataset=pre_training_dataset,
           fine_tune_dataset=fine_tune_dataset,
           test_dataset=test_dataset,
           callbacks=callbacks)
示例#30
0
def main():
    set_seed(1)
    date = time.strftime("%Y%m%d%H%M%S", time.localtime())
    print(f'* Preparing to train model {date}')

    # ************** configuration ****************
    # - training setting
    resume = config['resume']
    if config['mode'] == 'PYNATIVE':
        mode = context.PYNATIVE_MODE
    else:
        mode = context.GRAPH_MODE

    device = config['device']
    device_id = config['device_id']
    dataset_sink_mode = config['dataset_sink_mode']

    # use in dataset
    div = 8

    # setting bias and padding
    if resume:
        print('* Resuming model...')
        resume_config_log = config['resume_config_log']
        resume_config = get_eval_config(resume_config_log)
        if 'best_ckpt' in resume_config.keys():
            resume_model_path = resume_config['best_ckpt']
        else:
            resume_model_path = resume_config['latest_model']
            print('* [WARNING] Not using the best model, but latest saved model instead.')

        has_bias = resume_config['has_bias']
        use_dropout = resume_config['use_dropout']

        pad_mode = resume_config['pad_mode']

        if pad_mode == 'pad':
            padding = resume_config['padding']
        elif pad_mode == 'same':
            padding = 0
        else:
            raise ValueError(f"invalid pad mode: {pad_mode}!")

        best_acc = resume_config['best_acc']
        best_ckpt = resume_config['best_ckpt']
        print('* The best accuracy in dev dataset for the current resumed model is {:.2f}%'.format(best_acc * 100))

    else:
        has_bias = config['has_bias']
        use_dropout = config['use_dropout']

        pad_mode = config['pad_mode']

        if pad_mode == 'pad':
            padding = config['padding']
        elif pad_mode == 'same':
            padding = 0
        else:
            raise ValueError(f"invalid pad mode: {pad_mode}!")

    # hyper-parameters
    if resume:
        batch_size = resume_config['batch_size']
        opt_type = resume_config['opt']
        use_dynamic_lr = resume_config['use_dynamic_lr']
        warmup_step = resume_config['warmup_step']
        warmup_ratio = resume_config['warmup_ratio']
    else:
        batch_size = config['batch_size']
        opt_type = config['opt']
        use_dynamic_lr = config['use_dynamic_lr']
        warmup_step = config['warmup_step']
        warmup_ratio = config['warmup_ratio']

    test_dev_batch_size = config['test_dev_batch_size']
    learning_rate = float(config['learning_rate'])
    epochs = config['epochs']
    loss_scale = config['loss_scale']

    # configuration of saving model checkpoint
    save_checkpoint_steps = config['save_checkpoint_steps']
    keep_checkpoint_max = config['keep_checkpoint_max']
    prefix = config['prefix'] + '_' + date
    model_dir = config['model_dir']

    # loss monitor
    loss_monitor_step = config['loss_monitor_step']

    # whether to use mindInsight summary
    use_summary = config['use_summary']

    # step_eval
    use_step_eval = config['use_step_eval']
    eval_step = config['eval_step']
    eval_epoch = config['eval_epoch']
    patience = config['patience']

    # eval in steps or epochs
    step_eval = True

    if eval_step == -1:
        step_eval = False

    # ************** end of configuration **************
    if device == 'GPU':
        context.set_context(mode=mode, device_target=device, device_id=device_id)
    elif device == 'Ascend':
        import moxing as mox
        from utils.const import DATA_PATH, MODEL_PATH, BEST_MODEL_PATH, LOG_PATH
        obs_datapath = config['obs_datapath']
        obs_saved_model = config['obs_saved_model']
        obs_best_model = config['obs_best_model']
        obs_log = config['obs_log']
        mox.file.copy_parallel(obs_datapath, DATA_PATH)
        mox.file.copy_parallel(MODEL_PATH, obs_saved_model)
        mox.file.copy_parallel(BEST_MODEL_PATH, obs_best_model)
        mox.file.copy_parallel(LOG_PATH, obs_log)
        context.set_context(mode=mode, device_target=device)
        use_summary = False

    # callbacks function
    callbacks = []

    # data
    train_loader, idx2label, label2idx = get_dataset(batch_size=batch_size, phase='train',
                                                     test_dev_batch_size=test_dev_batch_size, div=div,
                                                     num_parallel_workers=4)

    if eval_step == 0:
        eval_step = train_loader.get_dataset_size()

    # network
    net = DFCNN(num_classes=len(label2idx), padding=padding, pad_mode=pad_mode,
                has_bias=has_bias, use_dropout=use_dropout)

    # Criterion
    criterion = CTCLoss()

    # resume
    if resume:
        print("* Loading parameters...")
        param_dict = load_checkpoint(resume_model_path)
        # load the parameter into net
        load_param_into_net(net, param_dict)
        print(f'* Parameters loading from {resume_model_path} succeeded!')

    net.set_train(True)
    net.set_grad(True)

    # lr schedule
    if use_dynamic_lr:
        dataset_size = train_loader.get_dataset_size()
        learning_rate = Tensor(dynamic_lr(base_lr=learning_rate, warmup_step=warmup_step,
                                          warmup_ratio=warmup_ratio, epochs=epochs,
                                          steps_per_epoch=dataset_size), mstype.float32)
        print('* Using dynamic learning rate, which will be set up as :', learning_rate.asnumpy())

    # optim
    if opt_type == 'adam':
        opt = nn.Adam(net.trainable_params(), learning_rate=learning_rate, beta1=0.9, beta2=0.999, weight_decay=0.0,
                      eps=10e-8)
    elif opt_type == 'rms':
        opt = nn.RMSProp(params=net.trainable_params(),
                         centered=True,
                         learning_rate=learning_rate,
                         momentum=0.9,
                         loss_scale=loss_scale)
    elif opt_type == 'sgd':
        opt = nn.SGD(params=net.trainable_params(), learning_rate=learning_rate)
    else:
        raise ValueError(f"optimizer: {opt_type} is not supported for now!")

    if resume:
        # load the parameter into optimizer
        load_param_into_net(opt, param_dict)

    # save_model
    config_ck = CheckpointConfig(save_checkpoint_steps=save_checkpoint_steps, keep_checkpoint_max=keep_checkpoint_max)
    ckpt_cb = ModelCheckpoint(prefix=prefix, directory=model_dir, config=config_ck)

    # logger
    the_logger = logger(config, date)
    log = Logging(logger=the_logger, model_ckpt=ckpt_cb)

    callbacks.append(ckpt_cb)
    callbacks.append(log)

    net = WithLossCell(net, criterion)
    scaling_sens = Tensor(np.full((1), loss_scale), dtype=mstype.float32)

    net = DFCNNCTCTrainOneStepWithLossScaleCell(net, opt, scaling_sens)
    net.set_train(True)
    model = Model(net)

    if use_step_eval:
        # step evaluation
        step_eval = StepAccInfo(model=model, name=prefix, div=div, test_dev_batch_size=test_dev_batch_size,
                                step_eval=step_eval, eval_step=eval_step, eval_epoch=eval_epoch,
                                logger=the_logger, patience=patience, dataset_size=train_loader.get_dataset_size())

        callbacks.append(step_eval)

    # loss monitor
    loss_monitor = LossMonitor(loss_monitor_step)

    callbacks.append(loss_monitor)

    if use_summary:
        summary_dir = os.path.join(SUMMARY_DIR, date)
        if not os.path.exists(summary_dir):
            os.mkdir(summary_dir)
        # mindInsight
        summary_collector = SummaryCollector(summary_dir=summary_dir, collect_freq=1, max_file_size=4 * 1024 ** 3)
        callbacks.append(summary_collector)

    if resume:
        the_logger.update_acc_ckpt(best_acc, best_ckpt)

    print(f'* Start training...')
    model.train(epochs,
                train_loader,
                callbacks=callbacks,
                dataset_sink_mode=dataset_sink_mode)