def create_training_session(training_onnx, weights_to_train, loss_output_name='loss', training_optimizer_name='SGDOptimizer', device='cpu'): """ Creates an instance of class `TrainingSession`. :param training_onnx: ONNX graph used to train :param weights_to_train: names of initializers to be optimized :param loss_output_name: name of the loss output :param training_optimizer_name: optimizer name :param device: `'cpu'` or `'cuda'` :return: instance of `TrainingSession` """ ort_parameters = TrainingParameters() ort_parameters.loss_output_name = loss_output_name output_types = {} for output in training_onnx.graph.output: output_types[output.name] = output.type.tensor_type ort_parameters.weights_to_train = set(weights_to_train) ort_parameters.training_optimizer_name = training_optimizer_name ort_parameters.optimizer_attributes_map = { name: {} for name in weights_to_train } ort_parameters.optimizer_int_attributes_map = { name: {} for name in weights_to_train } session_options = SessionOptions() session_options.use_deterministic_compute = True if hasattr(device, 'device_type'): if device.device_type() == device.cpu(): provider = ['CPUExecutionProvider'] elif device.device_type() == device.cuda(): provider = ['CUDAExecutionProvider'] else: raise ValueError(f"Unexpected device {device!r}.") else: if device == 'cpu': provider = ['CPUExecutionProvider'] elif device.startswith("cuda"): provider = ['CUDAExecutionProvider'] else: raise ValueError(f"Unexpected device {device!r}.") session = TrainingSession(training_onnx.SerializeToString(), ort_parameters, session_options, providers=provider) return session
def create_training_session(training_onnx, weights_to_train, loss_output_name='loss', training_optimizer_name='SGDOptimizer'): """ Creates an instance of class `TrainingSession`. :param training_onnx: ONNX graph used to train :param weights_to_train: names of initializers to be optimized :param loss_output_name: name of the loss output :param training_optimizer_name: optimizer name :return: instance of `TrainingSession` """ ort_parameters = TrainingParameters() ort_parameters.loss_output_name = loss_output_name output_types = {} for output in training_onnx.graph.output: output_types[output.name] = output.type.tensor_type ort_parameters.weights_to_train = set(weights_to_train) ort_parameters.training_optimizer_name = training_optimizer_name ort_parameters.optimizer_attributes_map = { name: {} for name in weights_to_train } ort_parameters.optimizer_int_attributes_map = { name: {} for name in weights_to_train } session_options = SessionOptions() session_options.use_deterministic_compute = True session = TrainingSession(training_onnx.SerializeToString(), ort_parameters, session_options) return session