示例#1
0
    def getOperators(self):
        # map: op_name -> operator
        self.operators = {}
        # map: op_name -> [engine, engine]
        self.engines = {}

        def filePriority(x):
            if x == "caffe2/caffe2/operators":
                return 0
            if 'contrib' in x.split('/'):
                return 2
            if 'experiments' in x.split('/'):
                return 3
            return 1

        for name in core._GetRegisteredOperators():
            schema = OpSchema.get(name)
            if schema:
                priority = filePriority(os.path.dirname(schema.file))
                operator = self.getOperatorDoc(name, schema, priority)
                self.operators[name] = operator

            # Engine
            elif name.find("_ENGINE_") != -1:
                engine = self.getOperatorEngine(name)
                if engine.base_op_name in self.engines:
                    self.engines[engine.base_op_name].append(engine)
                else:
                    self.engines[engine.base_op_name] = [engine]

            # No schema
            else:
                priority = 4
                self.operators[name] = self.getOperatorDoc(
                    name, schema, priority)

        for name, engines in self.engines.items():
            if name in self.operators:
                self.operators[name].addEngines(engines)

        # Generate a sorted list of operators
        operators = [v for k, v in self.operators.items()]

        def compare(op1, op2):
            if op1.priority == op2.priority:
                if op1.name < op2.name:
                    return -1
                else:
                    return 1
            return op1.priority - op2.priority

        return sorted(operators, cmp=compare)
示例#2
0
def gen_coverage_sets(source_dir):
    covered_ops = gen_covered_ops(source_dir)

    not_covered_ops = set()
    schemaless_ops = []
    for op_name in core._GetRegisteredOperators():
        s = OpSchema.get(op_name)

        if s is not None and s.private:
            continue
        if s:
            if op_name not in covered_ops:
                not_covered_ops.add(op_name)
        else:
            if op_name.find("_ENGINE_") == -1:
                schemaless_ops.append(op_name)
    return (covered_ops, not_covered_ops, schemaless_ops)
示例#3
0
    def getOperators(self):
        # map: op_name -> operator
        self.operators = {}
        # map: op_name -> [engine, engine]
        self.engines = {}

        def filePriority(x):
            if x == "caffe2/caffe2/operators":
                return 0
            if 'contrib' in x.split('/'):
                return 2
            if 'experiments' in x.split('/'):
                return 3
            return 1

        for name in core._GetRegisteredOperators():
            schema = OpSchema.get(name)
            if schema:
                priority = filePriority(os.path.dirname(schema.file))
                operator = self.getOperatorDoc(name, schema, priority)
                self.operators[name] = operator

            # Engine
            elif name.find("_ENGINE_") != -1:
                engine = self.getOperatorEngine(name)
                if engine.base_op_name in self.engines:
                    self.engines[engine.base_op_name].append(engine)
                else:
                    self.engines[engine.base_op_name] = [engine]

            # No schema
            else:
                priority = 4
                self.operators[name] = self.getOperatorDoc(name, schema, priority)

        for name, engines in viewitems(self.engines):
            if name in self.operators:
                self.operators[name].addEngines(engines)

        # Generate a sorted list of operators
        return sorted(
            viewvalues(self.operators),
            key=lambda op: (op.priority, op.name)
        )