Exemplo n.º 1
0
    def visualize(self, layout_dir='LR', display=False):
        """
        Create and optionally display an image of the net structure.

        :param layout_dir: string in ['LR', 'TB', 'BT'].
          Short string for graph layout direction.

        :param display: bool.
          If set to ``True``, displays the graphic in a window. Press enter
          to close it.

        :returns: 3D numpy array.
          Graphic of the visualization as (H, W, C) image in BGR format.
        """
        if _draw is None or _cv2 is None:  # pragma: no cover
            raise Exception('Drawing is not available!')
        else:  # pragma: no cover
            with _NamedTemporaryFile(mode='w+b', suffix='.png') as tmpfile:
                _draw.draw_net_to_file(self.to_pbuf_message(),
                                       tmpfile.name,
                                       rankdir=layout_dir)
                result_image = _cv2.imread(tmpfile.name)
                assert result_image is not None
            if display:  # pragma: no cover
                _cv2.imshow(self.name, result_image)
                _cv2.waitKey(0)
                _cv2.destroyWindow(self.name)
            return result_image
Exemplo n.º 2
0
    def visualize(self,
                  layout_dir='LR',
                  display=False):
        """
        Create and optionally display an image of the net structure.

        :param layout_dir: string in ['LR', 'TB', 'BT'].
          Short string for graph layout direction.

        :param display: bool.
          If set to ``True``, displays the graphic in a window. Press enter
          to close it.

        :returns: 3D numpy array.
          Graphic of the visualization as (H, W, C) image in BGR format.
        """
        if _draw is None or _cv2 is None:
            raise Exception('Drawing is not available!')
        with _NamedTemporaryFile(mode='w+b', suffix='.png') as tmpfile:
            _draw.draw_net_to_file(self.to_pbuf_message(),
                                   tmpfile.name,
                                   rankdir=layout_dir)
            result_image = _cv2.imread(tmpfile.name)
            assert result_image is not None
        if display:
            _cv2.imshow(self.name, result_image)
            _cv2.waitKey(0)
            _cv2.destroyWindow(self.name)
        return result_image
Exemplo n.º 3
0
def manipulate_train(ori, target_train, **kwargs):
    train_pb = caffe_pb2.NetParameter()
    if cfg.MODEL.DIFFERENT_DILATION.ENABLE:
        ori = 'models/train_different_dilation_template.prototxt'
        with open(ori, 'r') as f:
            train_txt = f.read()
            txtf.Merge(train_txt, train_pb)
        train_pb = _add_dimension_reduction(train_pb)
        train_pb = _apply_mult_lr(train_pb)
        train_vis_file = '.'.join(target_train.rsplit('.')[:-1]) + '.jpg'
        with open(target_train, 'w') as f:
            f.write(str(train_pb))
        draw_net_to_file(train_pb, train_vis_file, 'LR', 0)
        tb.sess.add_image('train_net', train_vis_file, wall_time=0, step=0)
        return None

    with open(ori, 'r') as f:
        train_txt = f.read()
        txtf.Merge(train_txt, train_pb)
    train_pb = _add_dimension_reduction(train_pb)
    train_pb = _apply_mult_lr(train_pb)
    train_vis_file = '.'.join(target_train.rsplit('.')[:-1]) + '.jpg'
    with open(target_train, 'w') as f:
        f.write(str(train_pb))
    draw_net_to_file(train_pb, train_vis_file, 'LR', 0)
    tb.sess.add_image('train_net', train_vis_file, wall_time=0, step=0)
Exemplo n.º 4
0
def print_network(prototxt_filename):
    '''
    Draw the ANN architecture
    '''
    _net = caffe.proto.caffe_pb2.NetParameter()
    f = open(prototxt_filename)
    google.protobuf.text_format.Merge(f.read(), _net)
    draw.draw_net_to_file(_net, prototxt_filename + '.png')
    print('Draw ANN done!')
Exemplo n.º 5
0
def draw_from_proto(fname, wname):
    if os.path.isdir(fname):
        sh.mkdir("-p", wname)
        for base in os.listdir(fname):
            sub_fname = os.path.join(fname, base)
            if os.path.isdir(sub_fname):
                sub_wname = base
            else:
                sub_wname = base + ".png"
            draw_from_proto(sub_fname, os.path.join(wname, sub_wname))
    else:
        if not (fname.endswith(".pt") or fname.endswith(".prototxt")):
            print "Ignore file %s according to its suffix, not `pt` or `prototxt`" % fname
            return
        if "solver" in fname:
            print "Regard file %s as a solver prototxt, ignore!" % fname
            return
        net_param = caffe_pb2.NetParameter()
        text_format.Merge(open(fname, "r").read(), net_param)
        draw.draw_net_to_file(net_param, wname)
Exemplo n.º 6
0
def manipulate_test(ori, target_test, **kwargs):  # TODO more elegant editing
    test_pb = caffe_pb2.NetParameter()
    if cfg.MODEL.DIFFERENT_DILATION.ENABLE:
        ori = 'models/test_different_dilation_template.prototxt'
        with open(ori, 'r') as f:
            test_txt = f.read()
            txtf.Merge(test_txt, test_pb)
        test_pb = _add_dimension_reduction(test_pb)
        test_vis_file = '.'.join(target_test.rsplit('.')[:-1]) + '.jpg'
        with open(target_test, 'w') as f:
            f.write(str(test_pb))
        draw_net_to_file(test_pb, test_vis_file, 'LR', 1)
        tb.sess.add_image('test_net', test_vis_file, wall_time=0, step=0)
        return None

    with open(ori, 'r') as f:
        test_txt = f.read()
        txtf.Merge(test_txt, test_pb)
    test_pb = _add_dimension_reduction(test_pb)
    test_vis_file = '.'.join(target_test.rsplit('.')[:-1]) + '.jpg'
    with open(target_test, 'w') as f:
        f.write(str(test_pb))
    draw_net_to_file(test_pb, test_vis_file, 'LR', 1)
    tb.sess.add_image('test_net', test_vis_file, wall_time=0, step=0)
Exemplo n.º 7
0
def save_cnn_graph(path_model, name_model, phase, phase_name):
    net_parameter = caffe_pb2.NetParameter()
    text_format.Merge(open(path_model).read(), net_parameter)
    draw_net_to_file(net_parameter, const.PATH_TO_IMAGE_DIR + "/network_graph_" + name_model + "_" + phase_name + ".png", "BT", phase)