def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1, epoch_num=1, iter_first_order=9): """Initializes dataset.""" need_wrap = False if dataset_sink_mode: # remove later to deal with loop sink if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \ and not context.get_context("enable_ge"): need_wrap = True if not is_train: dataset.__loop_size__ = 1 dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num, iter_first_order) # remove later to deal with loop sink if need_wrap: network = nn.DataWrapper(network, *(dataset_helper.types_shapes()), dataset.__ME_INITED__) network.set_train(is_train) network.phase = phase if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): network.set_auto_parallel() return dataset_helper, network
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, iter_first_order=1): """Initializes dataset.""" need_wrap = False if dataset_sink_mode: # remove later to deal with loop sink if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \ and not context.get_context("enable_ge"): need_wrap = True if not is_train: dataset.__loop_size__ = 1 dataset_helper = DatasetHelper(dataset, dataset_sink_mode, iter_first_order) # remove later to deal with loop sink if need_wrap: network = nn.DataWrapper(network, *(dataset_helper.types_shapes()), dataset.__ME_INITED__) network.set_train(is_train) network.phase = phase return dataset_helper, network
def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None): """ Evaluation. The data would be passed to network through dataset channel. Args: valid_dataset (Dataset): Dataset to evaluate the model. list_callback (ListCallback): Executor of callback list. Default: None. cb_params (_InternalCallbackParam): Callback parameters. Default: None. Returns: Dict, returns the loss value & metrics values for the model in test mode. """ _device_number_check(self._parallel_mode, self._device_number) run_context = RunContext(cb_params) # remove later to deal with loop sink need_wrap = False if not hasattr(valid_dataset, '__ME_INITED__') and context.get_context("enable_loop_sink") \ and not context.get_context("enable_ge"): need_wrap = True valid_dataset.__loop_size__ = 1 dataset_helper = DatasetHelper(valid_dataset) # remove later to deal with loop sink if need_wrap: self._eval_network = nn.DataWrapper( self._eval_network, *(dataset_helper.types_shapes()), valid_dataset.__ME_INITED__) self._eval_network.set_train(mode=False) self._eval_network.phase = 'eval' list_callback.begin(run_context) for inputs in dataset_helper: cb_params.cur_step_num += 1 list_callback.step_begin(run_context) outputs = self._eval_network(*inputs) cb_params.net_outputs = outputs list_callback.step_end(run_context) self._update_metrics(outputs) metrics = self._get_metrics() cb_params.metrics = metrics list_callback.end(run_context) return metrics
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None): """ Training process. The data would be passed to network through dataset channel. Args: epoch (int): Total number of iterations on the data. train_dataset (Dataset): A training dataset iterator. If there is no loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be returned and passed to the network. Otherwise, a tuple (data, label) should be returned, and the data and label are passed to the network and loss function respectively. list_callback (_ListCallback): Executor of callback list. Default: None. cb_params (_InternalCallbackParam): Callback parameters. Default: None. """ # remove later to deal with loop sink iter_first_order = 277 iter_second_order = 1 train_dataset.__loop_size__ = iter_second_order need_wrap = False if not hasattr(train_dataset, '__ME_INITED__') and context.get_context("enable_loop_sink") \ and not context.get_context("enable_ge"): need_wrap = True dataset_helper = DatasetHelper(train_dataset, iter_first_order) # remove later to deal with loop sink if need_wrap: self._train_network = nn.DataWrapper( self._train_network, *(dataset_helper.types_shapes()), train_dataset.__ME_INITED__) cb_params.train_network = self._train_network self._train_network.set_train() cb_params.cur_step_num = 0 loop_size = dataset_helper.loop_size() run_context = RunContext(cb_params) list_callback.begin(run_context) # used to stop training for early stop, such as stopAtTIme or stopATStep should_stop = False has_do_train1_dataset = False checkpoint_branch_one = True for i in range(epoch): cb_params.cur_epoch_num = i + 1 list_callback.epoch_begin(run_context) # for data sink dataset_helper only iter once, other wise iter epoch_size times. for inputs in dataset_helper: list_callback.step_begin(run_context) if checkpoint_branch_one: cb_params.cur_step_num += loop_size self._train_network.set_second_order(True) self._train_network.phase = 'train0' else: cb_params.cur_step_num += iter_first_order self._train_network.set_second_order(False) self._train_network.phase = 'train1' if not has_do_train1_dataset: _exec_datagraph(train_dataset, iter_first_order, phase='train1_dataset') has_do_train1_dataset = True checkpoint_branch_one = not checkpoint_branch_one outputs = self._train_network(*inputs) cb_params.net_outputs = outputs list_callback.step_end(run_context) list_callback.epoch_end(run_context) should_stop = should_stop or run_context.get_stop_requested() if should_stop: break list_callback.end(run_context)