Esempio n. 1
0
 def test_registry_invalid(self, input_dim, output_dim, batch_size):
     m = model_helper.ModelHelper()
     brew.fc(m, "data", "fc1", dim_in=input_dim, dim_out=output_dim)
     with self.assertRaises(RuntimeError):
         workspace.ApplyTransform(
             "definitely_not_a_real_transform",
             m.net.Proto())
Esempio n. 2
0
    def test_simple_transform(self, input_dim, output_dim, batch_size):
        m = model_helper.ModelHelper()
        fc1 = brew.fc(m, "data", "fc1", dim_in=input_dim, dim_out=output_dim)
        fc2 = brew.fc(m, fc1, "fc2", dim_in=output_dim, dim_out=output_dim)
        conv = brew.conv(m, fc2, "conv",
                            dim_in=output_dim,
                            dim_out=output_dim,
                            use_cudnn=True,
                            engine="CUDNN",
                            kernel=3)

        conv.Relu([], conv)\
           .Softmax([], "pred") \
           .LabelCrossEntropy(["label"], ["xent"]) \
           .AveragedLoss([], "loss")

        transformed_net_proto = workspace.ApplyTransform(
            "ConvToNNPack",
            m.net.Proto())

        self.assertEqual(transformed_net_proto.op[2].engine, "NNPACK")