示例#1
0
    def _exec_preprocess(self,
                         network,
                         is_train,
                         phase,
                         dataset,
                         dataset_sink_mode,
                         sink_size=-1,
                         epoch_num=1,
                         iter_first_order=9):
        """Initializes dataset."""
        need_wrap = False
        if dataset_sink_mode:
            # remove later to deal with loop sink
            if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \
                    and not context.get_context("enable_ge"):
                need_wrap = True

            if not is_train:
                dataset.__loop_size__ = 1

        dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size,
                                       epoch_num, iter_first_order)

        # remove later to deal with loop sink
        if need_wrap:
            network = nn.DataWrapper(network, *(dataset_helper.types_shapes()),
                                     dataset.__ME_INITED__)
            network.set_train(is_train)
            network.phase = phase

        if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL,
                                   ParallelMode.AUTO_PARALLEL):
            network.set_auto_parallel()

        return dataset_helper, network
示例#2
0
    def _exec_preprocess(self,
                         network,
                         is_train,
                         phase,
                         dataset,
                         dataset_sink_mode,
                         iter_first_order=1):
        """Initializes dataset."""
        need_wrap = False
        if dataset_sink_mode:
            # remove later to deal with loop sink
            if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \
                    and not context.get_context("enable_ge"):
                need_wrap = True

            if not is_train:
                dataset.__loop_size__ = 1

        dataset_helper = DatasetHelper(dataset, dataset_sink_mode,
                                       iter_first_order)

        # remove later to deal with loop sink
        if need_wrap:
            network = nn.DataWrapper(network, *(dataset_helper.types_shapes()),
                                     dataset.__ME_INITED__)
            network.set_train(is_train)
            network.phase = phase

        return dataset_helper, network
示例#3
0
    def _eval_dataset_sink_process(self,
                                   valid_dataset,
                                   list_callback=None,
                                   cb_params=None):
        """
        Evaluation. The data would be passed to network through dataset channel.

        Args:
            valid_dataset (Dataset): Dataset to evaluate the model.
            list_callback (ListCallback): Executor of callback list. Default: None.
            cb_params (_InternalCallbackParam): Callback parameters. Default: None.

        Returns:
            Dict, returns the loss value & metrics values for the model in test mode.
        """
        _device_number_check(self._parallel_mode, self._device_number)

        run_context = RunContext(cb_params)

        # remove later to deal with loop sink
        need_wrap = False
        if not hasattr(valid_dataset, '__ME_INITED__') and context.get_context("enable_loop_sink") \
                and not context.get_context("enable_ge"):
            need_wrap = True

        valid_dataset.__loop_size__ = 1
        dataset_helper = DatasetHelper(valid_dataset)

        # remove later to deal with loop sink
        if need_wrap:
            self._eval_network = nn.DataWrapper(
                self._eval_network, *(dataset_helper.types_shapes()),
                valid_dataset.__ME_INITED__)
            self._eval_network.set_train(mode=False)
            self._eval_network.phase = 'eval'
        list_callback.begin(run_context)

        for inputs in dataset_helper:
            cb_params.cur_step_num += 1
            list_callback.step_begin(run_context)

            outputs = self._eval_network(*inputs)

            cb_params.net_outputs = outputs
            list_callback.step_end(run_context)
            self._update_metrics(outputs)

        metrics = self._get_metrics()
        cb_params.metrics = metrics
        list_callback.end(run_context)

        return metrics
示例#4
0
    def _train_dataset_sink_process(self,
                                    epoch,
                                    train_dataset,
                                    list_callback=None,
                                    cb_params=None):
        """
        Training process. The data would be passed to network through dataset channel.

        Args:
            epoch (int): Total number of iterations on the data.
            train_dataset (Dataset): A training dataset iterator. If there is no
                                     loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be
                                     returned and passed to the network. Otherwise, a tuple (data, label) should
                                     be returned, and the data and label are passed to the network and loss
                                     function respectively.
            list_callback (_ListCallback): Executor of callback list. Default: None.
            cb_params (_InternalCallbackParam): Callback parameters. Default: None.
        """
        # remove later to deal with loop sink
        iter_first_order = 277
        iter_second_order = 1
        train_dataset.__loop_size__ = iter_second_order
        need_wrap = False
        if not hasattr(train_dataset, '__ME_INITED__') and context.get_context("enable_loop_sink") \
                and not context.get_context("enable_ge"):
            need_wrap = True

        dataset_helper = DatasetHelper(train_dataset, iter_first_order)
        # remove later to deal with loop sink
        if need_wrap:
            self._train_network = nn.DataWrapper(
                self._train_network, *(dataset_helper.types_shapes()),
                train_dataset.__ME_INITED__)
            cb_params.train_network = self._train_network
            self._train_network.set_train()

        cb_params.cur_step_num = 0
        loop_size = dataset_helper.loop_size()
        run_context = RunContext(cb_params)
        list_callback.begin(run_context)

        # used to stop training for early stop, such as stopAtTIme or stopATStep
        should_stop = False
        has_do_train1_dataset = False
        checkpoint_branch_one = True
        for i in range(epoch):
            cb_params.cur_epoch_num = i + 1
            list_callback.epoch_begin(run_context)

            # for data sink dataset_helper only iter once, other wise iter epoch_size times.
            for inputs in dataset_helper:
                list_callback.step_begin(run_context)
                if checkpoint_branch_one:
                    cb_params.cur_step_num += loop_size
                    self._train_network.set_second_order(True)
                    self._train_network.phase = 'train0'
                else:
                    cb_params.cur_step_num += iter_first_order
                    self._train_network.set_second_order(False)
                    self._train_network.phase = 'train1'
                    if not has_do_train1_dataset:
                        _exec_datagraph(train_dataset,
                                        iter_first_order,
                                        phase='train1_dataset')
                        has_do_train1_dataset = True
                checkpoint_branch_one = not checkpoint_branch_one
                outputs = self._train_network(*inputs)
                cb_params.net_outputs = outputs
                list_callback.step_end(run_context)

            list_callback.epoch_end(run_context)
            should_stop = should_stop or run_context.get_stop_requested()
            if should_stop:
                break

        list_callback.end(run_context)