Ejemplo n.º 1
0
    def _match(self, value, settings):
        assert self not in settings.dict_so_far, "Operation cannot be matched multiple times"

        assert isinstance(value, BaseOperation)
        op = value

        if not settings.allow_multi_consumer and any(
                len(r.consumers) > 1 for r in op.outputs):
            return Match()

        if self.name is not None and op.name not in utils.listify(self.name):
            return Match()

        match_ = Match(True, root=op, dict_={self: op})

        if self.inputs is not None:
            match2 = Match()
            for input_patterns in self._pattern_list_list(
                    self.inputs, op.inputs):
                match2 = self._match_inputs(op, settings, input_patterns)
                if match2:
                    break
            if not match2:
                return Match()
            match_ = Match(True,
                           root=op,
                           dict_=utils.dict_union(match_.dict, match2.dict))

        if self.attribs is not None:
            assert isinstance(self.attribs, dict)
            match2 = self._match_attribs(op, settings, self.attribs)
            if not match2:
                return Match()
            match_ = Match(True,
                           root=op,
                           dict_=utils.dict_union(match_.dict, match2.dict))

        if self.outputs is not None:
            match2 = Match()
            for output_patterns in self._pattern_list_list(
                    self.outputs, op.outputs):
                match2 = self._match_outputs(op, settings, output_patterns)
                if match2:
                    break
            if not match2:
                return Match()
            match_ = Match(True,
                           root=op,
                           dict_=utils.dict_union(match_.dict, match2.dict))

        return match_
Ejemplo n.º 2
0
def transform_fuse_activations(tf_graph):
    # type: (TFGraph)->None

    fuse_to = [
        "tf.add",
        "tf.subtract",
        "tf.multiply",
        "tf.divide",
        "tf.nn.conv2d",
        "tf.nn.depthwise_conv2d",
        "tf.nn.max_pool",
        "tf.nn.avg_pool",
        # "tf.nn.conv2d_transpose", (not working yet)
        "tf.matmul",
        "tf.nn.l2_normalize",
        # "tf.concat" (not working yet)
    ]

    conv_output = matcher.Tensor()
    convlike = matcher.Operation(name=fuse_to, outputs=conv_output)
    activation = matcher.Operation(name="tf.nn.relu", inputs={0: conv_output})

    matcher.replace(
        tf_graph, activation, lambda m: TFOperation(
            graph=tf_graph,
            name=m[convlike].name,
            attribs=utils.dict_union(m[convlike].attribs,
                                     dict(fused_activation_function='RELU')),
            inputs=m[convlike].inputs,
            outputs=m[activation].outputs),
        lambda m: not m[convlike].attribs.get('fused_activation_function'))

    conv_output = matcher.Tensor()
    convlike = matcher.Operation(name=fuse_to, outputs=conv_output)
    activation = matcher.Operation(name="tf.clip_by_value",
                                   inputs={0: conv_output})

    matcher.replace(
        graph=tf_graph,
        pattern=activation,
        replacement=lambda m: TFOperation(
            graph=tf_graph,
            name=m[convlike].name,
            attribs=utils.dict_union(m[convlike].attribs,
                                     dict(fused_activation_function='RELU6')),
            inputs=m[convlike].inputs,
            outputs=m[activation].outputs),
        condition=lambda m:
        (m[activation].inputs[1].data == [0] and m[activation].inputs[2].data
         == [6] and not m[convlike].attribs.get('fused_activation_function')))
Ejemplo n.º 3
0
def _create_lp_pools(g):
    # type: (NNEFGraph)->None

    input, abs_out, pow_out, box_out, output, p, q = matcher.tensors(7)

    _abs_op = matcher.Operation(name='abs', inputs=input, outputs=abs_out)
    _pow_op = matcher.Operation(name='pow',
                                inputs=(abs_out, p),
                                outputs=pow_out)
    box_op = matcher.Operation(name='box', inputs=pow_out, outputs=box_out)
    pow2_op = matcher.Operation(name='pow',
                                inputs=(box_out, q),
                                outputs=output)

    matcher.replace(
        g,
        pow2_op, lambda m: NNEFOperation(graph=g,
                                         name="_lp_pool",
                                         inputs=m[input],
                                         outputs=m[output],
                                         attribs=utils.dict_union(
                                             m[box_op].attribs,
                                             dict(p=float(m[p].get_numpy_array(
                                             ).item())))), lambda m:
        (m[p].rank == 0 and m[p].data is not None and m[
            p].get_numpy_array().item() != 0 and m[q].rank == 0 and m[q].data
         is not None and m[p].get_numpy_array().item() != 0 and np.allclose(
             1.0 / m[p].get_numpy_array().item(), m[q].get_numpy_array().item(
             ))))
Ejemplo n.º 4
0
 def replacement(m):
     CaffeOperation(graph=g,
                    name=m[op2].name,
                    inputs=m[input],
                    outputs=m[output2],
                    attribs=utils.dict_union(m[op2].attribs,
                                             m[op1].attribs))
Ejemplo n.º 5
0
 def replacement(m):
     CaffeOperation(graph=g,
                    name=m[power_op].name,
                    inputs=m[input],
                    outputs=m[powered],
                    attribs=utils.dict_union(m[scale_op].attribs,
                                             m[shift_op].attribs,
                                             m[power_op].attribs))
