def get_net_runtime( self, model, blobs, input_dims, input_dtype=core.DataType.FLOAT, device="SM-G950U-7.0-24", mode="ave", verbose=True, ): """ Calculated the run-time latency for a given model from the loop-up table Args: model: caffe2 pb input_dims: input dimensions device: device used for benchmarking wait: if not found, whether to wait for the result """ db_file = os.path.join( self.db_dir, device + "-" + "op_lut_database.json" ) if self.current_db != db_file: self.current_db = db_file self.op_lut.load(db_file) ops, input_shapes, input_dtypes = get_ops_from_net( model, blobs, input_dims ) for i in range(len(input_shapes)): input_shapes[i] = map(lambda x: list(x), input_shapes[i]) devices = [device for _ in range(len(input_shapes))] while True: op_time, net_time, all_found = [], 0, True for i, op in enumerate(ops): op_query = LUTSchema() load_caffe2_op( op_query, op, input_shapes=input_shapes[i], input_dtypes=input_dtypes[i], device=devices[i], model_name=str(model), ) latency = self.get_op_runtime( op_query, mode=mode, verbose=verbose ) op_time.append({"op": op_query, "time": latency}) if latency is not None: net_time += latency else: all_found = False return op_time, net_time, all_found
def load(self, dbfile): """ Load a database file and convert all records to LUTSchema Args: dbfile: path to the database file """ self.ops = [] assert os.path.exists(dbfile) db_file = open(dbfile, "r") db = db_file.readlines() for record in db: op = LUTSchema() op.load_from_json(json.loads(record)) self.ops.append(op) db_file.close()