コード例 #1
0
def _parse_model(config, model_name, batch_size):
    if config.max_batch_size == 0:
        if batch_size != 1:
            raise Exception("batching not supported for model '" + model_name + "'")
    else:  # max_batch_size > 0
        if batch_size > config.max_batch_size:
            raise Exception(
                "expecting batch size <= {} for model '{}'".format(config.max_batch_size, model_name))

    input_ = config.input[0]
    output = config.output[0]
    if input_.format == model_config_pb2.ModelInput.FORMAT_NHWC:
        h = input_.dims[0]
        w = input_.dims[1]
        c = input_.dims[2]
    else:
        c = input_.dims[0]
        h = input_.dims[1]
        w = input_.dims[2]

    return input_.name, output.name, c, h, w, input_.format, tensor_dtype_to_np_dtype(input_.data_type)
コード例 #2
0
ファイル: image_client.py プロジェクト: zhaohainan666/Adlik
def parse_model(config, model_name, batch_size):
    """
    Check the configuration of a model to make sure it meets the
    requirements for an image classification network (as expected by
    this client)
    """
    if config.name != model_name:
        raise Exception("Get model config exception, config name: {}".format(config.name))

    if len(config.input) != 1:
        raise Exception("expecting 1 input, got {}".format(len(config.input)))
    if len(config.output) != 1:
        raise Exception("expecting 1 output, got {}".format(len(config.output)))

    input_ = config.input[0]
    output = config.output[0]

    if output.data_type != types_pb2.DT_FLOAT:
        raise Exception("expecting output data type to be DT_FLOAT, model '" +
                        model_name + "' output type is " +
                        model_config.DataType.Name(output.data_type))

    # Output is expected to be a vector. But allow any number of
    # dimensions as long as all but 1 is size 1 (e.g. { 10 }, { 1, 10
    # }, { 10, 1, 1 } are all ok).
    non_one_cnt = 0
    for dim in output.dims:
        if dim > 1:
            non_one_cnt += 1
            if non_one_cnt > 1:
                raise Exception("expecting model output to be a vector")

    # Model specifying maximum batch size of 0 indicates that batching
    # is not supported and so the input tensors do not expect an "N"
    # dimension (and 'batch_size' should be 1 so that only a single
    # image instance is inferred at a time).
    max_batch_size = config.max_batch_size
    if max_batch_size == 0:
        if batch_size != 1:
            raise Exception("batching not supported for model '" + model_name + "'")
    else:  # max_batch_size > 0
        if batch_size > max_batch_size:
            raise Exception(
                "expecting batch size <= {} for model '{}'".format(max_batch_size, model_name))

    # Model input must have 3 dims, either CHW or HWC
    if len(input_.dims) != 3:
        raise Exception(
            "expecting input to have 3 dimensions, model '{}' input has {}".format(
                model_name, len(input_.dims)))

    if ((input_.format != model_config.ModelInput.FORMAT_NCHW) and
            (input_.format != model_config.ModelInput.FORMAT_NHWC)):
        raise Exception("unexpected input format " + model_config.ModelInput.Format.Name(input_.format) +
                        ", expecting " +
                        model_config.ModelInput.Format.Name(model_config.ModelInput.FORMAT_NCHW) +
                        " or " +
                        model_config.ModelInput.Format.Name(model_config.ModelInput.FORMAT_NHWC))

    if input_.format == model_config.ModelInput.FORMAT_NHWC:
        h = input_.dims[0]
        w = input_.dims[1]
        c = input_.dims[2]
    else:
        c = input_.dims[0]
        h = input_.dims[1]
        w = input_.dims[2]

    output_size = 1
    for dim in output.dims:
        output_size = output_size * dim
    output_size = output_size * np.dtype(tensor_dtype_to_np_dtype(output.data_type)).itemsize

    return input_.name, output.name, c, h, w, input_.format, tensor_dtype_to_np_dtype(input_.data_type), output_size