def test_simple_cnnmodel(self): model = cnn.CNNModelHelper("NCHW", name="overfeat") data, label = model.ImageInput(["db"], ["data", "label"], is_test=0) with core.NameScope("conv1"): conv1 = model.Conv(data, "conv1", 3, 96, 11, stride=4) relu1 = model.Relu(conv1, conv1) pool1 = model.MaxPool(relu1, "pool1", kernel=2, stride=2) with core.NameScope("classifier"): fc = model.FC(pool1, "fc", 4096, 1000) pred = model.Softmax(fc, "pred") xent = model.LabelCrossEntropy([pred, label], "xent") loss = model.AveragedLoss(xent, "loss") model.net.RunAllOnGPU() model.param_init_net.RunAllOnGPU() model.AddGradientOperators([loss], skip=1) track_blob_names = {} graph = tb.cnn_to_graph_def( model, track_blob_names=track_blob_names, shapes={}, ) self.assertEqual( track_blob_names['GRADIENTS/conv1/conv1_b_grad'], 'conv1/conv1_b_grad', ) self.maxDiff = None # We can't guarantee the order in which they appear, so we sort # both before we compare them sep = "node {" expected = "\n".join(sorted( sep + "\n " + part.strip() for part in EXPECTED.strip().split(sep) if part.strip() )) actual = "\n".join(sorted( sep + "\n " + part.strip() for part in str(graph).strip().split(sep) if part.strip() )) self.assertMultiLineEqual(actual, expected)
def test_simple_cnnmodel(self): model = cnn.CNNModelHelper("NCHW", name="overfeat") data, label = model.ImageInput(["db"], ["data", "label"], is_test=0) with core.NameScope("conv1"): conv1 = model.Conv(data, "conv1", 3, 96, 11, stride=4) relu1 = model.Relu(conv1, conv1) pool1 = model.MaxPool(relu1, "pool1", kernel=2, stride=2) with core.NameScope("classifier"): fc = model.FC(pool1, "fc", 4096, 1000) pred = model.Softmax(fc, "pred") xent = model.LabelCrossEntropy([pred, label], "xent") loss = model.AveragedLoss(xent, "loss") model.net.RunAllOnGPU() model.param_init_net.RunAllOnGPU() model.AddGradientOperators([loss], skip=1) track_blob_names = {} graph = tb.cnn_to_graph_def( model, track_blob_names=track_blob_names, shapes={}, ) self.assertEqual( track_blob_names['GRADIENTS/conv1/conv1_b_grad'], 'conv1/conv1_b_grad', ) self.maxDiff = None # We can't guarantee the order in which they appear, so we sort # both before we compare them sep = "node {" expected = "\n".join( sorted(sep + "\n " + part.strip() for part in EXPECTED.strip().split(sep) if part.strip())) actual = "\n".join( sorted(sep + "\n " + part.strip() for part in str(graph).strip().split(sep) if part.strip())) self.assertMultiLineEqual(actual, expected)
def visualize_cnn(cnn, **kwargs): g = tb_exporter.cnn_to_graph_def(cnn, **kwargs) _show_graph(g)