Пример #1
0
def create_ort_training_session_bind_parameters(model, device, world_rank=-1, world_size=1,
                                                gradient_accumulation_steps=1):
    output_name = model.graph.output[0].name
    ort_parameters = ort.TrainingParameters()
    ort_parameters.loss_output_name = output_name
    ort_parameters.use_mixed_precision = False
    ort_parameters.world_rank = world_rank
    ort_parameters.world_size = world_size
    ort_parameters.gradient_accumulation_steps = gradient_accumulation_steps

    torch_params = {}
    output_types = {}
    for output in model.graph.output:
        output_types[output.name] = output.type.tensor_type

    for initializer in model.graph.initializer:
        torch_tensor = torch.nn.Parameter(torch.as_tensor(numpy_helper.to_array(initializer), device=device))
        delete_input_with_name(model.graph.input, initializer.name)
        model.graph.input.extend(
            [helper.make_tensor_value_info(initializer.name, initializer.data_type, initializer.dims)])

        torch_params[initializer.name] = torch_tensor

    del model.graph.initializer[:]

    ort_parameters.weights_to_train = set(torch_params.keys())

    if device.type == 'cuda' and hasattr(device, "index") and device.index is not None:
        from onnxruntime.capi._pybind_state import set_cuda_device_id
        set_cuda_device_id(device.index)
    session = ort.TrainingSession(model.SerializeToString(), ort_parameters)

    train_io_binding = session.io_binding()
    eval_io_binding = session.io_binding()

    enable_grad_accumulation = gradient_accumulation_steps > 1
    for param in torch_params.keys():
        torch_tensor = torch_params[param]

        train_io_binding.bind_input(param, torch_tensor.device.type, get_device_index(torch_tensor.device),
                                    dtype_torch_to_numpy(torch_params[param].dtype), list(torch_tensor.size()),
                                    torch_tensor.data_ptr())
        eval_io_binding.bind_input(param, torch_tensor.device.type, get_device_index(torch_tensor.device),
                                   dtype_torch_to_numpy(torch_params[param].dtype), list(torch_tensor.size()),
                                   torch_tensor.data_ptr())

        device_index = get_device_index(device)
        create_and_bind_grad_or_grad_accumulate_buffer(train_io_binding, torch_tensor, param, enable_grad_accumulation, device, device_index)

    return session, train_io_binding, eval_io_binding, output_name, torch_params, output_types
Пример #2
0
    def _create_ort_training_session(self):
        # Validating frozen_weights names
        unused_frozen_weights = [n for n in self.options.utils.frozen_weights\
            if n not in [i.name for i in self._onnx_model.graph.initializer]]
        if unused_frozen_weights:
            raise RuntimeError("{} params from 'frozen_weights' not found in the ONNX model.".format(
                unused_frozen_weights))

        # Get loss name from model description
        loss_name = [item.name for item in self.model_desc.outputs if item.is_loss]
        assert len(loss_name) == 1, f"Only one loss output is supported ({len(loss_name)} were specified)"
        loss_name = loss_name[0]

        # Parse optimizer parameters
        optimizer_attributes_map = {}
        optimizer_int_attributes_map = {}
        trainable_params = set()
        for initializer in self._onnx_model.graph.initializer:
            if initializer.name in self.options.utils.frozen_weights:
                continue  # only trainable parameters are passed to the backend
            trainable_params.add(initializer.name)
            optimizer_attributes_map[initializer.name] = {}
            optimizer_int_attributes_map[initializer.name] = {}
            for param_group in self.optim_config.params:
                if initializer.name not in param_group['params']:
                    continue  # keep looking for a matching param_group
                for k, v in param_group.items():
                    if k == 'params':
                        continue  # 'params' is not a hyper parameter, skip it
                    if isinstance(v, float):
                        optimizer_attributes_map[initializer.name][k] = v
                    elif isinstance(v, int):
                        optimizer_int_attributes_map[initializer.name][k] = v
                    else:
                        raise ValueError("Optimizer attributes must be either float or int.")

        # TrainingParameters
        ort_parameters = ort.TrainingParameters()
        ort_parameters.loss_output_name = loss_name
        ort_parameters.use_mixed_precision = self.options.mixed_precision.enabled
        ort_parameters.world_rank = self.options.distributed.world_rank
        ort_parameters.world_size = self.options.distributed.world_size
        ort_parameters.gradient_accumulation_steps = self.options.batch.gradient_accumulation_steps
        ort_parameters.allreduce_post_accumulation = self.options.distributed.allreduce_post_accumulation
        ort_parameters.deepspeed_zero_stage = self.options.distributed.deepspeed_zero_optimization.stage
        ort_parameters.enable_grad_norm_clip = self.options.utils.grad_norm_clip
        ort_parameters.set_gradients_as_graph_outputs = False
        ort_parameters.use_invertible_layernorm_grad = self.options.utils.invertible_layer_norm_gradient
        ort_parameters.training_optimizer_name = self.optim_config.name
        ort_parameters.lr_params_feed_name = self.model_desc.learning_rate.name
        ort_parameters.weights_to_train = trainable_params
        ort_parameters.optimizer_attributes_map = optimizer_attributes_map
        ort_parameters.optimizer_int_attributes_map = optimizer_int_attributes_map

        # SessionOptions
        session_options = ort.SessionOptions()
        session_options.use_deterministic_compute = self.options.debug.deterministic_compute

        # TrainingSession
        self._training_session = ort.TrainingSession(self._onnx_model.SerializeToString(),
                                                     ort_parameters,
                                                     session_options)

        # I/O bindings
        self._train_io_binding = self._training_session.io_binding()
        self._eval_io_binding = self._training_session.io_binding()
