コード例 #1
0
def _create_iobinding(io_binding, inputs, model, device):
    '''Creates IO binding for a `model` inputs and output'''
    for idx, value_info in enumerate(model.graph.input):
        io_binding.bind_ortvalue_input(value_info.name, OrtValue(_ortvalue_from_torch_tensor(inputs[idx])))

    for value_info in model.graph.output:
        io_binding.bind_output(value_info.name, device.type, device_id=get_device_index(device))
コード例 #2
0
 def run_forward(self, iobinding, run_options):
     """
      Compute the forward subgraph until it hits the Yield Op.
      :param iobinding: the iobinding object that has graph inputs/outputs bind.
      :param run_options: See :class:`onnxruntime.RunOptions`.
     """
     ortvalues, run_id = self._training_agent.run_forward(
         iobinding._iobinding, run_options)
     return [OrtValue(ortvalue) for ortvalue in ortvalues], run_id
コード例 #3
0
def _ortvalue_from_torch_tensor(torch_tensor):
    return OrtValue(
        C.OrtValue.from_dlpack(to_dlpack(torch_tensor),
                               torch_tensor.dtype == torch.bool))