Beispiel #1
0
    def test_caffe2_simple_model(self):
        model = ModelHelper(name="mnist")
        # how come those inputs don't break the forward pass =.=a
        workspace.FeedBlob("data", np.random.randn(1, 3, 64, 64).astype(np.float32))
        workspace.FeedBlob("label", np.random.randn(1, 1000).astype(np.int))

        with core.NameScope("conv1"):
            conv1 = brew.conv(model, "data", 'conv1', dim_in=1, dim_out=20, kernel=5)
            # Image size: 24 x 24 -> 12 x 12
            pool1 = brew.max_pool(model, conv1, 'pool1', kernel=2, stride=2)
            # Image size: 12 x 12 -> 8 x 8
            conv2 = brew.conv(model, pool1, 'conv2', dim_in=20, dim_out=100, kernel=5)
            # Image size: 8 x 8 -> 4 x 4
            pool2 = brew.max_pool(model, conv2, 'pool2', kernel=2, stride=2)
        with core.NameScope("classifier"):
            # 50 * 4 * 4 stands for dim_out from previous layer multiplied by the image size
            fc3 = brew.fc(model, pool2, 'fc3', dim_in=100 * 4 * 4, dim_out=500)
            relu = brew.relu(model, fc3, fc3)
            pred = brew.fc(model, relu, 'pred', 500, 10)
            softmax = brew.softmax(model, pred, 'softmax')
            xent = model.LabelCrossEntropy([softmax, "label"], 'xent')
            # compute the expected loss
            loss = model.AveragedLoss(xent, "loss")
        model.net.RunAllOnMKL()
        model.param_init_net.RunAllOnMKL()
        model.AddGradientOperators([loss], skip=1)
        blob_name_tracker = {}
        graph = c2_graph.model_to_graph_def(
            model,
            blob_name_tracker=blob_name_tracker,
            shapes={},
            show_simplified=False,
        )
        compare_proto(graph, self)
Beispiel #2
0
        def test_simple_cnnmodel(self):
            model = cnn.CNNModelHelper("NCHW", name="overfeat")
            workspace.FeedBlob(
                "data",
                np.random.randn(1, 3, 64, 64).astype(np.float32))
            workspace.FeedBlob("label",
                               np.random.randn(1, 1000).astype(np.int))
            with core.NameScope("conv1"):
                conv1 = model.Conv("data", "conv1", 3, 96, 11, stride=4)
                relu1 = model.Relu(conv1, conv1)
                pool1 = model.MaxPool(relu1, "pool1", kernel=2, stride=2)
            with core.NameScope("classifier"):
                fc = model.FC(pool1, "fc", 4096, 1000)
                pred = model.Softmax(fc, "pred")
                xent = model.LabelCrossEntropy([pred, "label"], "xent")
                loss = model.AveragedLoss(xent, "loss")

            blob_name_tracker = {}
            graph = c2_graph.model_to_graph_def(
                model,
                blob_name_tracker=blob_name_tracker,
                shapes={},
                show_simplified=False,
            )
            compare_proto(graph, self)
Beispiel #3
0
 def add_graph(self, model, input_to_model=None, verbose=False):
     # prohibit second call?
     # no, let tensorboard handle it and show its warning message.
     torch._C._log_api_usage_once("tensorboard.logging.add_graph")
     if hasattr(model, 'forward'):
         # A valid PyTorch model should have a 'forward' method
         self._get_file_writer().add_graph(
             graph(model, input_to_model, verbose))
     else:
         # Caffe2 models do not have the 'forward' method
         from caffe2.proto import caffe2_pb2
         from caffe2.python import core
         from torch.utils.tensorboard._caffe2_graph import (
             model_to_graph_def, nets_to_graph_def, protos_to_graph_def)
         if isinstance(model, list):
             if isinstance(model[0], core.Net):
                 current_graph = nets_to_graph_def(model)
             elif isinstance(model[0], caffe2_pb2.NetDef):
                 current_graph = protos_to_graph_def(model)
         else:
             # Handles cnn.CNNModelHelper, model_helper.ModelHelper
             current_graph = model_to_graph_def(model)
         event = event_pb2.Event(
             graph_def=current_graph.SerializeToString())
         self._get_file_writer().add_event(event)