Exemplo n.º 1
0
 def test_fuseNNPACKConvReluInplaceFollowedByMultipleInputOp(self):
     net = core.Net("net")
     net.Conv(["X", "w", "b"], ["Y"],
              stride=1,
              pad=0,
              kernel=3,
              order="NCHW")
     net.Relu(["Y"], ["Y"])
     net.Conv(["Y", "w", "b"], ["Y2"],
              stride=1,
              pad=0,
              kernel=3,
              order="NCHW")
     net.Relu(["Y2"], ["Y2"])
     addNNPACK(net)  # get the NNPACK engine
     assert str_compare(net.Proto().op[0].engine, "NNPACK")
     fuseNNPACKConvRelu(net)
     assert (len(net.Proto().op) == 2)
     has_activation_arg = False
     for arg in net.Proto().op[0].arg:
         if str_compare(arg.name, "activation"):
             assert str_compare(arg.s, "Relu")
             has_activation_arg = True
     assert has_activation_arg
     assert net.Proto().op[0].output[0] != net.Proto().op[0].input[0]
     assert net.Proto().op[1].output[0] != net.Proto().op[1].input[0]
Exemplo n.º 2
0
 def test_addNNPACK(self):
     net = core.Net("net")
     net.Conv(
         ["X", "w", "b"], ["Y"], stride=1, pad=0, kernel=3, order="NCHW"
     )
     net.Relu(["Y"], ["Y2"])
     addNNPACK(net)
     assert str_compare(net.Proto().op[0].engine, "NNPACK")
Exemplo n.º 3
0
 def test_addNNPACK(self):
     net = core.Net("net")
     net.Conv(["X", "w", "b"], ["Y"],
              stride=1,
              pad=0,
              kernel=3,
              order="NCHW")
     net.Relu(["Y"], ["Y2"])
     addNNPACK(net)
     assert str_compare(net.Proto().op[0].engine, "NNPACK")
Exemplo n.º 4
0
 def test_noFuseNNPACKConvRelu(self):
     net = core.Net("net")
     net.Conv(
         ["X", "w", "b"], ["Y"], stride=1, pad=0, kernel=3, order="NCHW"
     )
     net.Relu(["Y"], ["Y2"])
     net.Relu(["Y"], ["Y3"])
     addNNPACK(net) # get the NNPACK engine
     assert str_compare(net.Proto().op[0].engine, "NNPACK")
     fuseNNPACKConvRelu(net)
     assert (len(net.Proto().op) == 3)
     has_activation_arg = False
     for arg in net.Proto().op[0].arg:
         if str_compare(arg.name, "activation") and str_compare(arg.s, "Relu"):
             has_activation_arg = True
     assert not has_activation_arg
Exemplo n.º 5
0
 def test_fuseNNPACKConvReluNoInplace(self):
     net = core.Net("net")
     net.Conv(
         ["X", "w", "b"], ["Y"], stride=1, pad=0, kernel=3, order="NCHW"
     )
     net.Relu(["Y"], ["X"])
     addNNPACK(net) # get the NNPACK engine
     assert str_compare(net.Proto().op[0].engine, "NNPACK")
     fuseNNPACKConvRelu(net)
     assert (len(net.Proto().op) == 1)
     has_activation_arg = False
     for arg in net.Proto().op[0].arg:
         if str_compare(arg.name, "activation"):
             assert str_compare(arg.s, "Relu")
             has_activation_arg = True
     assert has_activation_arg
     assert net.Proto().op[0].output[0] != net.Proto().op[0].input[0]
Exemplo n.º 6
0
 def test_noFuseNNPACKConvRelu(self):
     net = core.Net("net")
     net.Conv(["X", "w", "b"], ["Y"],
              stride=1,
              pad=0,
              kernel=3,
              order="NCHW")
     net.Relu(["Y"], ["Y2"])
     net.Relu(["Y"], ["Y3"])
     addNNPACK(net)  # get the NNPACK engine
     assert (net.Proto().op[0].engine == "NNPACK")
     fuseNNPACKConvRelu(net)
     assert (len(net.Proto().op) == 3)
     has_activation_arg = False
     for arg in net.Proto().op[0].arg:
         if arg.name == "activation" and arg.s == "Relu":
             has_activation_arg = True
     assert not has_activation_arg
Exemplo n.º 7
0
 def test_fuseNNPACKConvReluInplaceRelu(self):
     net = core.Net("net")
     net.Conv(["X", "w", "b"], ["Y"],
              stride=1,
              pad=0,
              kernel=3,
              order="NCHW")
     net.Relu(["Y"], ["Y"])
     addNNPACK(net)  # get the NNPACK engine
     assert (net.Proto().op[0].engine == "NNPACK")
     fuseNNPACKConvRelu(net)
     assert (len(net.Proto().op) == 1)
     has_activation_arg = False
     for arg in net.Proto().op[0].arg:
         if arg.name == "activation":
             assert (arg.s == "Relu")
             has_activation_arg = True
     assert has_activation_arg
     assert net.Proto().op[0].output[0] != net.Proto().op[0].input[0]