def _create_execution_agent(self): """Creates a TrainingAgent that can run the forward and backward graph on the training model""" session_options, providers, provider_options = self._get_session_config( ) fw_feed_names = [ input.name for input in self._optimized_onnx_model.graph.input ] fw_outputs_device_info = [] for idx in range(len(self._graph_info.user_output_names)): fw_outputs_device_info.append( C.OrtDevice(get_ort_device_type(self._device.type), C.OrtDevice.default_memory(), _utils.get_device_index(self._device))) bw_fetches_names = [ output.name for output in self._optimized_onnx_model.graph.output ] bw_outputs_device_info = [] for idx in range(len(bw_fetches_names)): bw_outputs_device_info.append( C.OrtDevice(get_ort_device_type(self._device.type), C.OrtDevice.default_memory(), _utils.get_device_index(self._device))) self._execution_agent = TrainingAgent( self._optimized_onnx_model.SerializeToString(), fw_feed_names, fw_outputs_device_info, bw_fetches_names, bw_outputs_device_info, session_options, providers, provider_options)
def _create_execution_agent(self): """Creates a TrainingAgent that can run the forward and backward graph on the training model""" session_options, providers, provider_options = self._get_session_config( ) fw_feed_names = [ input.name for input in self._onnx_models.optimized_model.graph.input ] device_type = self._device if type( self._device) is str else self._device.type.lower() if device_type == "ort": fw_outputs_device_info = [C.get_ort_device(self._device.index)] * ( len(self._graph_info.user_output_names) + len(self._graph_info.frontier_node_arg_map)) else: fw_outputs_device_info = [ C.OrtDevice( get_ort_device_type(self._device.type, self._device.index), C.OrtDevice.default_memory(), _utils.get_device_index(self._device), ) ] * (len(self._graph_info.user_output_names) + len(self._graph_info.frontier_node_arg_map)) bw_fetches_names = [ output.name for output in self._onnx_models.optimized_model.graph.output ] if device_type == "ort": bw_outputs_device_info = [C.get_ort_device(self._device.index) ] * len(bw_fetches_names) else: bw_outputs_device_info = [ C.OrtDevice( get_ort_device_type(self._device.type, self._device.index), C.OrtDevice.default_memory(), _utils.get_device_index(self._device), ) ] * len(bw_fetches_names) local_device_rank = self._device.index if device_type == "ort" else _utils.get_device_index( self._device) self._execution_agent = TrainingAgent( self._onnx_models.optimized_model.SerializeToString(), fw_feed_names, fw_outputs_device_info, bw_fetches_names, bw_outputs_device_info, session_options, providers, provider_options, local_device_rank, )