def local_predict(input_data, model_dir): """Runs prediction locally. Args: input_data: list of input files to run prediction on. model_dir: path to Tensorflow model folder. """ session, _ = session_bundle.load_session_bundle_from_path(model_dir) # get the mappings between aliases and tensor names # for both inputs and outputs input_alias_map = json.loads(session.graph.get_collection('inputs')[0]) output_alias_map = json.loads(session.graph.get_collection('outputs')[0]) aliases, tensor_names = zip(*output_alias_map.items()) metadata_path = os.path.join(model_dir, 'metadata.yaml') transformer = features.FeatureProducer(metadata_path) for input_file in input_data: with open(input_file) as f: feed_dict = collections.defaultdict(list) for line in f: preprocessed = transformer.preprocess(line) feed_dict[input_alias_map.values()[0]].append( preprocessed.SerializeToString()) result = session.run(fetches=tensor_names, feed_dict=feed_dict) for row in zip(*result): print json.dumps({ name: (value.tolist() if getattr(value, 'tolist', None) else value) for name, value in zip(aliases, row) })
def from_client(cls, client, model_path, skip_preprocessing=False): preprocess_fn = None if not skip_preprocessing: metadata_path = _get_metadata_path(model_path) if metadata_path: feature_producer = features.FeatureProducer(metadata_path) if feature_producer: preprocess_fn = feature_producer.preprocess return cls(client, preprocess_fn)