예제 #1
0
def register(stub, model_name, mar_set_str):
    mar_set = set()
    if mar_set_str:
        mar_set = set(mar_set_str.split(','))
    marfile = f"{model_name}.mar"
    print(f"## Check {marfile} in mar_set :", mar_set)
    if marfile not in mar_set:
        marfile = "https://torchserve.s3.amazonaws.com/mar_files/{}.mar".format(
            model_name)

    print(f"## Register marfile:{marfile}\n")
    params = {
        'url': marfile,
        'initial_workers': 1,
        'synchronous': True,
        'model_name': model_name
    }
    try:
        response = stub.RegisterModel(
            management_pb2.RegisterModelRequest(**params))
        print(f"Model {model_name} registered successfully")
    except grpc.RpcError as e:
        print(f"Failed to register model {model_name}.")
        print(str(e.details()))
        exit(1)
예제 #2
0
def test_inference_apis():
    with open(os.path.dirname(__file__) + inference_data_json, 'rb') as f:
        test_data = json.loads(f.read())

    for item in test_data:
        if item['url'].startswith('{{mar_path_'):
            path = test_utils.mar_file_table[item['url'][2:-2]]
        else:
            path = item['url']

        managment_stub = test_gRPC_utils.get_management_stub()
        response = managment_stub.RegisterModel(
            management_pb2.RegisterModelRequest(url=path,
                                                initial_workers=item['worker'],
                                                synchronous=bool(
                                                    item['synchronous']),
                                                model_name=item['model_name']))

        print(response.msg)

        model_input = os.path.dirname(__file__) + "/../" + item['file']
        prediction = __infer(test_gRPC_utils.get_inference_stub(),
                             item['model_name'], model_input)

        print("Prediction is : ", str(prediction))

        if 'expected' in item:
            try:
                prediction = literal_eval(prediction)
            except SyntaxError:
                pass

            if isinstance(prediction, list) and 'tolerance' in item:
                assert len(prediction) == len(item['expected'])
                for i in range(len(prediction)):
                    assert __get_change(
                        prediction[i], item['expected'][i]) < item['tolerance']
            elif isinstance(prediction, dict) and 'tolerance' in item:
                assert len(prediction) == len(item['expected'])
                for key in prediction:
                    assert __get_change(
                        prediction[key],
                        item['expected'][key]) < item['tolerance']
            else:
                assert str(prediction) == str(item['expected'])

        response = managment_stub.UnregisterModel(
            management_pb2.UnregisterModelRequest(
                model_name=item['model_name'], ))

        print(response.msg)
예제 #3
0
def register(stub, model_name, local=False):
    if local:
        url = "{}".format(model_name)
    else:
        url = "https://torchserve.s3.amazonaws.com/mar_files/{}.mar".format(
            model_name)
    params = {
        'url': url,
        'initial_workers': 1,
        'synchronous': True,
        'model_name': model_name.replace(".mar", "")
    }
    response = stub.RegisterModel(
        management_pb2.RegisterModelRequest(**params))
    return response
예제 #4
0
def register(stub, model_name, url):
    print("Registering ", model_name)
    params = {
        'url': url,
        'initial_workers': 1,
        'synchronous': True,
        'model_name': model_name
    }
    try:
        response = stub.RegisterModel(
            management_pb2.RegisterModelRequest(**params))
        print(f"Model {model_name} registered successfully")
    except grpc.RpcError as e:
        print(f"Failed to register model {model_name}.")
        print(str(e.details()))
        exit(1)
예제 #5
0
def register(stub, model_name):
    params = {
        'url':
        "https://torchserve.s3.amazonaws.com/mar_files/{}.mar".format(
            model_name),
        'initial_workers':
        1,
        'synchronous':
        True,
        'model_name':
        model_name
    }
    try:
        response = stub.RegisterModel(
            management_pb2.RegisterModelRequest(**params))
        print(f"Model {model_name} registered successfully")
    except grpc.RpcError as e:
        print(f"Failed to register model {model_name}.")
        print(str(e.details()))
        exit(1)