Пример #3
0
def create_ort_training_session_with_optimizer(
        model,
        device,
        training_optimizer_name,
        lr_params_feed_name,
        map_optimizer_attributes,
        world_rank=-1,
        world_size=1,
        gradient_accumulation_steps=1,
        bind_parameters=False,
        use_mixed_precision=False,
        allreduce_post_accumulation=False,
        deepspeed_zero_stage=0,
        enable_grad_norm_clip=True,
        frozen_weights=[],
        opset_version=DEFAULT_OPSET_VERSION):
    output_name = model.graph.output[0].name
    ort_parameters = ort.TrainingParameters()
    ort_parameters.loss_output_name = output_name
    ort_parameters.use_mixed_precision = use_mixed_precision
    ort_parameters.world_rank = world_rank
    ort_parameters.world_size = world_size
    ort_parameters.gradient_accumulation_steps = gradient_accumulation_steps
    ort_parameters.use_mixed_precision = use_mixed_precision
    ort_parameters.allreduce_post_accumulation = allreduce_post_accumulation
    ort_parameters.deepspeed_zero_stage = deepspeed_zero_stage
    ort_parameters.enable_grad_norm_clip = enable_grad_norm_clip
    ort_parameters.set_gradients_as_graph_outputs = False

    output_types = {}
    for output in model.graph.output:
        output_types[output.name] = output.type.tensor_type

    # pybind does not allow to add directly to ort_parameters.weights_to_train.
    # Have to work around by using a temporary weights_to_train.
    torch_params = {}
    optimizer_attributes_map = {}
    optimizer_int_attributes_map = {}

    unused_frozen_weights = [
        n for n in frozen_weights
        if n not in [i.name for i in model.graph.initializer]
    ]
    if unused_frozen_weights:
        raise RuntimeError(
            "{} in frozen_weights not found in model weights.".format(
                unused_frozen_weights))

    weights_to_train = set()
    for initializer in model.graph.initializer:
        if initializer.name in frozen_weights:
            continue
        weights_to_train.add(initializer.name)
        if map_optimizer_attributes is not None:
            attributes = map_optimizer_attributes(initializer.name)
            optimizer_attributes_map[initializer.name] = {}
            optimizer_int_attributes_map[initializer.name] = {}
            for k, v in attributes.items():
                if isinstance(v, float):
                    optimizer_attributes_map[initializer.name][k] = v
                elif isinstance(v, int):
                    optimizer_int_attributes_map[initializer.name][k] = v
                else:
                    raise ValueError(
                        "Optimizer attributes must be either float or int.")
        else:
            optimizer_attributes_map[initializer.name] = {}
            optimizer_int_attributes_map[initializer.name] = {}

    if bind_parameters:
        for initializer in model.graph.initializer:
            torch_tensor = torch.nn.Parameter(
                torch.as_tensor(numpy_helper.to_array(initializer),
                                device=device))
            delete_input_with_name(model.graph.input, initializer.name)
            model.graph.input.extend([
                helper.make_tensor_value_info(initializer.name,
                                              initializer.data_type,
                                              initializer.dims)
            ])
            torch_params[initializer.name] = torch_tensor

        del model.graph.initializer[:]

    ort_parameters.weights_to_train = weights_to_train
    ort_parameters.training_optimizer_name = training_optimizer_name
    ort_parameters.lr_params_feed_name = lr_params_feed_name
    ort_parameters.optimizer_attributes_map = optimizer_attributes_map
    ort_parameters.optimizer_int_attributes_map = optimizer_int_attributes_map

    session = ort.TrainingSession(model.SerializeToString(), ort_parameters)
    train_io_binding = session.io_binding()
    eval_io_binding = session.io_binding()

    if bind_parameters:
        for param in torch_params.keys():
            torch_tensor = torch_params[param]

            train_io_binding.bind_input(
                param, torch_tensor.device.type,
                get_device_index(torch_tensor.device),
                dtype_torch_to_numpy(torch_params[param].dtype),
                list(torch_tensor.size()), torch_tensor.data_ptr())
            eval_io_binding.bind_input(
                param, torch_tensor.device.type,
                get_device_index(torch_tensor.device),
                dtype_torch_to_numpy(torch_params[param].dtype),
                list(torch_tensor.size()), torch_tensor.data_ptr())

    return session, train_io_binding, eval_io_binding, output_name, torch_params, output_types
