コード例 #1
0
def create_training_session(training_onnx,
                            weights_to_train,
                            loss_output_name='loss',
                            training_optimizer_name='SGDOptimizer',
                            device='cpu'):
    """
    Creates an instance of class `TrainingSession`.

    :param training_onnx: ONNX graph used to train
    :param weights_to_train: names of initializers to be optimized
    :param loss_output_name: name of the loss output
    :param training_optimizer_name: optimizer name
    :param device: `'cpu'` or `'cuda'`
    :return: instance of `TrainingSession`
    """
    ort_parameters = TrainingParameters()
    ort_parameters.loss_output_name = loss_output_name

    output_types = {}
    for output in training_onnx.graph.output:
        output_types[output.name] = output.type.tensor_type

    ort_parameters.weights_to_train = set(weights_to_train)
    ort_parameters.training_optimizer_name = training_optimizer_name

    ort_parameters.optimizer_attributes_map = {
        name: {}
        for name in weights_to_train
    }
    ort_parameters.optimizer_int_attributes_map = {
        name: {}
        for name in weights_to_train
    }

    session_options = SessionOptions()
    session_options.use_deterministic_compute = True

    if hasattr(device, 'device_type'):
        if device.device_type() == device.cpu():
            provider = ['CPUExecutionProvider']
        elif device.device_type() == device.cuda():
            provider = ['CUDAExecutionProvider']
        else:
            raise ValueError(f"Unexpected device {device!r}.")
    else:
        if device == 'cpu':
            provider = ['CPUExecutionProvider']
        elif device.startswith("cuda"):
            provider = ['CUDAExecutionProvider']
        else:
            raise ValueError(f"Unexpected device {device!r}.")

    session = TrainingSession(training_onnx.SerializeToString(),
                              ort_parameters,
                              session_options,
                              providers=provider)
    return session
コード例 #2
0
def create_training_session(training_onnx,
                            weights_to_train,
                            loss_output_name='loss',
                            training_optimizer_name='SGDOptimizer'):
    """
    Creates an instance of class `TrainingSession`.

    :param training_onnx: ONNX graph used to train
    :param weights_to_train: names of initializers to be optimized
    :param loss_output_name: name of the loss output
    :param training_optimizer_name: optimizer name
    :return: instance of `TrainingSession`
    """
    ort_parameters = TrainingParameters()
    ort_parameters.loss_output_name = loss_output_name

    output_types = {}
    for output in training_onnx.graph.output:
        output_types[output.name] = output.type.tensor_type

    ort_parameters.weights_to_train = set(weights_to_train)
    ort_parameters.training_optimizer_name = training_optimizer_name

    ort_parameters.optimizer_attributes_map = {
        name: {}
        for name in weights_to_train
    }
    ort_parameters.optimizer_int_attributes_map = {
        name: {}
        for name in weights_to_train
    }

    session_options = SessionOptions()
    session_options.use_deterministic_compute = True

    session = TrainingSession(training_onnx.SerializeToString(),
                              ort_parameters, session_options)
    return session