def create_ort_training_session_bind_parameters(model, device, world_rank=-1, world_size=1, gradient_accumulation_steps=1): output_name = model.graph.output[0].name ort_parameters = ort.TrainingParameters() ort_parameters.loss_output_name = output_name ort_parameters.use_mixed_precision = False ort_parameters.world_rank = world_rank ort_parameters.world_size = world_size ort_parameters.gradient_accumulation_steps = gradient_accumulation_steps torch_params = {} output_types = {} for output in model.graph.output: output_types[output.name] = output.type.tensor_type for initializer in model.graph.initializer: torch_tensor = torch.nn.Parameter(torch.as_tensor(numpy_helper.to_array(initializer), device=device)) delete_input_with_name(model.graph.input, initializer.name) model.graph.input.extend( [helper.make_tensor_value_info(initializer.name, initializer.data_type, initializer.dims)]) torch_params[initializer.name] = torch_tensor del model.graph.initializer[:] ort_parameters.weights_to_train = set(torch_params.keys()) if device.type == 'cuda' and hasattr(device, "index") and device.index is not None: from onnxruntime.capi._pybind_state import set_cuda_device_id set_cuda_device_id(device.index) session = ort.TrainingSession(model.SerializeToString(), ort_parameters) train_io_binding = session.io_binding() eval_io_binding = session.io_binding() enable_grad_accumulation = gradient_accumulation_steps > 1 for param in torch_params.keys(): torch_tensor = torch_params[param] train_io_binding.bind_input(param, torch_tensor.device.type, get_device_index(torch_tensor.device), dtype_torch_to_numpy(torch_params[param].dtype), list(torch_tensor.size()), torch_tensor.data_ptr()) eval_io_binding.bind_input(param, torch_tensor.device.type, get_device_index(torch_tensor.device), dtype_torch_to_numpy(torch_params[param].dtype), list(torch_tensor.size()), torch_tensor.data_ptr()) device_index = get_device_index(device) create_and_bind_grad_or_grad_accumulate_buffer(train_io_binding, torch_tensor, param, enable_grad_accumulation, device, device_index) return session, train_io_binding, eval_io_binding, output_name, torch_params, output_types
def _create_ort_training_session(self): # Validating frozen_weights names unused_frozen_weights = [n for n in self.options.utils.frozen_weights\ if n not in [i.name for i in self._onnx_model.graph.initializer]] if unused_frozen_weights: raise RuntimeError("{} params from 'frozen_weights' not found in the ONNX model.".format( unused_frozen_weights)) # Get loss name from model description loss_name = [item.name for item in self.model_desc.outputs if item.is_loss] assert len(loss_name) == 1, f"Only one loss output is supported ({len(loss_name)} were specified)" loss_name = loss_name[0] # Parse optimizer parameters optimizer_attributes_map = {} optimizer_int_attributes_map = {} trainable_params = set() for initializer in self._onnx_model.graph.initializer: if initializer.name in self.options.utils.frozen_weights: continue # only trainable parameters are passed to the backend trainable_params.add(initializer.name) optimizer_attributes_map[initializer.name] = {} optimizer_int_attributes_map[initializer.name] = {} for param_group in self.optim_config.params: if initializer.name not in param_group['params']: continue # keep looking for a matching param_group for k, v in param_group.items(): if k == 'params': continue # 'params' is not a hyper parameter, skip it if isinstance(v, float): optimizer_attributes_map[initializer.name][k] = v elif isinstance(v, int): optimizer_int_attributes_map[initializer.name][k] = v else: raise ValueError("Optimizer attributes must be either float or int.") # TrainingParameters ort_parameters = ort.TrainingParameters() ort_parameters.loss_output_name = loss_name ort_parameters.use_mixed_precision = self.options.mixed_precision.enabled ort_parameters.world_rank = self.options.distributed.world_rank ort_parameters.world_size = self.options.distributed.world_size ort_parameters.gradient_accumulation_steps = self.options.batch.gradient_accumulation_steps ort_parameters.allreduce_post_accumulation = self.options.distributed.allreduce_post_accumulation ort_parameters.deepspeed_zero_stage = self.options.distributed.deepspeed_zero_optimization.stage ort_parameters.enable_grad_norm_clip = self.options.utils.grad_norm_clip ort_parameters.set_gradients_as_graph_outputs = False ort_parameters.use_invertible_layernorm_grad = self.options.utils.invertible_layer_norm_gradient ort_parameters.training_optimizer_name = self.optim_config.name ort_parameters.lr_params_feed_name = self.model_desc.learning_rate.name ort_parameters.weights_to_train = trainable_params ort_parameters.optimizer_attributes_map = optimizer_attributes_map ort_parameters.optimizer_int_attributes_map = optimizer_int_attributes_map # SessionOptions session_options = ort.SessionOptions() session_options.use_deterministic_compute = self.options.debug.deterministic_compute # TrainingSession self._training_session = ort.TrainingSession(self._onnx_model.SerializeToString(), ort_parameters, session_options) # I/O bindings self._train_io_binding = self._training_session.io_binding() self._eval_io_binding = self._training_session.io_binding()
def create_ort_training_session_with_optimizer( model, device, training_optimizer_name, lr_params_feed_name, map_optimizer_attributes, world_rank=-1, world_size=1, gradient_accumulation_steps=1, bind_parameters=False, use_mixed_precision=False, allreduce_post_accumulation=False, deepspeed_zero_stage=0, enable_grad_norm_clip=True, frozen_weights=[], opset_version=DEFAULT_OPSET_VERSION): output_name = model.graph.output[0].name ort_parameters = ort.TrainingParameters() ort_parameters.loss_output_name = output_name ort_parameters.use_mixed_precision = use_mixed_precision ort_parameters.world_rank = world_rank ort_parameters.world_size = world_size ort_parameters.gradient_accumulation_steps = gradient_accumulation_steps ort_parameters.use_mixed_precision = use_mixed_precision ort_parameters.allreduce_post_accumulation = allreduce_post_accumulation ort_parameters.deepspeed_zero_stage = deepspeed_zero_stage ort_parameters.enable_grad_norm_clip = enable_grad_norm_clip ort_parameters.set_gradients_as_graph_outputs = False output_types = {} for output in model.graph.output: output_types[output.name] = output.type.tensor_type # pybind does not allow to add directly to ort_parameters.weights_to_train. # Have to work around by using a temporary weights_to_train. torch_params = {} optimizer_attributes_map = {} optimizer_int_attributes_map = {} unused_frozen_weights = [ n for n in frozen_weights if n not in [i.name for i in model.graph.initializer] ] if unused_frozen_weights: raise RuntimeError( "{} in frozen_weights not found in model weights.".format( unused_frozen_weights)) weights_to_train = set() for initializer in model.graph.initializer: if initializer.name in frozen_weights: continue weights_to_train.add(initializer.name) if map_optimizer_attributes is not None: attributes = map_optimizer_attributes(initializer.name) optimizer_attributes_map[initializer.name] = {} optimizer_int_attributes_map[initializer.name] = {} for k, v in attributes.items(): if isinstance(v, float): optimizer_attributes_map[initializer.name][k] = v elif isinstance(v, int): optimizer_int_attributes_map[initializer.name][k] = v else: raise ValueError( "Optimizer attributes must be either float or int.") else: optimizer_attributes_map[initializer.name] = {} optimizer_int_attributes_map[initializer.name] = {} if bind_parameters: for initializer in model.graph.initializer: torch_tensor = torch.nn.Parameter( torch.as_tensor(numpy_helper.to_array(initializer), device=device)) delete_input_with_name(model.graph.input, initializer.name) model.graph.input.extend([ helper.make_tensor_value_info(initializer.name, initializer.data_type, initializer.dims) ]) torch_params[initializer.name] = torch_tensor del model.graph.initializer[:] ort_parameters.weights_to_train = weights_to_train ort_parameters.training_optimizer_name = training_optimizer_name ort_parameters.lr_params_feed_name = lr_params_feed_name ort_parameters.optimizer_attributes_map = optimizer_attributes_map ort_parameters.optimizer_int_attributes_map = optimizer_int_attributes_map session = ort.TrainingSession(model.SerializeToString(), ort_parameters) train_io_binding = session.io_binding() eval_io_binding = session.io_binding() if bind_parameters: for param in torch_params.keys(): torch_tensor = torch_params[param] train_io_binding.bind_input( param, torch_tensor.device.type, get_device_index(torch_tensor.device), dtype_torch_to_numpy(torch_params[param].dtype), list(torch_tensor.size()), torch_tensor.data_ptr()) eval_io_binding.bind_input( param, torch_tensor.device.type, get_device_index(torch_tensor.device), dtype_torch_to_numpy(torch_params[param].dtype), list(torch_tensor.size()), torch_tensor.data_ptr()) return session, train_io_binding, eval_io_binding, output_name, torch_params, output_types
def _create_ort_training_session(self): # Validating frozen_weights names unused_frozen_weights = [n for n in self.options.utils.frozen_weights\ if n not in [i.name for i in self._onnx_model.graph.initializer]] if unused_frozen_weights: raise RuntimeError( "{} params from 'frozen_weights' not found in the ONNX model.". format(unused_frozen_weights)) # Get loss name from model description loss_name = [ item.name for item in self.model_desc.outputs if item.is_loss ] assert len( loss_name ) == 1, f"Only one loss output is supported ({len(loss_name)} were specified)" loss_name = loss_name[0] # Parse optimizer parameters optimizer_attributes_map = {} optimizer_int_attributes_map = {} trainable_params = set() for initializer in self._onnx_model.graph.initializer: if initializer.name in self.options.utils.frozen_weights: continue # only trainable parameters are passed to the backend trainable_params.add(initializer.name) optimizer_attributes_map[initializer.name] = {} optimizer_int_attributes_map[initializer.name] = {} not_in_param_groups = True for param_group in self.optim_config.params: if initializer.name not in param_group['params']: continue # keep looking for a matching param_group not_in_param_groups = False for k, v in param_group.items(): # 'params' is not a hyper parameter, skip it. 'lr' per weight is not supported if k == 'params' or k == 'lr': continue if isinstance(v, float): optimizer_attributes_map[initializer.name][k] = v elif isinstance(v, int): optimizer_int_attributes_map[initializer.name][k] = v else: raise ValueError( "Optimizer attributes must be either float or int." ) # set default values for params not found in groups if not_in_param_groups: for k, v in self.optim_config.defaults.items(): if k == 'lr': continue if isinstance(v, float): optimizer_attributes_map[initializer.name][k] = v elif isinstance(v, int): optimizer_int_attributes_map[initializer.name][k] = v else: raise ValueError( "Optimizer attributes must be either float or int." ) # TrainingParameters ort_parameters = ort.TrainingParameters() ort_parameters.loss_output_name = loss_name ort_parameters.use_mixed_precision = self.options.mixed_precision.enabled ort_parameters.world_rank = self.options.distributed.world_rank ort_parameters.world_size = self.options.distributed.world_size ort_parameters.gradient_accumulation_steps = self.options.batch.gradient_accumulation_steps ort_parameters.allreduce_post_accumulation = self.options.distributed.allreduce_post_accumulation ort_parameters.deepspeed_zero_stage = self.options.distributed.deepspeed_zero_optimization.stage ort_parameters.enable_grad_norm_clip = self.options.utils.grad_norm_clip ort_parameters.set_gradients_as_graph_outputs = False ort_parameters.use_invertible_layernorm_grad = self.options.utils.invertible_layer_norm_gradient ort_parameters.training_optimizer_name = self.optim_config.name ort_parameters.lr_params_feed_name = self.model_desc.learning_rate.name ort_parameters.weights_to_train = trainable_params ort_parameters.optimizer_attributes_map = optimizer_attributes_map ort_parameters.optimizer_int_attributes_map = optimizer_int_attributes_map ort_parameters.attn_dropout_recompute = self.options.graph_transformer.attn_dropout_recompute ort_parameters.gelu_recompute = self.options.graph_transformer.gelu_recompute ort_parameters.transformer_layer_recompute = self.options.graph_transformer.transformer_layer_recompute ort_parameters.number_recompute_layers = self.options.graph_transformer.number_recompute_layers ort_parameters.model_with_training_graph_path = self.options.debug.model_with_training_graph_path # SessionOptions session_options = ort.SessionOptions() session_options.use_deterministic_compute = self.options.debug.deterministic_compute if (self.options.graph_transformer.attn_dropout_recompute or self.options.graph_transformer.gelu_recompute or self.options.graph_transformer.transformer_layer_recompute): session_options.execution_order = ort.ExecutionOrder.PRIORITY_BASED # old ort session may already exists and occupies GPU memory when creating new session, this may cause OOM error. # for example, load_state_dict will be called before returing the function, and it calls _init_session again del self._training_session # TrainingSession self._training_session = ort.TrainingSession( self._onnx_model.SerializeToString(), ort_parameters, session_options) # I/O bindings self._train_io_binding = self._training_session.io_binding() self._eval_io_binding = self._training_session.io_binding()