def _match(self, value, settings): assert self not in settings.dict_so_far, "Operation cannot be matched multiple times" assert isinstance(value, BaseOperation) op = value if not settings.allow_multi_consumer and any( len(r.consumers) > 1 for r in op.outputs): return Match() if self.name is not None and op.name not in utils.listify(self.name): return Match() match_ = Match(True, root=op, dict_={self: op}) if self.inputs is not None: match2 = Match() for input_patterns in self._pattern_list_list( self.inputs, op.inputs): match2 = self._match_inputs(op, settings, input_patterns) if match2: break if not match2: return Match() match_ = Match(True, root=op, dict_=utils.dict_union(match_.dict, match2.dict)) if self.attribs is not None: assert isinstance(self.attribs, dict) match2 = self._match_attribs(op, settings, self.attribs) if not match2: return Match() match_ = Match(True, root=op, dict_=utils.dict_union(match_.dict, match2.dict)) if self.outputs is not None: match2 = Match() for output_patterns in self._pattern_list_list( self.outputs, op.outputs): match2 = self._match_outputs(op, settings, output_patterns) if match2: break if not match2: return Match() match_ = Match(True, root=op, dict_=utils.dict_union(match_.dict, match2.dict)) return match_
def transform_fuse_activations(tf_graph): # type: (TFGraph)->None fuse_to = [ "tf.add", "tf.subtract", "tf.multiply", "tf.divide", "tf.nn.conv2d", "tf.nn.depthwise_conv2d", "tf.nn.max_pool", "tf.nn.avg_pool", # "tf.nn.conv2d_transpose", (not working yet) "tf.matmul", "tf.nn.l2_normalize", # "tf.concat" (not working yet) ] conv_output = matcher.Tensor() convlike = matcher.Operation(name=fuse_to, outputs=conv_output) activation = matcher.Operation(name="tf.nn.relu", inputs={0: conv_output}) matcher.replace( tf_graph, activation, lambda m: TFOperation( graph=tf_graph, name=m[convlike].name, attribs=utils.dict_union(m[convlike].attribs, dict(fused_activation_function='RELU')), inputs=m[convlike].inputs, outputs=m[activation].outputs), lambda m: not m[convlike].attribs.get('fused_activation_function')) conv_output = matcher.Tensor() convlike = matcher.Operation(name=fuse_to, outputs=conv_output) activation = matcher.Operation(name="tf.clip_by_value", inputs={0: conv_output}) matcher.replace( graph=tf_graph, pattern=activation, replacement=lambda m: TFOperation( graph=tf_graph, name=m[convlike].name, attribs=utils.dict_union(m[convlike].attribs, dict(fused_activation_function='RELU6')), inputs=m[convlike].inputs, outputs=m[activation].outputs), condition=lambda m: (m[activation].inputs[1].data == [0] and m[activation].inputs[2].data == [6] and not m[convlike].attribs.get('fused_activation_function')))
def _create_lp_pools(g): # type: (NNEFGraph)->None input, abs_out, pow_out, box_out, output, p, q = matcher.tensors(7) _abs_op = matcher.Operation(name='abs', inputs=input, outputs=abs_out) _pow_op = matcher.Operation(name='pow', inputs=(abs_out, p), outputs=pow_out) box_op = matcher.Operation(name='box', inputs=pow_out, outputs=box_out) pow2_op = matcher.Operation(name='pow', inputs=(box_out, q), outputs=output) matcher.replace( g, pow2_op, lambda m: NNEFOperation(graph=g, name="_lp_pool", inputs=m[input], outputs=m[output], attribs=utils.dict_union( m[box_op].attribs, dict(p=float(m[p].get_numpy_array( ).item())))), lambda m: (m[p].rank == 0 and m[p].data is not None and m[ p].get_numpy_array().item() != 0 and m[q].rank == 0 and m[q].data is not None and m[p].get_numpy_array().item() != 0 and np.allclose( 1.0 / m[p].get_numpy_array().item(), m[q].get_numpy_array().item( ))))
def replacement(m): CaffeOperation(graph=g, name=m[op2].name, inputs=m[input], outputs=m[output2], attribs=utils.dict_union(m[op2].attribs, m[op1].attribs))
def replacement(m): CaffeOperation(graph=g, name=m[power_op].name, inputs=m[input], outputs=m[powered], attribs=utils.dict_union(m[scale_op].attribs, m[shift_op].attribs, m[power_op].attribs))
def _match(self, value, settings): new_settings = settings.copy( allow_multi_consumer_inside=self._allow_multi_consumer_inside) match_ = self._pattern._match(value, new_settings) if not match_: return match_ return Match(did_match=True, root=match_.root, dict_=utils.dict_union(match_.dict, {self: match_.root}))
def _match(self, value, settings): assert isinstance(value, BaseTensor) if settings.dict_so_far.get(self, value) != value: return Match if self._producer_pattern and settings.follow_producer: if value.producer is None: return Match() match = self._producer_pattern._match( value.producer, settings.copy(dict_so_far=utils.dict_union(settings.dict_so_far, {self: value}))) if not match: return Match() return Match(did_match=True, root=value, dict_=utils.dict_union(match.dict, {self: value})) else: return Match(did_match=True, root=value, dict_={self: value})
def forward(self, *inputs): activation_tensors = _RefCountedDict({ t.name: (sum(input is t for consumer in t.consumers for input in consumer.inputs) + (1 if t in self._nnef_graph.outputs else 0)) for t in self._nnef_graph.tensors}) def get_tensor(name): if hasattr(self, self._safe_name(name)): return getattr(self, self._safe_name(name)) else: return activation_tensors[name] def has_tensor(name): return hasattr(self, self._safe_name(name)) or name in activation_tensors assert len(inputs) == len(self._nnef_graph.inputs) for torch_tensor, nnef_tensor in zip(inputs, self._nnef_graph.inputs): activation_tensors.ready(nnef_tensor.name, torch_tensor) utils.call_each(self._tensor_hooks, nnef_tensor, torch_tensor) if self._tensor_hooks: for nnef_tensor in self._nnef_graph.tensors: if nnef_tensor.is_constant or nnef_tensor.is_variable: utils.call_each(self._tensor_hooks, nnef_tensor, get_tensor(nnef_tensor.name)) for op in self._nnef_graph.operations: if op.name not in self._operations: raise utils.NNEFToolsException("Unsupported operation: {}".format(op.name)) fun = self._operations[op.name] assert all(has_tensor(t.name) for t in op.inputs) if isinstance(op.inputs, tuple): inputs = tuple(get_tensor(t.name) for t in op.inputs) else: inputs = ([get_tensor(t.name) for t in op.inputs],) outputs = fun(*inputs, **utils.dict_union(op.attribs, self._get_extra_attributes(op.name))) if not isinstance(outputs, (list, tuple)): outputs = (outputs,) for t, output in zip(op.outputs, outputs): activation_tensors.ready(t.name, output) utils.call_each(self._tensor_hooks, t, output) for t in op.inputs: if not t.is_constant and not t.is_variable: activation_tensors.release(t.name) outputs = [get_tensor(t.name) for t in self._nnef_graph.outputs] for t in self._nnef_graph.outputs: activation_tensors.release(t.name) assert not activation_tensors, "Reference counting error in PyTorch NNEF Backend" return tuple(outputs)
def _match(self, value, settings): assert self not in settings.dict_so_far, "OrPattern cannot be matched multiple times" match_ = Match() for pattern in self.patterns: match_ = pattern._match(value, settings) if match_: break if not match_: return match_ return Match(did_match=True, root=match_.root, dict_=utils.dict_union(match_.dict, {self: match_.root}))
def _match_inputs(self, op, settings, input_patterns): if len(op.inputs) != len(input_patterns): return Match() dict_ = {self: op} for input, input_pattern in zip(op.inputs, input_patterns): # noinspection PyProtectedMember match_ = input_pattern._match(input, settings.copy(allow_multi_consumer=settings.allow_multi_consumer_inside, dict_so_far=utils.dict_union(settings.dict_so_far, dict_))) if not match_: return Match() dict_.update(match_.dict) return Match(did_match=True, root=op, dict_=dict_)
def _match_attribs(self, op, settings, attrib_patterns): def trafo(arg): return arg if isinstance(arg, Pattern) else _Const(arg) attrib_patterns = utils.recursive_transform(attrib_patterns, trafo) # type: typing.Dict[str, Pattern] dict_ = {self: op} for attrib_name, attrib_pattern in six.iteritems(attrib_patterns): attrib_value = op.attribs[attrib_name] match_ = attrib_pattern._match(attrib_value, settings.copy(allow_multi_consumer=settings.allow_multi_consumer_inside, dict_so_far=utils.dict_union(settings.dict_so_far, dict_))) if not match_: return Match() dict_.update(match_.dict) return Match(did_match=True, root=op, dict_=dict_)