def getOperators(self): # map: op_name -> operator self.operators = {} # map: op_name -> [engine, engine] self.engines = {} def filePriority(x): if x == "caffe2/caffe2/operators": return 0 if 'contrib' in x.split('/'): return 2 if 'experiments' in x.split('/'): return 3 return 1 for name in core._GetRegisteredOperators(): schema = OpSchema.get(name) if schema: priority = filePriority(os.path.dirname(schema.file)) operator = self.getOperatorDoc(name, schema, priority) self.operators[name] = operator # Engine elif name.find("_ENGINE_") != -1: engine = self.getOperatorEngine(name) if engine.base_op_name in self.engines: self.engines[engine.base_op_name].append(engine) else: self.engines[engine.base_op_name] = [engine] # No schema else: priority = 4 self.operators[name] = self.getOperatorDoc( name, schema, priority) for name, engines in self.engines.items(): if name in self.operators: self.operators[name].addEngines(engines) # Generate a sorted list of operators operators = [v for k, v in self.operators.items()] def compare(op1, op2): if op1.priority == op2.priority: if op1.name < op2.name: return -1 else: return 1 return op1.priority - op2.priority return sorted(operators, cmp=compare)
def gen_coverage_sets(source_dir): covered_ops = gen_covered_ops(source_dir) not_covered_ops = set() schemaless_ops = [] for op_name in core._GetRegisteredOperators(): s = OpSchema.get(op_name) if s is not None and s.private: continue if s: if op_name not in covered_ops: not_covered_ops.add(op_name) else: if op_name.find("_ENGINE_") == -1: schemaless_ops.append(op_name) return (covered_ops, not_covered_ops, schemaless_ops)
def getOperators(self): # map: op_name -> operator self.operators = {} # map: op_name -> [engine, engine] self.engines = {} def filePriority(x): if x == "caffe2/caffe2/operators": return 0 if 'contrib' in x.split('/'): return 2 if 'experiments' in x.split('/'): return 3 return 1 for name in core._GetRegisteredOperators(): schema = OpSchema.get(name) if schema: priority = filePriority(os.path.dirname(schema.file)) operator = self.getOperatorDoc(name, schema, priority) self.operators[name] = operator # Engine elif name.find("_ENGINE_") != -1: engine = self.getOperatorEngine(name) if engine.base_op_name in self.engines: self.engines[engine.base_op_name].append(engine) else: self.engines[engine.base_op_name] = [engine] # No schema else: priority = 4 self.operators[name] = self.getOperatorDoc(name, schema, priority) for name, engines in viewitems(self.engines): if name in self.operators: self.operators[name].addEngines(engines) # Generate a sorted list of operators return sorted( viewvalues(self.operators), key=lambda op: (op.priority, op.name) )