Пример #4
0
    def _create_ort_training_session(self):
        # Validating frozen_weights names
        unused_frozen_weights = [n for n in self.options.utils.frozen_weights\
            if n not in [i.name for i in self._onnx_model.graph.initializer]]
        if unused_frozen_weights:
            raise RuntimeError(
                "{} params from 'frozen_weights' not found in the ONNX model.".
                format(unused_frozen_weights))

        # Get loss name from model description
        loss_name = [
            item.name for item in self.model_desc.outputs if item.is_loss
        ]
        assert len(
            loss_name
        ) == 1, f"Only one loss output is supported ({len(loss_name)} were specified)"
        loss_name = loss_name[0]

        # Parse optimizer parameters
        optimizer_attributes_map = {}
        optimizer_int_attributes_map = {}
        trainable_params = set()
        for initializer in self._onnx_model.graph.initializer:
            if initializer.name in self.options.utils.frozen_weights:
                continue  # only trainable parameters are passed to the backend
            trainable_params.add(initializer.name)
            optimizer_attributes_map[initializer.name] = {}
            optimizer_int_attributes_map[initializer.name] = {}
            not_in_param_groups = True
            for param_group in self.optim_config.params:
                if initializer.name not in param_group['params']:
                    continue  # keep looking for a matching param_group
                not_in_param_groups = False
                for k, v in param_group.items():
                    # 'params' is not a hyper parameter, skip it. 'lr' per weight is not supported
                    if k == 'params' or k == 'lr':
                        continue
                    if isinstance(v, float):
                        optimizer_attributes_map[initializer.name][k] = v
                    elif isinstance(v, int):
                        optimizer_int_attributes_map[initializer.name][k] = v
                    else:
                        raise ValueError(
                            "Optimizer attributes must be either float or int."
                        )

            # set default values for params not found in groups
            if not_in_param_groups:
                for k, v in self.optim_config.defaults.items():
                    if k == 'lr':
                        continue
                    if isinstance(v, float):
                        optimizer_attributes_map[initializer.name][k] = v
                    elif isinstance(v, int):
                        optimizer_int_attributes_map[initializer.name][k] = v
                    else:
                        raise ValueError(
                            "Optimizer attributes must be either float or int."
                        )

        # TrainingParameters
        ort_parameters = ort.TrainingParameters()
        ort_parameters.loss_output_name = loss_name
        ort_parameters.use_mixed_precision = self.options.mixed_precision.enabled
        ort_parameters.world_rank = self.options.distributed.world_rank
        ort_parameters.world_size = self.options.distributed.world_size
        ort_parameters.gradient_accumulation_steps = self.options.batch.gradient_accumulation_steps
        ort_parameters.allreduce_post_accumulation = self.options.distributed.allreduce_post_accumulation
        ort_parameters.deepspeed_zero_stage = self.options.distributed.deepspeed_zero_optimization.stage
        ort_parameters.enable_grad_norm_clip = self.options.utils.grad_norm_clip
        ort_parameters.set_gradients_as_graph_outputs = False
        ort_parameters.use_invertible_layernorm_grad = self.options.utils.invertible_layer_norm_gradient
        ort_parameters.training_optimizer_name = self.optim_config.name
        ort_parameters.lr_params_feed_name = self.model_desc.learning_rate.name
        ort_parameters.weights_to_train = trainable_params
        ort_parameters.optimizer_attributes_map = optimizer_attributes_map
        ort_parameters.optimizer_int_attributes_map = optimizer_int_attributes_map

        ort_parameters.attn_dropout_recompute = self.options.graph_transformer.attn_dropout_recompute
        ort_parameters.gelu_recompute = self.options.graph_transformer.gelu_recompute
        ort_parameters.transformer_layer_recompute = self.options.graph_transformer.transformer_layer_recompute
        ort_parameters.number_recompute_layers = self.options.graph_transformer.number_recompute_layers
        ort_parameters.model_with_training_graph_path = self.options.debug.model_with_training_graph_path

        # SessionOptions
        session_options = ort.SessionOptions()
        session_options.use_deterministic_compute = self.options.debug.deterministic_compute
        if (self.options.graph_transformer.attn_dropout_recompute
                or self.options.graph_transformer.gelu_recompute
                or self.options.graph_transformer.transformer_layer_recompute):
            session_options.execution_order = ort.ExecutionOrder.PRIORITY_BASED

        # old ort session may already exists and occupies GPU memory when creating new session, this may cause OOM error.
        # for example, load_state_dict will be called before returing the function, and it calls _init_session again
        del self._training_session
        # TrainingSession
        self._training_session = ort.TrainingSession(
            self._onnx_model.SerializeToString(), ort_parameters,
            session_options)

        # I/O bindings
        self._train_io_binding = self._training_session.io_binding()
        self._eval_io_binding = self._training_session.io_binding()