コード例 #1
0
ファイル: test_callback.py プロジェクト: opendlf/mindspore
def test_RunContext():
    """Test RunContext."""
    context_err = 666
    with pytest.raises(TypeError):
        RunContext(context_err)

    cb_params = _InternalCallbackParam()
    cb_params.member1 = 1
    cb_params.member2 = "abc"

    run_context = RunContext(cb_params)
    run_context.original_args()
    assert cb_params.member1 == 1
    assert cb_params.member2 == "abc"

    run_context.request_stop()
    should_stop = run_context.get_stop_requested()
    assert should_stop
コード例 #2
0
    def _train_process(self,
                       epoch,
                       train_dataset,
                       list_callback=None,
                       cb_params=None):
        """
        Training process. The data would be passed to network directly.

        Args:
            epoch (int): Total number of iterations on the data.
            train_dataset (Dataset): A training dataset iterator. If there is no
                                     loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be
                                     returned and passed to the network. Otherwise, a tuple (data, label) should
                                     be returned, and the data and label are passed to the network and loss
                                     function respectively.
            list_callback (Callback): Executor of callback list. Default: None.
            cb_params (_InternalCallbackParam): Callback parameters. Default: None.
        """
        dataset_helper, _ = self._exec_preprocess(self._train_network,
                                                  is_train=True,
                                                  phase='train',
                                                  dataset=train_dataset,
                                                  dataset_sink_mode=False)
        cb_params.cur_step_num = 0
        run_context = RunContext(cb_params)
        list_callback.begin(run_context)
        # used to stop training for early stop, such as stopAtTIme or stopATStep
        should_stop = False

        for i in range(epoch):
            cb_params.cur_epoch_num = i + 1

            list_callback.epoch_begin(run_context)

            for next_element in dataset_helper:
                len_element = len(next_element)
                if self._loss_fn and len_element != 2:
                    raise ValueError(
                        "when loss_fn is not None, train_dataset should"
                        "return two elements, but got {}".format(len_element))
                cb_params.cur_step_num += 1
                list_callback.step_begin(run_context)

                overflow = False
                if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update(
                ):
                    scaling_sens = self._get_scaling_sens()
                    next_element = tuple(next_element) + (Tensor(
                        scaling_sens, mstype.float32), )

                outputs = self._train_network(*next_element)
                cb_params.net_outputs = outputs
                if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update(
                ):
                    _, overflow, _ = outputs
                    overflow = np.all(overflow.asnumpy())
                    self._loss_scale_manager.update_loss_scale(overflow)

                list_callback.step_end(run_context)
                should_stop = should_stop or run_context.get_stop_requested()
                if should_stop:
                    break

            train_dataset.reset()

            list_callback.epoch_end(run_context)
            should_stop = should_stop or run_context.get_stop_requested()
            if should_stop:
                break

        list_callback.end(run_context)
コード例 #3
0
    def _train_dataset_sink_process(self,
                                    epoch,
                                    train_dataset,
                                    list_callback=None,
                                    cb_params=None):
        """
        Training process. The data would be passed to network through dataset channel.

        Args:
            epoch (int): Total number of iterations on the data.
            train_dataset (Dataset): A training dataset iterator. If there is no
                                     loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be
                                     returned and passed to the network. Otherwise, a tuple (data, label) should
                                     be returned, and the data and label are passed to the network and loss
                                     function respectively.
            list_callback (Callback): Executor of callback list. Default: None.
            cb_params (_InternalCallbackParam): Callback parameters. Default: None.
        """
        iter_first_order = self._frequency - 1
        iter_second_order = 1
        train_dataset.__loop_size__ = iter_second_order
        dataset_helper, train_network = self._exec_preprocess(
            self._train_network,
            is_train=True,
            phase='train',
            dataset=train_dataset,
            dataset_sink_mode=True,
            iter_first_order=iter_first_order)
        self._train_network = train_network
        cb_params.train_network = self._train_network
        cb_params.cur_step_num = 0

        loop_size = dataset_helper.loop_size()
        run_context = RunContext(cb_params)
        list_callback.begin(run_context)

        # used to stop training for early stop, such as stopAtTIme or stopATStep
        should_stop = False
        has_do_dataset_init = False
        switch_branch_one = True
        for i in range(epoch):
            cb_params.cur_epoch_num = i + 1
            list_callback.epoch_begin(run_context)

            # for data sink dataset_helper only iter once, other wise iter epoch_size times.
            for inputs in dataset_helper:
                list_callback.step_begin(run_context)
                if switch_branch_one:
                    cb_params.cur_step_num += loop_size
                    self._train_network.add_flags_recursive(thor=True)
                    self._train_network.phase = 'train0'
                else:
                    cb_params.cur_step_num += iter_first_order
                    self._train_network.add_flags_recursive(thor=False)
                    self._train_network.phase = 'train1'
                    if not has_do_dataset_init:
                        _exec_datagraph(train_dataset,
                                        iter_first_order,
                                        phase='train1_dataset')
                        has_do_dataset_init = True
                switch_branch_one = not switch_branch_one
                outputs = self._train_network(*inputs)
                cb_params.net_outputs = outputs
                list_callback.step_end(run_context)

            list_callback.epoch_end(run_context)
            should_stop = should_stop or run_context.get_stop_requested()
            if should_stop:
                break

        list_callback.end(run_context)