def _unify_attrs(attrs, op_proto): # type: (typing.Dict[str, typing.Any], OpProto)-> typing.Dict[str, typing.Any] dict2 = {} for arg_proto in op_proto.list_nontensor_arg_protos(): val = attrs[arg_proto.primary_arg_name] if arg_proto.is_array and val is not None: val = utils.listify(val) dict2[arg_proto.primary_arg_name] = val return dict2
def generic_convert_reduce(op, target_name): # type: (TFOperation, str)->None op.name = target_name assert op.inputs[ 1].data is not None, "{} is only supported with constant axes".format( op.name) axes = utils.listify(op.inputs[1].data.tolist()) op.inputs = (op.inputs[0], ) op.attribs = dict(axis=axes, keepdims=op.attribs["keep_dims"])
def _match(self, value, settings): assert self not in settings.dict_so_far, "Operation cannot be matched multiple times" assert isinstance(value, BaseOperation) op = value if not settings.allow_multi_consumer and any( len(r.consumers) > 1 for r in op.outputs): return Match() if self.name is not None and op.name not in utils.listify(self.name): return Match() match_ = Match(True, root=op, dict_={self: op}) if self.inputs is not None: match2 = Match() for input_patterns in self._pattern_list_list( self.inputs, op.inputs): match2 = self._match_inputs(op, settings, input_patterns) if match2: break if not match2: return Match() match_ = Match(True, root=op, dict_=utils.dict_union(match_.dict, match2.dict)) if self.attribs is not None: assert isinstance(self.attribs, dict) match2 = self._match_attribs(op, settings, self.attribs) if not match2: return Match() match_ = Match(True, root=op, dict_=utils.dict_union(match_.dict, match2.dict)) if self.outputs is not None: match2 = Match() for output_patterns in self._pattern_list_list( self.outputs, op.outputs): match2 = self._match_outputs(op, settings, output_patterns) if match2: break if not match2: return Match() match_ = Match(True, root=op, dict_=utils.dict_union(match_.dict, match2.dict)) return match_
def _get_arg(op_name, args, arg_proto): # type: (str, typing.Dict[str, typing.Any], ArgProto)->typing.Any found = True value = None for arg_name in arg_proto.arg_names: if arg_name in args: found = True if args[arg_name] is not None: assert value is None if arg_proto.is_array and args[arg_name] is not None: value = utils.listify(args[arg_name]) else: value = args[arg_name] if found: return value if arg_proto.is_optional: return None assert False, "Arg '{}' not found for op '{}'".format(arg_proto.primary_arg_name, op_name)