def run_as_local_main(): args = parse_infer_args() sm_boto3 = boto3.client('sagemaker') sess = sagemaker.Session() region = sess.boto_session.region_name model_url = args.model_file model = PyTorchModel(model_data=model_url, source_dir=os.path.abspath(os.path.dirname(__file__)), role=get_sm_execution_role(ON_SAGEMAKER_NOTEBOOK, region), framework_version='1.0.0', entry_point='inference.py') infer_mode = args.infer_mode if 'bt' == infer_mode: env = {'MODEL_SERVER_TIMEOUT': '120'} transformer = model.transformer( instance_count=1, instance_type='ml.c5.xlarge', output_path=args.output_dir, max_payload=99, env=env, max_concurrent_transforms=1, tags=[{ "Key": "Project", "Value": "SM Example" }], ) transformer.transform(args.input_file, content_type="text/csv") transformer.wait() elif 'ep' == infer_mode: model.deploy(instance_type='ml.c5.xlarge', initial_instance_count=1) else: raise Exception(f'Unknown inference mode {infer_mode}')
def batch_inference(session, client, model_name, setting, pytorch): sagemaker_session = sagemaker.Session(boto_session=session, sagemaker_client=client) conf = yaml.load(open(setting)) # check the target model exists if _model_exists(client, model_name): logger.info('use the registered model.') deploy_args = conf['deploy'] deploy_args['model_name'] = model_name deploy_args['base_transform_job_name'] = model_name deploy_args['sagemaker_session'] = sagemaker_session transformer = Transformer(**deploy_args) else: # [TODO] updateing case (delete and create). # Basically, models have dependencies on multiple endpoints and inference jobs, # so it is not easy to delete it. logger.info('register the new model.') model_args = conf['model'] model_args['sagemaker_session'] = sagemaker_session model_args['name'] = model_name if pytorch: model = PyTorchModel(**model_args) else: model = ChainerModel(**model_args) deploy_args = conf['deploy'] transformer = model.transformer(**deploy_args) # register model transform_args = conf['transform'] # use default job_name (model_name + datetime.now()) transformer.transform(**transform_args)