Example #1
0
    def _init_session(self):
        if self.onnx_model_ is None:
            return

        if self._enable_internal_postprocess:
            self._onnx_model_ = postprocess.run_postprocess(self.onnx_model_)

        if self._extra_postprocess:
            self._extra_postprocess(self.onnx_model_)

        self._verify_fully_optimized_model(self.onnx_model_)
        self.session, self.train_io_binding, self.eval_io_binding, self.output_name, _, self.output_types = \
            create_ort_training_session_with_optimizer(
                self.onnx_model_, self.device_,
                self.training_optimizer_name_, self.learning_rate_description_.name_, self.map_optimizer_attributes_,
                self.world_rank, self.world_size,
                self.gradient_accumulation_steps, bind_parameters=False,
                use_mixed_precision=self.use_mixed_precision, allreduce_post_accumulation=self.allreduce_post_accumulation_,
                deepspeed_zero_stage=self.deepspeed_zero_stage_,
                enable_grad_norm_clip=self.enable_grad_norm_clip_,
                frozen_weights=self.frozen_weights_, opset_version=self.opset_version_)

        self.loss_scale_input_name = self.session.loss_scale_input_name

        if self.use_mixed_precision:
            self.input_desc_with_lr_and_loss_scale = [
                *self.input_desc_with_lr,
                IODescription(self.loss_scale_input_name, [], torch.float32)
            ]

        # ORT backend has modified model output dtype from float32 to float16.
        for o_desc in self.model_desc_.outputs_:
            if self.use_mixed_precision and o_desc.dtype_ == torch.float32 and not self.session.is_output_fp32_node(
                    o_desc.name_):
                o_desc.eval_dtype_ = torch.float16
            else:
                o_desc.eval_dtype_ = o_desc.dtype_

        # gradient accumulation buffers are connected to a single node with a boolean, dimension 1 tensor output.
        # add a matching output to drive gradient accumulation.
        if self.gradient_accumulation_steps > 1:
            self.output_desc_with_group_accumulated_gradients = [
                *self.model_desc_.outputs_,
                IODescription(
                    get_group_accumulated_gradients_output_node_arg_name(
                        self.session), [1], torch.bool)
            ]

        if self.use_mixed_precision:
            # when ready to use accumulated gradient with mixed precision, we need to fetch all_infinite to determine
            # if the gradient is usable.
            self.output_desc_with_all_fp_16_or_fp32_gradients_finite = [
                *self.model_desc_.outputs_,
                IODescription(get_all_gradients_finite_arg_name(self.session),
                              [1], torch.bool)
            ]

        if self.state_dict_:
            self.load_state_dict(self.state_dict_, self.strict_)
        self.state_dict_ = None
Example #2
0
def convert_model_loss_fn_to_onnx(model,
                                  loss_fn,
                                  model_desc,
                                  device,
                                  inputs,
                                  opset_version=DEFAULT_OPSET_VERSION,
                                  _enable_internal_postprocess=True):
    # example: {input0:{0:'batch'}, input1:{0:'batch'}}
    dynamic_axes = {}
    for input in model_desc.inputs_:
        symbolic_axis = {}
        for i, axis in enumerate(input.shape_):
            if isinstance(axis, str):
                symbolic_axis[i] = axis
        if len(symbolic_axis):
            dynamic_axes[input.name_] = symbolic_axis

    for output in model_desc.outputs_:
        symbolic_axis = {}
        for i, axis in enumerate(output.shape_):
            if isinstance(axis, str):
                symbolic_axis[i] = axis
        if len(symbolic_axis):
            dynamic_axes[output.name_] = symbolic_axis

    input_names = [input.name_ for input in model_desc.inputs_]
    output_names = [output.name_ for output in model_desc.outputs_]

    if isinstance(inputs, torch.Tensor):
        inputs = [inputs]
    if isinstance(inputs, dict):
        sample_inputs = [
            inputs[k.name_].to(device=device) for k in model_desc.inputs_
        ]
    elif isinstance(inputs, (list, tuple)):
        sample_inputs = [
            input.to(device=device) for i, input in enumerate(inputs)
            if i < len(model_desc.inputs_)
        ]
    else:
        raise RuntimeError(
            "Unexpected input type. Only torch.Tensor, or dict/list/tuple of torch.Tensor is supported."
        )

    # pytorch onnx exporter/trace does not try to match argument names.
    # e.g. for models with optional inputs, it requires all inputs be present.
    # this is a problem because the model graph depends on inputs provided.
    model = wrap_for_input_match(model, loss_fn, input_names)

    model.eval()
    with torch.no_grad():
        sample_outputs = model(*sample_inputs)
    if isinstance(sample_outputs, torch.Tensor):
        sample_outputs = [sample_outputs]
    for sample_output, output_desc in zip(sample_outputs, model_desc.outputs_):
        output_desc.dtype_ = sample_output.dtype
    model.train()

    f = io.BytesIO()

    # Other export options to use(this is for backward compatibility).
    other_export_options = {}
    other_export_options['training'] = True

    # This option was added after 1.4 release.
    if LooseVersion(torch.__version__) > LooseVersion('1.4.0'):
        other_export_options['enable_onnx_checker'] = False
    # This option was added after 1.6 release.
    if LooseVersion(torch.__version__) >= LooseVersion('1.6.0'):
        other_export_options['training'] = torch.onnx.TrainingMode.TRAINING

    torch.onnx._export(model,
                       tuple(sample_inputs),
                       f,
                       input_names=input_names,
                       output_names=output_names,
                       opset_version=opset_version,
                       dynamic_axes=dynamic_axes,
                       _retain_param_name=True,
                       example_outputs=tuple(sample_outputs),
                       do_constant_folding=False,
                       **other_export_options)

    onnx_model = onnx.load_model_from_string(f.getvalue())

    # Remove 'model_.' prefix introduced by model wrapper for initializers.
    replace_name_dict = {}
    for n in onnx_model.graph.initializer:
        if n.name.startswith('model_.'):
            replace_name_dict[n.name] = n.name[len('model_.'):]
            n.name = replace_name_dict[n.name]
    for n in onnx_model.graph.node:
        for i, name in enumerate(n.input):
            if name in replace_name_dict:
                n.input[i] = replace_name_dict[name]

    # onnx model initializer may contain non-trainable registered buffers that are not part
    # of pytorch model named parameteres.
    named_parameters = model.model_.named_parameters() if hasattr(
        model, 'model_') else model.named_parameters()
    assert set([n for n, t in named_parameters]).issubset(
        set([n.name for n in onnx_model.graph.initializer])), \
        "Initializer names do not match between PyTorch model and ONNX model, " \
        "please report a bug to ONNX Runtime."

    if _enable_internal_postprocess:
        onnx_model = postprocess.run_postprocess(onnx_model)

    return onnx_model