Ejemplo n.º 1
0
def neuron_groups(model, img, layer, n_groups=6, attr_classes=[]):
    # Compute activations

    with tf.Graph().as_default(), tf.Session():
        t_input = tf.placeholder_with_default(img, [None, None, 3])
        T = render.import_model(model, t_input, t_input)
        acts = T(layer).eval()

    # We'll use ChannelReducer (a wrapper around scikit learn's factorization tools)
    # to apply Non-Negative Matrix factorization (NMF).

    nmf = ChannelReducer(n_groups, "PCA")
    print(layer, n_groups)
    spatial_factors = nmf.fit_transform(acts)[0].transpose(2, 0, 1).astype("float32")
    channel_factors = nmf._reducer.components_.astype("float32")

    # Let's organize the channels based on their horizontal position in the image

    x_peak = np.argmax(spatial_factors.max(1), 1)
    ns_sorted = np.argsort(x_peak)
    spatial_factors = spatial_factors[ns_sorted]
    channel_factors = channel_factors[ns_sorted]

    # And create a feature visualziation of each group

    param_f = lambda: param.image(80, batch=n_groups)
    obj = sum(objectives.direction(layer, channel_factors[i], batch=i)
              for i in range(n_groups))
    group_icons = render.render_vis(model, obj, param_f, verbose=False)[-1]

    # We'd also like to know about attribution
    #
    # First, let's turn each group into a vector over activations
    group_vecs = [spatial_factors[i, ..., None] * channel_factors[i]
                  for i in range(n_groups)]

    attrs = np.asarray([raw_class_group_attr(img, layer, attr_class, model, group_vecs)
                        for attr_class in attr_classes])

    gray_scale_groups = [skimage.color.rgb2gray(icon) for icon in group_icons]

    # Let's render the visualization!
    data = {
        "img": _image_url(img),
        "n_groups": n_groups,
        "spatial_factors": [_image_url(factor[..., None] / np.percentile(spatial_factors, 99) * [1, 0, 0]) for factor in
                            spatial_factors],
        "group_icons": [_image_url(icon) for icon in gray_scale_groups]
    }

    # with open('ng.pickle', 'wb') as handle:
    #     pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)
    # with open('./svelte_python/ng.pickle', 'rb') as p_file:
    #     data = pickle.load(p_file)

    generate_html('neuron_groups', data)
Ejemplo n.º 2
0
def spatial_spatial_attr(model,
                         img,
                         layer1,
                         layer2,
                         hint_label_1=None,
                         hint_label_2=None,
                         override=None):
    hint1 = orange_blue(raw_class_spatial_attr(model,
                                               img,
                                               layer1,
                                               hint_label_1,
                                               override=override),
                        raw_class_spatial_attr(model,
                                               img,
                                               layer1,
                                               hint_label_2,
                                               override=override),
                        clip=True)
    hint2 = orange_blue(raw_class_spatial_attr(model,
                                               img,
                                               layer2,
                                               hint_label_1,
                                               override=override),
                        raw_class_spatial_attr(model,
                                               img,
                                               layer2,
                                               hint_label_2,
                                               override=override),
                        clip=True)

    attrs = raw_spatial_spatial_attr(model,
                                     img,
                                     layer1,
                                     layer2,
                                     override=override)
    attrs = attrs / attrs.max()

    data = {
        "spritemap1": image_url_grid(attrs),
        "spritemap2": image_url_grid(attrs.transpose(2, 3, 0, 1)),
        "size1": attrs.shape[3],
        "layer1": layer1,
        "size2": attrs.shape[0],
        "layer2": layer2,
        "img": _image_url(img),
        "hint1": _image_url(hint1),
        "hint2": _image_url(hint2)
    }
    generate_html('spatial_attr', data)
Ejemplo n.º 3
0
def show_image(image):
    html = ""
    data_url = _image_url(image)
    html += '<img width=\"100\" style=\"margin: 10px\" src=\"' + data_url + '\">'
    with open("img.html", "w") as f:
        f.write(html)
    _display_html(html)
Ejemplo n.º 4
0
def image_url_grid(grid):
    return [[_image_url(img) for img in line] for line in grid]
