def test_registry_invalid(self, input_dim, output_dim, batch_size): m = model_helper.ModelHelper() brew.fc(m, "data", "fc1", dim_in=input_dim, dim_out=output_dim) with self.assertRaises(RuntimeError): workspace.ApplyTransform( "definitely_not_a_real_transform", m.net.Proto())
def test_simple_transform(self, input_dim, output_dim, batch_size): m = model_helper.ModelHelper() fc1 = brew.fc(m, "data", "fc1", dim_in=input_dim, dim_out=output_dim) fc2 = brew.fc(m, fc1, "fc2", dim_in=output_dim, dim_out=output_dim) conv = brew.conv(m, fc2, "conv", dim_in=output_dim, dim_out=output_dim, use_cudnn=True, engine="CUDNN", kernel=3) conv.Relu([], conv)\ .Softmax([], "pred") \ .LabelCrossEntropy(["label"], ["xent"]) \ .AveragedLoss([], "loss") transformed_net_proto = workspace.ApplyTransform( "ConvToNNPack", m.net.Proto()) self.assertEqual(transformed_net_proto.op[2].engine, "NNPACK")