Ejemplo n.º 6
0
 def _match(self, value, settings):
     new_settings = settings.copy(
         allow_multi_consumer_inside=self._allow_multi_consumer_inside)
     match_ = self._pattern._match(value, new_settings)
     if not match_:
         return match_
     return Match(did_match=True,
                  root=match_.root,
                  dict_=utils.dict_union(match_.dict, {self: match_.root}))
Ejemplo n.º 7
0
    def _match(self, value, settings):
        assert isinstance(value, BaseTensor)

        if settings.dict_so_far.get(self, value) != value:
            return Match

        if self._producer_pattern and settings.follow_producer:
            if value.producer is None:
                return Match()

            match = self._producer_pattern._match(
                value.producer,
                settings.copy(dict_so_far=utils.dict_union(settings.dict_so_far, {self: value})))

            if not match:
                return Match()

            return Match(did_match=True, root=value, dict_=utils.dict_union(match.dict, {self: value}))
        else:
            return Match(did_match=True, root=value, dict_={self: value})
Ejemplo n.º 8
0
    def forward(self, *inputs):
        activation_tensors = _RefCountedDict({
            t.name: (sum(input is t
                         for consumer in t.consumers
                         for input in consumer.inputs)
                     + (1 if t in self._nnef_graph.outputs else 0))
            for t in self._nnef_graph.tensors})

        def get_tensor(name):
            if hasattr(self, self._safe_name(name)):
                return getattr(self, self._safe_name(name))
            else:
                return activation_tensors[name]

        def has_tensor(name):
            return hasattr(self, self._safe_name(name)) or name in activation_tensors

        assert len(inputs) == len(self._nnef_graph.inputs)
        for torch_tensor, nnef_tensor in zip(inputs, self._nnef_graph.inputs):
            activation_tensors.ready(nnef_tensor.name, torch_tensor)
            utils.call_each(self._tensor_hooks, nnef_tensor, torch_tensor)

        if self._tensor_hooks:
            for nnef_tensor in self._nnef_graph.tensors:
                if nnef_tensor.is_constant or nnef_tensor.is_variable:
                    utils.call_each(self._tensor_hooks, nnef_tensor, get_tensor(nnef_tensor.name))

        for op in self._nnef_graph.operations:
            if op.name not in self._operations:
                raise utils.NNEFToolsException("Unsupported operation: {}".format(op.name))
            fun = self._operations[op.name]
            assert all(has_tensor(t.name) for t in op.inputs)
            if isinstance(op.inputs, tuple):
                inputs = tuple(get_tensor(t.name) for t in op.inputs)
            else:
                inputs = ([get_tensor(t.name) for t in op.inputs],)
            outputs = fun(*inputs, **utils.dict_union(op.attribs, self._get_extra_attributes(op.name)))
            if not isinstance(outputs, (list, tuple)):
                outputs = (outputs,)
            for t, output in zip(op.outputs, outputs):
                activation_tensors.ready(t.name, output)
                utils.call_each(self._tensor_hooks, t, output)

            for t in op.inputs:
                if not t.is_constant and not t.is_variable:
                    activation_tensors.release(t.name)

        outputs = [get_tensor(t.name) for t in self._nnef_graph.outputs]
        for t in self._nnef_graph.outputs:
            activation_tensors.release(t.name)

        assert not activation_tensors, "Reference counting error in PyTorch NNEF Backend"
        return tuple(outputs)
Ejemplo n.º 9
0
    def _match(self, value, settings):
        assert self not in settings.dict_so_far, "OrPattern cannot be matched multiple times"

        match_ = Match()

        for pattern in self.patterns:
            match_ = pattern._match(value, settings)
            if match_:
                break

        if not match_:
            return match_

        return Match(did_match=True, root=match_.root, dict_=utils.dict_union(match_.dict, {self: match_.root}))
Ejemplo n.º 10
0
    def _match_inputs(self, op, settings, input_patterns):
        if len(op.inputs) != len(input_patterns):
            return Match()

        dict_ = {self: op}
        for input, input_pattern in zip(op.inputs, input_patterns):
            # noinspection PyProtectedMember
            match_ = input_pattern._match(input,
                                          settings.copy(allow_multi_consumer=settings.allow_multi_consumer_inside,
                                                        dict_so_far=utils.dict_union(settings.dict_so_far, dict_)))
            if not match_:
                return Match()
            dict_.update(match_.dict)

        return Match(did_match=True, root=op, dict_=dict_)
Ejemplo n.º 11
0
    def _match_attribs(self, op, settings, attrib_patterns):
        def trafo(arg):
            return arg if isinstance(arg, Pattern) else _Const(arg)

        attrib_patterns = utils.recursive_transform(attrib_patterns, trafo)  # type: typing.Dict[str, Pattern]

        dict_ = {self: op}
        for attrib_name, attrib_pattern in six.iteritems(attrib_patterns):
            attrib_value = op.attribs[attrib_name]
            match_ = attrib_pattern._match(attrib_value,
                                           settings.copy(allow_multi_consumer=settings.allow_multi_consumer_inside,
                                                         dict_so_far=utils.dict_union(settings.dict_so_far, dict_)))
            if not match_:
                return Match()
            dict_.update(match_.dict)

        return Match(did_match=True, root=op, dict_=dict_)