Ejemplo n.º 5
0
 def show_images(self, images):
     html = ""
     for image in images:
         data_url = _image_url(image)
         html += '<img width=\"100\" style=\"margin: 10px\" src=\"' + data_url + '\">'
     _display_html(html)
                text.append(line)
        else:
            if "<image" in line and "image/png" in line and (
                    len(line) > 10000 or
                (key != "images/rot-features" and len(line) > 10000)):
                image_str = line.split("xlink:href=\"")[1].split("\"")[0]
                image_content = image_str[len("data:image/png;base64,"):]
                image_hash = hex(hash(image_content))[4:20]
                with open(f"public/generated_images/{image_hash}.png",
                          "wb") as f:
                    f.write(base64.b64decode(image_content))
                line = line.replace(image_str,
                                    f"generated_images/{image_hash}.png")
                line = line.replace("/>", " style='image-rendering: auto;' />")
                print(len(image_content) // 1024, end=", ")
            elif "<image" in line:
                image_str = line.split("xlink:href=\"")[1].split("\"")[0]
                arr = png_url2im(image_str)
                if arr.shape[0] == 5 and arr.shape[1] == 5:
                    print(arr.shape)
                    arr = np.repeat(arr, 4, axis=0)
                    arr = np.repeat(arr, 4, axis=1)
                    new_image_str = _image_url(arr)
                    line = line.replace(image_str, new_image_str)
            text.append(line)
    figure_html[key] = "\n".join(text)

index_template = open("index_template.html", "r").read()
index_html = index_template.format(**figure_html)
open("public/index.html", "w").write(index_html)
print("")
def spatial_spatial_attr(imglist,
                         filenamelist,
                         layer1,
                         layer2,
                         hint_label_1=None,
                         hint_label_2=None,
                         override=None):
    filename = ''
    for f in filenamelist:
        filename += f
    with open('result/' + filename + '.html', 'a') as f:
        f.write('''<!DOCTYPE html>
                          <html>
                          <head >
                            <title>%s</title>
                                <script src='GroupWidget_1cb0e0d.js'></script>
                          </head>
                          <body>''' % (filename))
    for key, img in enumerate(imglist):
        hint1 = orange_blue(raw_class_spatial_attr(img,
                                                   layer1,
                                                   hint_label_1,
                                                   override=override),
                            raw_class_spatial_attr(img,
                                                   layer1,
                                                   hint_label_2,
                                                   override=override),
                            clip=True)
        hint2 = orange_blue(raw_class_spatial_attr(img,
                                                   layer2,
                                                   hint_label_1,
                                                   override=override),
                            raw_class_spatial_attr(img,
                                                   layer2,
                                                   hint_label_2,
                                                   override=override),
                            clip=True)

        attrs = raw_spatial_spatial_attr(img,
                                         layer1,
                                         layer2,
                                         override=override)
        attrs = attrs / attrs.max()

        with open('result/' + filename + '.html', 'a') as f:
            f.write('''  <main%s></main%s>
                <script>
                  var app = new GroupWidget_1cb0e0d({
                    target: document.querySelector( 'main%s' ),''' %
                    (key, key, key))
            f.write('''      data: {
              "layer2": "''' + layer2 + '''",
              "layer1": "''' + layer1 + '''",''')
            f.write('"spritemap1"' + ":" + str(image_url_grid(attrs)) + ',\n')
            f.write('"spritemap2"' + ":" +
                    str(image_url_grid(attrs.transpose(2, 3, 0, 1))) + ',\n')
            f.write('"size1"' + ":" + str(attrs.shape[3]) + ',\n')
            f.write('"size2"' + ":" + str(attrs.shape[0]) + ',\n')
            f.write('"img"' + ":" + '"' + str(_image_url(img)) + '"' + ',\n')
            f.write('"hint1"' + ":" + '"' + str(_image_url(hint1)) + '"' +
                    ',\n')
            f.write('"hint2"' + ":" + '"' + str(_image_url(hint2)) + '"' +
                    '\n')
            f.write('''} });''')
            f.write('''</script>''')

    with open('result/' + filename + '.html', 'a') as f:
        f.write('''</body></html >''')
    print(filename)
def neuron_groups(imglist, filenamelist, layer, n_groups=6, attr_classes=None):
    # Compute activations
    filename = ''
    for f in filenamelist:
        filename += f
    with open('result/' + filename + '.html', 'a') as f:
        f.write('''<!DOCTYPE html>
                        <html>
                        <head >
                          <title>%s</title>
                              <script src='GroupWidget_1cb0e0d.js'></script>
                        </head>
                        <body>''' % (filename))
    for key, img in enumerate(imglist):
        if attr_classes is None:
            attr_classes = []
        with tf.Graph().as_default(), tf.Session():
            t_input = tf.placeholder_with_default(img, [None, None, 3])
            T = render.import_model(model, t_input, t_input)
            acts = T(layer).eval()

        # We'll use ChannelReducer (a wrapper around scikit learn's factorization tools)
        # to apply Non-Negative Matrix factorization (NMF).

        nmf = ChannelReducer(n_groups, "NMF")
        spatial_factors = nmf.fit_transform(acts)[0].transpose(
            2, 0, 1).astype("float32")
        channel_factors = nmf._reducer.components_.astype("float32")

        # Let's organize the channels based on their horizontal position in the image

        x_peak = np.argmax(spatial_factors.max(1), 1)
        ns_sorted = np.argsort(x_peak)
        spatial_factors = spatial_factors[ns_sorted]
        channel_factors = channel_factors[ns_sorted]

        # And create a feature visualziation of each group

        param_f = lambda: param.image(80, batch=n_groups)
        obj = sum(
            objectives.direction(layer, channel_factors[i], batch=i)
            for i in range(n_groups))
        group_icons = render.render_vis(model, obj, param_f, verbose=False)[-1]

        # We'd also like to know about attribution

        # First, let's turn each group into a vector over activations
        group_vecs = [
            spatial_factors[i, ..., None] * channel_factors[i]
            for i in range(n_groups)
        ]

        attrs = np.asarray([
            raw_class_group_attr(img, layer, attr_class, group_vecs)
            for attr_class in attr_classes
        ])

        print(attrs)

        # Let's render the visualization!

        with open('result/' + filename + '.html', 'a') as f:
            f.write('''  <main%s></main%s>
                          <script>
                            var app = new GroupWidget_1cb0e0d({
                              target: document.querySelector( 'main%s' ),''' %
                    (key, key, key))
            f.write('''data: {''')
            f.write('"img":"%s",\n' % str(_image_url(img)))
            f.write('"n_groups"' + ":" + str(n_groups) + ',\n')
            f.write('"spatial_factors"' + ":" + str([
                _image_url(factor[..., None] /
                           np.percentile(spatial_factors, 99) * [1, 0, 0])
                for factor in spatial_factors
            ]) + ',\n')
            f.write('"group_icons"' + ":" +
                    str([_image_url(icon) for icon in group_icons]) + ',\n')
            f.write('''} });''')
            f.write('''</script>''')

    with open('result/' + filename + '.html', 'a') as f:
        f.write('''</body></html >''')
    print(filename)