def _init_session(self): if self.onnx_model_ is None: return if self._enable_internal_postprocess: self._onnx_model_ = postprocess.run_postprocess(self.onnx_model_) if self._extra_postprocess: self._extra_postprocess(self.onnx_model_) self._verify_fully_optimized_model(self.onnx_model_) self.session, self.train_io_binding, self.eval_io_binding, self.output_name, _, self.output_types = \ create_ort_training_session_with_optimizer( self.onnx_model_, self.device_, self.training_optimizer_name_, self.learning_rate_description_.name_, self.map_optimizer_attributes_, self.world_rank, self.world_size, self.gradient_accumulation_steps, bind_parameters=False, use_mixed_precision=self.use_mixed_precision, allreduce_post_accumulation=self.allreduce_post_accumulation_, deepspeed_zero_stage=self.deepspeed_zero_stage_, enable_grad_norm_clip=self.enable_grad_norm_clip_, frozen_weights=self.frozen_weights_, opset_version=self.opset_version_) self.loss_scale_input_name = self.session.loss_scale_input_name if self.use_mixed_precision: self.input_desc_with_lr_and_loss_scale = [ *self.input_desc_with_lr, IODescription(self.loss_scale_input_name, [], torch.float32) ] # ORT backend has modified model output dtype from float32 to float16. for o_desc in self.model_desc_.outputs_: if self.use_mixed_precision and o_desc.dtype_ == torch.float32 and not self.session.is_output_fp32_node( o_desc.name_): o_desc.eval_dtype_ = torch.float16 else: o_desc.eval_dtype_ = o_desc.dtype_ # gradient accumulation buffers are connected to a single node with a boolean, dimension 1 tensor output. # add a matching output to drive gradient accumulation. if self.gradient_accumulation_steps > 1: self.output_desc_with_group_accumulated_gradients = [ *self.model_desc_.outputs_, IODescription( get_group_accumulated_gradients_output_node_arg_name( self.session), [1], torch.bool) ] if self.use_mixed_precision: # when ready to use accumulated gradient with mixed precision, we need to fetch all_infinite to determine # if the gradient is usable. self.output_desc_with_all_fp_16_or_fp32_gradients_finite = [ *self.model_desc_.outputs_, IODescription(get_all_gradients_finite_arg_name(self.session), [1], torch.bool) ] if self.state_dict_: self.load_state_dict(self.state_dict_, self.strict_) self.state_dict_ = None
def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs, opset_version=DEFAULT_OPSET_VERSION, _enable_internal_postprocess=True): # example: {input0:{0:'batch'}, input1:{0:'batch'}} dynamic_axes = {} for input in model_desc.inputs_: symbolic_axis = {} for i, axis in enumerate(input.shape_): if isinstance(axis, str): symbolic_axis[i] = axis if len(symbolic_axis): dynamic_axes[input.name_] = symbolic_axis for output in model_desc.outputs_: symbolic_axis = {} for i, axis in enumerate(output.shape_): if isinstance(axis, str): symbolic_axis[i] = axis if len(symbolic_axis): dynamic_axes[output.name_] = symbolic_axis input_names = [input.name_ for input in model_desc.inputs_] output_names = [output.name_ for output in model_desc.outputs_] if isinstance(inputs, torch.Tensor): inputs = [inputs] if isinstance(inputs, dict): sample_inputs = [ inputs[k.name_].to(device=device) for k in model_desc.inputs_ ] elif isinstance(inputs, (list, tuple)): sample_inputs = [ input.to(device=device) for i, input in enumerate(inputs) if i < len(model_desc.inputs_) ] else: raise RuntimeError( "Unexpected input type. Only torch.Tensor, or dict/list/tuple of torch.Tensor is supported." ) # pytorch onnx exporter/trace does not try to match argument names. # e.g. for models with optional inputs, it requires all inputs be present. # this is a problem because the model graph depends on inputs provided. model = wrap_for_input_match(model, loss_fn, input_names) model.eval() with torch.no_grad(): sample_outputs = model(*sample_inputs) if isinstance(sample_outputs, torch.Tensor): sample_outputs = [sample_outputs] for sample_output, output_desc in zip(sample_outputs, model_desc.outputs_): output_desc.dtype_ = sample_output.dtype model.train() f = io.BytesIO() # Other export options to use(this is for backward compatibility). other_export_options = {} other_export_options['training'] = True # This option was added after 1.4 release. if LooseVersion(torch.__version__) > LooseVersion('1.4.0'): other_export_options['enable_onnx_checker'] = False # This option was added after 1.6 release. if LooseVersion(torch.__version__) >= LooseVersion('1.6.0'): other_export_options['training'] = torch.onnx.TrainingMode.TRAINING torch.onnx._export(model, tuple(sample_inputs), f, input_names=input_names, output_names=output_names, opset_version=opset_version, dynamic_axes=dynamic_axes, _retain_param_name=True, example_outputs=tuple(sample_outputs), do_constant_folding=False, **other_export_options) onnx_model = onnx.load_model_from_string(f.getvalue()) # Remove 'model_.' prefix introduced by model wrapper for initializers. replace_name_dict = {} for n in onnx_model.graph.initializer: if n.name.startswith('model_.'): replace_name_dict[n.name] = n.name[len('model_.'):] n.name = replace_name_dict[n.name] for n in onnx_model.graph.node: for i, name in enumerate(n.input): if name in replace_name_dict: n.input[i] = replace_name_dict[name] # onnx model initializer may contain non-trainable registered buffers that are not part # of pytorch model named parameteres. named_parameters = model.model_.named_parameters() if hasattr( model, 'model_') else model.named_parameters() assert set([n for n, t in named_parameters]).issubset( set([n.name for n in onnx_model.graph.initializer])), \ "Initializer names do not match between PyTorch model and ONNX model, " \ "please report a bug to ONNX Runtime." if _enable_internal_postprocess: onnx_model = postprocess.run_postprocess(onnx_model) return onnx_model