Esempio n. 1
0
    def get_net_runtime(
        self,
        model,
        blobs,
        input_dims,
        input_dtype=core.DataType.FLOAT,
        device="SM-G950U-7.0-24",
        mode="ave",
        verbose=True,
    ):
        """
        Calculated the run-time latency for a given model from the loop-up table
        Args:
            model: caffe2 pb
            input_dims: input dimensions
            device: device used for benchmarking
            wait: if not found, whether to wait for the result
        """
        db_file = os.path.join(
            self.db_dir, device + "-" + "op_lut_database.json"
        )

        if self.current_db != db_file:
            self.current_db = db_file
            self.op_lut.load(db_file)

        ops, input_shapes, input_dtypes = get_ops_from_net(
            model, blobs, input_dims
        )

        for i in range(len(input_shapes)):
            input_shapes[i] = map(lambda x: list(x), input_shapes[i])

        devices = [device for _ in range(len(input_shapes))]

        while True:
            op_time, net_time, all_found = [], 0, True
            for i, op in enumerate(ops):
                op_query = LUTSchema()
                load_caffe2_op(
                    op_query,
                    op,
                    input_shapes=input_shapes[i],
                    input_dtypes=input_dtypes[i],
                    device=devices[i],
                    model_name=str(model),
                )

                latency = self.get_op_runtime(
                    op_query, mode=mode, verbose=verbose
                )

                op_time.append({"op": op_query, "time": latency})

                if latency is not None:
                    net_time += latency
                else:
                    all_found = False

            return op_time, net_time, all_found
Esempio n. 2
0
    def load(self, dbfile):
        """
        Load a database file and convert all records to LUTSchema
        Args:
            dbfile: path to the database file
        """
        self.ops = []
        assert os.path.exists(dbfile)

        db_file = open(dbfile, "r")
        db = db_file.readlines()

        for record in db:
            op = LUTSchema()
            op.load_from_json(json.loads(record))
            self.ops.append(op)

        db_file.close()