Пример #1
0
 def test_transformer_FuseNNPACKConvReluInplaceFollowedByMultipleInputOp(
         self):
     net = core.Net("net")
     net.Conv(["X", "w", "b"], ["Y"],
              stride=1,
              pad=0,
              kernel=3,
              order="NCHW")
     net.Relu(["Y"], ["Y"])
     net.Conv(["Y", "w", "b"], ["Y2"],
              stride=1,
              pad=0,
              kernel=3,
              order="NCHW")
     net.Relu(["Y2"], ["Y2"])
     transformer.AddNNPACK(net)  # get the NNPACK engine
     assert tu.str_compare(net.Proto().op[0].engine, "NNPACK")
     transformer.FuseNNPACKConvRelu(net)
     assert len(net.Proto().op) == 2
     has_activation_arg = False
     for arg in net.Proto().op[0].arg:
         if tu.str_compare(arg.name, "activation"):
             assert tu.str_compare(arg.s, "Relu")
             has_activation_arg = True
     assert has_activation_arg
     assert net.Proto().op[0].output[0] != net.Proto().op[0].input[0]
     assert net.Proto().op[1].output[0] != net.Proto().op[1].input[0]
 def test_transformer_SinkMaxPool(self):
     net = self._base_test_net()
     net.MaxPool(["Y"], ["Y1"], kernel=3)
     net.Relu(["Y1"], ["Y1"])
     transformer.SinkMaxPool(net)
     assert tu.str_compare(net.Proto().op[1].type, "Relu")
     assert tu.str_compare(net.Proto().op[2].type, "MaxPool")
Пример #3
0
 def test_transformer_SinkMaxPool(self):
     net = core.Net("net")
     net.Conv(["X", "w", "b"], ["Y"],
              stride=1,
              pad=0,
              kernel=3,
              order="NCHW")
     net.MaxPool(["Y"], ["Y1"], kernel=3)
     net.Relu(["Y1"], ["Y1"])
     transformer.SinkMaxPool(net)
     assert tu.str_compare(net.Proto().op[1].type, "Relu")
     assert tu.str_compare(net.Proto().op[2].type, "MaxPool")
 def _fuse_nnpack_convrelu(self, net, expected_result_num_ops,
 expected_activation_arg=True):
     self._add_nnpack(net)
     transformer.FuseNNPACKConvRelu(net)
     self.assertEquals(tu.numOps(net), expected_result_num_ops)
     has_activation_arg = False
     for arg in net.Proto().op[0].arg:
         if tu.str_compare(arg.name, "activation"):
             assert tu.str_compare(arg.s, "Relu")
             has_activation_arg = True
     if expected_activation_arg:
         assert has_activation_arg
     else:
         assert not has_activation_arg
Пример #5
0
 def test_noFuseNNPACKConvRelu(self):
     net = core.Net("net")
     net.Conv(["X", "w", "b"], ["Y"],
              stride=1,
              pad=0,
              kernel=3,
              order="NCHW")
     net.Relu(["Y"], ["Y2"])
     net.Relu(["Y"], ["Y3"])
     transformer.AddNNPACK(net)  # get the NNPACK engine
     assert tu.str_compare(net.Proto().op[0].engine, "NNPACK")
     transformer.FuseNNPACKConvRelu(net)
     assert len(net.Proto().op) == 3
     has_activation_arg = False
     for arg in net.Proto().op[0].arg:
         if tu.str_compare(arg.name, "activation") and tu.str_compare(
                 arg.s, "Relu"):
             has_activation_arg = True
     assert not has_activation_arg
Пример #6
0
 def test_transformer_AddNNPACK(self):
     net = core.Net("net")
     net.Conv(["X", "w", "b"], ["Y"],
              stride=1,
              pad=0,
              kernel=3,
              order="NCHW")
     net.Relu(["Y"], ["Y2"])
     transformer.AddNNPACK(net)
     assert tu.str_compare(net.Proto().op[0].engine, "NNPACK")
 def _add_nnpack(self, net):
     transformer.AddNNPACK(net)
     assert tu.str_compare(net.Proto().op[0].engine, "NNPACK")