Пример #1
0
def run_as_local_main():
    args = parse_infer_args()
    sm_boto3 = boto3.client('sagemaker')
    sess = sagemaker.Session()
    region = sess.boto_session.region_name
    model_url = args.model_file
    model = PyTorchModel(model_data=model_url,
                         source_dir=os.path.abspath(os.path.dirname(__file__)),
                         role=get_sm_execution_role(ON_SAGEMAKER_NOTEBOOK,
                                                    region),
                         framework_version='1.0.0',
                         entry_point='inference.py')

    infer_mode = args.infer_mode
    if 'bt' == infer_mode:
        env = {'MODEL_SERVER_TIMEOUT': '120'}
        transformer = model.transformer(
            instance_count=1,
            instance_type='ml.c5.xlarge',
            output_path=args.output_dir,
            max_payload=99,
            env=env,
            max_concurrent_transforms=1,
            tags=[{
                "Key": "Project",
                "Value": "SM Example"
            }],
        )
        transformer.transform(args.input_file, content_type="text/csv")
        transformer.wait()
    elif 'ep' == infer_mode:
        model.deploy(instance_type='ml.c5.xlarge', initial_instance_count=1)
    else:
        raise Exception(f'Unknown inference mode {infer_mode}')
def batch_inference(session, client, model_name, setting, pytorch):
    sagemaker_session = sagemaker.Session(boto_session=session,
                                          sagemaker_client=client)

    conf = yaml.load(open(setting))

    # check the target model exists
    if _model_exists(client, model_name):
        logger.info('use the registered model.')
        deploy_args = conf['deploy']
        deploy_args['model_name'] = model_name
        deploy_args['base_transform_job_name'] = model_name
        deploy_args['sagemaker_session'] = sagemaker_session

        transformer = Transformer(**deploy_args)

    else:
        # [TODO] updateing case (delete and create).
        # Basically, models have dependencies on multiple endpoints and inference jobs,
        # so it is not easy to delete it.
        logger.info('register the new model.')
        model_args = conf['model']
        model_args['sagemaker_session'] = sagemaker_session
        model_args['name'] = model_name
        if pytorch:
            model = PyTorchModel(**model_args)
        else:
            model = ChainerModel(**model_args)

        deploy_args = conf['deploy']
        transformer = model.transformer(**deploy_args)  # register model

    transform_args = conf['transform']
    # use default job_name (model_name + datetime.now())
    transformer.transform(**transform_args)