Exemple #1
0
def _read_and_cache(url, mode="rb"):
    local_path = local_cache_path(url)
    lock = FileLock(local_path + ".lockfile")
    with lock:
        if os.path.exists(local_path):
            log.debug("Found cached file '%s'.", local_path)
            return _handle_gfile(local_path)
        log.debug("Caching URL '%s' locally at '%s'.", url, local_path)
        try:
            with write_handle(local_path, "wb") as output_handle, read_handle(
                url, cache=False, mode="rb"
            ) as input_handle:
                for chunk in _file_chunk_iterator(input_handle):
                    output_handle.write(chunk)
            gc.collect()
            return _handle_gfile(local_path, mode=mode)
        except tf.errors.NotFoundError:
            raise
        except Exception as e:  # bare except to catch things like SystemExit or KeyboardInterrupt
            log.warning("Caching (%s -> %s) failed: %s", url, local_path, e)
            try:
                os.remove(local_path)
            except OSError:
                pass
            raise
Exemple #2
0
def save(thing, url_or_handle, **kwargs):
    """Save object to file on CNS.

    File format is inferred from path. Use save_img(), save_npy(), or save_json()
    if you need to force a particular format.

    Args:
      obj: object to save.
      path: CNS path.

    Raises:
      RuntimeError: If file extension not supported.
    """
    is_handle = hasattr(url_or_handle, "write") and hasattr(
        url_or_handle, "name")
    if is_handle:
        _, ext = os.path.splitext(url_or_handle.name)
    else:
        _, ext = os.path.splitext(url_or_handle)
    if not ext:
        raise RuntimeError("No extension in URL: " + url_or_handle)

    if ext in savers:
        saver = savers[ext]
        if is_handle:
            saver(thing, url_or_handle, **kwargs)
        else:
            with write_handle(url_or_handle) as handle:
                saver(thing, handle, **kwargs)
    else:
        saver_names = [(key, fn.__name__) for (key, fn) in savers.items()]
        message = "Unknown extension '{}', supports {}."
        raise ValueError(message.format(ext, saver_names))
Exemple #3
0
def save_joblib(value, url_or_handle, **kwargs):
    import joblib

    if hasattr(url_or_handle, "write") and hasattr(url_or_handle, "name"):
        joblib.dump(value, url_or_handle, **kwargs)
    else:
        with write_handle(url_or_handle) as handle:
            joblib.dump(value, handle, **kwargs)
Exemple #4
0
def test_write_handle_binary():
    path = "./tests/fixtures/bytes"

    with write_handle(path) as handle:
        handle.write(random_bytes)
    content = io.open(path, "rb").read()

    assert os.path.isfile(path)
    assert content == random_bytes
Exemple #5
0
def test_write_handle_text():
    text = u"The quick brown 🦊 jumps over the lazy dog"
    path = "./tests/fixtures/string2.txt"

    with write_handle(path, mode="w") as handle:
        handle.write(text)
    content = io.open(path, "rt").read()

    assert os.path.isfile(path)
    assert content == text
Exemple #6
0
def save_img(object, url, **kwargs):
    """Save numpy array as image file on CNS."""
    if isinstance(object, np.ndarray):
        normalized = _normalize_array(object)
        image = PIL.Image.fromarray(normalized)
    elif not isinstance(object, PIL.Image):
        raise ValueError("Can only save_img for numpy arrays or PIL.Images!")

    with write_handle(url) as handle:
        image.save(handle,
                   **kwargs)  # will infer format from handle's url ext.
Exemple #7
0
def _read_and_cache(url, mode='rb'):
    local_path = local_cache_path(url)
    if os.path.exists(local_path):
        log.info("Found cached file '%s'.", local_path)
        return _handle_gfile(local_path)
    else:
        log.info("Caching URL '%s' locally at '%s'.", url, local_path)
        with write_handle(local_path,
                          'wb') as output, read_handle(url,
                                                       cache=False,
                                                       mode='rb') as input:
            for chunk in _file_chunk_iterator(input):
                output.write(chunk)
        gc.collect()
        return _handle_gfile(local_path, mode=mode)
Exemple #8
0
def _read_and_cache(url):
    local_name = RESERVED_PATH_CHARS.sub('_', url)
    local_path = os.path.join(gettempdir(), local_name)
    if os.path.exists(local_path):
        log.info("Found cached file '%s'.", local_path)
        return _handle_gfile(local_path)
    else:
        log.info("Caching URL '%s' locally at '%s'.", url, local_path)
        with write_handle(local_path,
                          'wb') as output, read_handle(url,
                                                       cache=False) as input:
            for chunk in _file_chunk_iterator(input):
                output.write(chunk)
        gc.collect()
        return _handle_gfile(local_path)
Exemple #9
0
def save(thing, url_or_handle, **kwargs):
    """Save object to file on CNS.

    File format is inferred from path. Use save_img(), save_npy(), or save_json()
    if you need to force a particular format.

    Args:
      obj: object to save.
      path: CNS path.

    Raises:
      RuntimeError: If file extension not supported.
    """
    # Determine context
    # Is this a handle? What is the extension? Are we saving to GCS?
    is_handle = hasattr(url_or_handle, "write") and hasattr(url_or_handle, "name")
    if is_handle:
      path = url_or_handle.name
    else:
      path = url_or_handle

    _, ext = os.path.splitext(path)
    is_gcs = path.startswith("gs://")

    if not ext:
        raise RuntimeError("No extension in URL: " + path)

    # Determine which saver should be used
    if ext in savers:
        saver = savers[ext]
    elif isinstance(thing, str):
        saver = save_str
    else:
        message = "Unknown extension '{}'. As a result, only strings can be saved, not {}. Supported extensions: {}"
        raise ValueError(message.format(ext, type(thing).__name__, list(savers.keys()) ))

    # Actually save
    if is_handle:
        saver(thing, url_or_handle, **kwargs)
    else:
        with write_handle(url_or_handle) as handle:
            saver(thing, handle, **kwargs)

    # Set mime type on gcs if html -- usually, when one saves an html to GCS,
    # they want it to be viewsable as a website.
    if is_gcs and ext == ".html":
        subprocess.run(["gsutil", "setmeta", "-h", "Content-Type:text/html", path])
def compile_html(input_path,
                 html_path=None,
                 *,
                 props=None,
                 precision=None,
                 title=None,
                 div_id=None,
                 inline_js=None,
                 svelte_to_js=None,
                 js_path=None,
                 js_name=None,
                 js_lint=None):
    """Compile Svelte or JavaScript to HTML.

    Arguments:
        input_path:   path to input Svelte or JavaScript file
        html_path:    path to output HTML file
                      defaults to input_path with a new .html suffix
        props:        JSON-serializable object to pass to Svelte script
                      defaults to an empty object
        precision:    number of significant figures to round numpy arrays to
                      defaults to no rounding
        title:        title of HTML page
                      defaults to html_path filename without suffix
        div_id:       HTML id of div containing Svelte component
                      defaults to _default_div_id
        inline_js:    whether to insert the JavaScript into the HTML page inline
                      defaults to svelte_to_js
        svelte_to_js: whether to first compile from Svelte to JavaScript
                      defaults to whether input_path doesn't have a .js suffix
        js_path:      path to output JavaScript file if compiling from Svelte
                      and not inserting the JavaScript inline
                      defaults to compile_js default
        js_name:      name of JavaScript global variable
                      should match existing name if compiling from JavaScript
                      defaults to _default_js_name
        js_lint:      whether to use eslint if compiling from Svelte
                      defaults to compile_js default
    """
    if html_path is None:
        html_path = replace_file_extension(input_path, ".html")
    if props is None:
        props = {}
    if title is None:
        title = os.path.basename(html_path).rsplit(".", 1)[0]
    if div_id is None:
        div_id = _default_div_id
    if svelte_to_js is None:
        svelte_to_js = not input_path.endswith(".js")
    if inline_js is None:
        inline_js = svelte_to_js

    if svelte_to_js:
        if inline_js:
            if js_path is None:
                js_path = replace_file_extension(input_path, ".js")
            prefix = "svelte_" + os.path.basename(js_path)
            if prefix.endswith(".js"):
                prefix = prefix[:-3]
            _, js_path = tempfile.mkstemp(suffix=".js",
                                          prefix=prefix + "_",
                                          dir=_temp_config_dir,
                                          text=True)
        try:
            compile_js_result = compile_js(input_path,
                                           js_path,
                                           js_name=js_name,
                                           js_lint=js_lint)
        except CompileError as exn:
            raise CompileError(
                "Unable to compile Svelte source.\n"
                "See the above advice or try supplying pre-compiled JavaScript."
            ) from exn
        js_path = compile_js_result["js_path"]
        js_name = compile_js_result["js_name"]
        command_output = compile_js_result["command_output"]
    else:
        js_path = input_path
        if js_name is None:
            js_name = _default_js_name
        command_output = None

    if inline_js:
        with read_handle(js_path, cache=False, mode="r") as js_file:
            js_code = js_file.read().rstrip("\n")
            js_html = "<script>\n" + js_code + "\n  </script>"
        js_path = None
    else:
        js_relpath = os.path.relpath(js_path, start=os.path.dirname(html_path))
        js_html = '<script src="' + js_relpath + '"></script>'

    with write_handle(html_path, "w") as html_file:
        html_file.write("""<!DOCTYPE html>
<html>
<head>
  <meta charset="utf-8">
  <title>""" + title + '''</title>
</head>
<body>
  <div id="''' + div_id + """"></div>
  """ + js_html + """
  <script>
  var app = new """ + js_name + """({
    target: document.querySelector("#""" + div_id + """"),
    props: """ + json.dumps(props, cls=encoder(precision=precision)) + """
  });
  </script>
</body>
</html>""")
    return {
        "html_path": html_path,
        "js_path": js_path if svelte_to_js else None,
        "title": title,
        "div_id": div_id,
        "js_name": js_name,
        "command_output": command_output,
    }
Exemple #11
0
def save_npy(object, url):
    """Save numpy array as npy file."""
    with write_handle(url, "w") as handle:
        np.save(handle, object)
Exemple #12
0
def generate(
    *,
    output_dir,
    model_bytes,
    observations,
    observations_full=None,
    trajectories,
    policy_logits_name,
    value_function_name,
    env_name=None,
    numpy_precision=6,
    inline_js=True,
    inline_large_json=None,
    batch_size=512,
    action_combos=None,
    action_group_fns=[
        lambda combo: "RIGHT" in combo,
        lambda combo: "LEFT" in combo,
        lambda combo: "UP" in combo,
        lambda combo: "DOWN" in combo,
        lambda combo: "RIGHT" not in combo and "LEFT" not in combo and "UP"
        not in combo and "DOWN" not in combo,
    ],
    layer_kwargs={},
    input_layer_include=False,
    input_layer_name="input",
    gae_gamma=None,
    gae_lambda=None,
    trajectory_bookmarks=16,
    nmf_features=8,
    nmf_attr_opts=None,
    vis_subdiv_mults=[0.25, 0.5, 1, 2],
    vis_subdiv_mult_default=1,
    vis_expand_mults=[1, 2, 4, 8],
    vis_expand_mult_default=4,
    vis_thumbnail_num_mult=4,
    vis_thumbnail_expand_mult=4,
    scrub_range=(42 / 64, 44 / 64),
    attr_integrate_steps=10,
    attr_max_paths=None,
    attr_policy=False,
    attr_single_channels=True,
    observations_subdir="observations/",
    trajectories_subdir="trajectories/",
    trajectories_scrub_subdir="trajectories_scrub/",
    features_subdir="features/",
    thumbnails_subdir="thumbnails/",
    attribution_subdir="attribution/",
    attribution_scrub_subdir="attribution_scrub/",
    features_grids_subdir="features_grids/",
    attribution_totals_subdir="attribution_totals/",
    video_height="16em",
    video_width="16em",
    video_speed=12,
    policy_display_height="2em",
    policy_display_width="40em",
    navigator_width="24em",
    scrubber_height="4em",
    scrubber_width="48em",
    scrubber_visible_duration=256,
    legend_item_height="6em",
    legend_item_width="6em",
    feature_viewer_height="40em",
    feature_viewer_width="40em",
    attribution_weight=0.9,
    graph_colors={
        "v": "green",
        "action": "red",
        "action_group": "orange",
        "advantage": "blue",
    },
    trajectory_color="blue",
):
    from mpi4py import MPI

    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()

    model = get_model(model_bytes)
    if rank == 0:
        js_source_path = get_compiled_js()

    if env_name is None:
        env_name = "unknown"
    if inline_large_json is None:
        inline_large_json = "://" not in output_dir
    layer_kwargs.setdefault("name_contains_one_of", None)
    layer_kwargs.setdefault("op_is_one_of", ["relu"])
    layer_kwargs.setdefault("bottleneck_only", True)
    layer_kwargs.setdefault("discard_first_n", 0)
    if observations_full is None:
        observations_full = observations
    if "observations_full" not in trajectories:
        trajectories["observations_full"] = trajectories["observations"]
    if np.issubdtype(observations.dtype, np.integer):
        observations = observations / np.float32(255)
    if np.issubdtype(observations_full.dtype, np.integer):
        observations_full = observations_full / np.float32(255)
    if np.issubdtype(trajectories["observations"].dtype, np.integer):
        trajectories[
            "observations"] = trajectories["observations"] / np.float32(255)
    if np.issubdtype(trajectories["observations_full"].dtype, np.integer):
        trajectories["observations_full"] = trajectories[
            "observations_full"] / np.float32(255)
    if action_combos is None:
        num_actions = get_shape(model, policy_logits_name)[-1]
        action_combos = list(map(lambda x: (str(x), ), range(num_actions)))
        if env_name == "coinrun_old":
            action_combos = [
                (),
                ("RIGHT", ),
                ("LEFT", ),
                ("UP", ),
                ("RIGHT", "UP"),
                ("LEFT", "UP"),
                ("DOWN", ),
                ("A", ),
                ("B", ),
            ][:num_actions]
    if gae_gamma is None:
        gae_gamma = 0.999
    if gae_lambda is None:
        gae_lambda = 0.95

    layer_names = get_layer_names(model,
                                  [policy_logits_name, value_function_name],
                                  **layer_kwargs)
    if not layer_names:
        raise RuntimeError("No appropriate layers found. "
                           "Please adapt layer_kwargs to your architecture")
    squash = lambda s: s.replace("/", "").replace("_", "")
    if len(set([squash(layer_key)
                for layer_key in layer_names.keys()])) < len(layer_names):
        raise RuntimeError("Error squashing abbreviated layer names. "
                           "Different substitutions must be used")
    mpi_enumerate = lambda l: (lambda indices: list(enumerate(l))[indices[
        rank]:indices[rank + 1]])(np.linspace(0, len(l),
                                              comm.Get_size() + 1).astype(int))
    save_image = lambda image, path: save(
        image, os.path.join(output_dir, path), domain=(0, 1))
    save_images = lambda images, path: save_image(
        concatenate_horizontally(images), path)
    json_preloaded = {}
    save_json = lambda data, path: (json_preloaded.update(
        {path: data}) if inline_large_json else save(
            data, os.path.join(output_dir, path), indent=None))
    get_scrub_slice = lambda width: slice(
        int(np.round(scrub_range[0] * width)),
        int(
            np.maximum(np.round(scrub_range[1] * width),
                       np.round(scrub_range[0] * width) + 1)),
    )
    action_groups = [[
        action for action, combo in enumerate(action_combos) if group_fn(combo)
    ] for group_fn in action_group_fns]
    action_groups = list(
        filter(lambda action_group: len(action_group) > 1, action_groups))

    for index, observation in mpi_enumerate(observations_full):
        observation_path = os.path.join(observations_subdir, f"{index}.png")
        save_image(observation, observation_path)
    for index, trajectory_observations in mpi_enumerate(
            trajectories["observations_full"]):
        trajectory_path = os.path.join(trajectories_subdir, f"{index}.png")
        save_images(trajectory_observations, trajectory_path)
        scrub_slice = get_scrub_slice(trajectory_observations.shape[2])
        scrub = trajectory_observations[:, :, scrub_slice, :]
        scrub_path = os.path.join(trajectories_scrub_subdir, f"{index}.png")
        save_images(scrub, scrub_path)

    trajectories["policy_logits"] = []
    trajectories["values"] = []
    for trajectory_observations in trajectories["observations"]:
        trajectories["policy_logits"].append(
            batched_get(
                trajectory_observations,
                batch_size,
                lambda minibatch: get_acts(model, policy_logits_name, minibatch
                                           ),
            ))
        trajectories["values"].append(
            batched_get(
                trajectory_observations,
                batch_size,
                lambda minibatch: get_acts(model, value_function_name,
                                           minibatch),
            ))
    trajectories["policy_logits"] = np.array(trajectories["policy_logits"])
    trajectories["values"] = np.array(trajectories["values"])
    trajectories["advantages"] = compute_gae(trajectories,
                                             gae_gamma=gae_gamma,
                                             gae_lambda=gae_lambda)
    if "dones" not in trajectories:
        trajectories["dones"] = np.concatenate(
            [
                trajectories["firsts"][:, 1:],
                np.zeros_like(trajectories["firsts"][:, :1]),
            ],
            axis=-1,
        )

    bookmarks = {
        "high": get_bookmarks(trajectories, sign=1, num=trajectory_bookmarks),
        "low": get_bookmarks(trajectories, sign=-1, num=trajectory_bookmarks),
    }

    nmf_kwargs = {"attr_layer_name": value_function_name}
    if nmf_attr_opts is not None:
        nmf_kwargs["attr_opts"] = nmf_attr_opts
    nmfs = {
        layer_key: LayerNMF(
            model,
            layer_name,
            observations,
            obses_full=observations_full,
            features=nmf_features,
            **nmf_kwargs,
        )
        for layer_key, layer_name in layer_names.items()
    }

    features = []
    attributions = []
    attribution_totals = []

    for layer_key, layer_name in layer_names.items():
        nmf = nmfs[layer_key]

        if rank == 0:
            thumbnails = []
            for number in range(nmf.features):
                thumbnail = nmf.vis_dataset_thumbnail(
                    number,
                    num_mult=vis_thumbnail_num_mult,
                    expand_mult=vis_thumbnail_expand_mult,
                )[0]
                thumbnail = rescale_opacity(thumbnail,
                                            max_scale=1,
                                            keep_zeros=True)
                thumbnails.append(thumbnail)
            thumbnails_path = os.path.join(thumbnails_subdir,
                                           f"{squash(layer_key)}.png")
            save_images(thumbnails, thumbnails_path)

        for _, number in mpi_enumerate(range(nmf.features)):
            feature = {
                "layer": layer_key,
                "number": number,
                "images": [],
                "overlay_grids": [],
                "metadata": {
                    "subdiv_mult": [],
                    "expand_mult": []
                },
            }
            for subdiv_mult in vis_subdiv_mults:
                for expand_mult in vis_expand_mults:
                    image, overlay_grid = nmf.vis_dataset(
                        number,
                        subdiv_mult=subdiv_mult,
                        expand_mult=expand_mult)
                    image = rescale_opacity(image)
                    filename_root = (f"{squash(layer_key)}_"
                                     f"feature{number}_"
                                     f"{number_to_string(subdiv_mult)}_"
                                     f"{number_to_string(expand_mult)}")
                    image_filename = filename_root + ".png"
                    overlay_grid_filename = filename_root + ".json"
                    image_path = os.path.join(features_subdir, image_filename)
                    overlay_grid_path = os.path.join(features_grids_subdir,
                                                     overlay_grid_filename)
                    save_image(image, image_path)
                    save_json(overlay_grid, overlay_grid_path)
                    feature["images"].append(image_filename)
                    feature["overlay_grids"].append(overlay_grid_filename)
                    feature["metadata"]["subdiv_mult"].append(subdiv_mult)
                    feature["metadata"]["expand_mult"].append(expand_mult)
            features.append(feature)

    for layer_key, layer_name in ([
        (input_layer_name, None)
    ] if input_layer_include else []) + list(layer_names.items()):
        if layer_name is None:
            nmf = None
        else:
            nmf = nmfs[layer_key]

        for index, trajectory_observations in mpi_enumerate(
                trajectories["observations"]):
            attribution = {
                "layer": layer_key,
                "trajectory": index,
                "images": [],
                "metadata": {
                    "type": [],
                    "data": [],
                    "direction": [],
                    "channel": []
                },
            }
            if layer_name is not None:
                totals = {
                    "layer": layer_key,
                    "trajectory": index,
                    "channels": [],
                    "residuals": [],
                    "metadata": {
                        "type": [],
                        "data": []
                    },
                }

            def get_attr_minibatch(minibatch,
                                   output_name,
                                   *,
                                   score_fn=default_score_fn):
                if layer_name is None:
                    return get_grad(model,
                                    output_name,
                                    minibatch,
                                    score_fn=score_fn)
                elif attr_max_paths is None:
                    return get_attr(
                        model,
                        output_name,
                        layer_name,
                        minibatch,
                        score_fn=score_fn,
                        integrate_steps=attr_integrate_steps,
                    )
                else:
                    return get_multi_path_attr(
                        model,
                        output_name,
                        layer_name,
                        minibatch,
                        nmf,
                        score_fn=score_fn,
                        integrate_steps=attr_integrate_steps,
                        max_paths=attr_max_paths,
                    )

            def get_attr_batched(output_name, *, score_fn=default_score_fn):
                return batched_get(
                    trajectory_observations,
                    batch_size,
                    lambda minibatch: get_attr_minibatch(
                        minibatch, output_name, score_fn=score_fn),
                )

            def transform_attr(attr):
                if layer_name is None:
                    return attr, None
                else:
                    attr_trans = nmf.transform(np.maximum(
                        attr, 0)) - nmf.transform(np.maximum(-attr, 0))
                    attr_res = (
                        attr -
                        (nmf.inverse_transform(np.maximum(attr_trans, 0)) -
                         nmf.inverse_transform(np.maximum(-attr_trans, 0)))
                    ).sum(-1, keepdims=True)
                    nmf_norms = nmf.channel_dirs.sum(-1)
                    return attr_trans * nmf_norms[None, None, None], attr_res

            def save_attr(attr, attr_res, *, type_, data):
                if attr_res is None:
                    attr_res = np.zeros_like(attr).sum(-1, keepdims=True)
                filename_root = f"{squash(layer_key)}_{index}_{type_}"
                if data is not None:
                    filename_root = f"{filename_root}_{data}"
                if layer_name is not None:
                    channels_filename = f"{filename_root}_channels.json"
                    residuals_filename = f"{filename_root}_residuals.json"
                    channels_path = os.path.join(attribution_totals_subdir,
                                                 channels_filename)
                    residuals_path = os.path.join(attribution_totals_subdir,
                                                  residuals_filename)
                    save_json(attr.sum(-2).sum(-2), channels_path)
                    save_json(attr_res[..., 0].sum(-1).sum(-1), residuals_path)
                    totals["channels"].append(channels_filename)
                    totals["residuals"].append(residuals_filename)
                    totals["metadata"]["type"].append(type_)
                    totals["metadata"]["data"].append(data)
                attr_scale = np.median(attr.max(axis=(-3, -2, -1)))
                if attr_scale == 0:
                    attr_scale = attr.max()
                if attr_scale == 0:
                    attr_scale = 1
                attr_scaled = attr / attr_scale
                attr_res_scaled = attr_res / attr_scale
                channels = ["prin", "all"]
                if attr_single_channels and layer_name is not None:
                    channels += list(range(nmf.features)) + ["res"]
                for direction in ["abs", "pos", "neg"]:
                    if direction == "abs":
                        attr = np.abs(attr_scaled)
                        attr_res = np.abs(attr_res_scaled)
                    elif direction == "pos":
                        attr = np.maximum(attr_scaled, 0)
                        attr_res = np.maximum(attr_res_scaled, 0)
                    elif direction == "neg":
                        attr = np.maximum(-attr_scaled, 0)
                        attr_res = np.maximum(-attr_res_scaled, 0)
                    for channel in channels:
                        if isinstance(channel, int):
                            attr_single = attr.copy()
                            attr_single[..., :channel] = 0
                            attr_single[..., (channel + 1):] = 0
                            images = channels_to_rgb(attr_single)
                        elif channel == "res":
                            images = attr_res.repeat(3, axis=-1)
                        else:
                            images = channels_to_rgb(attr)
                            if channel == "all":
                                images += attr_res.repeat(3, axis=-1)
                        images = brightness_to_opacity(
                            conv2d(images, filter_=norm_filter(15)))
                        suffix = f"{direction}_{channel}"
                        images_filename = f"{filename_root}_{suffix}.png"
                        images_path = os.path.join(attribution_subdir,
                                                   images_filename)
                        save_images(images, images_path)
                        scrub = images[:, :,
                                       get_scrub_slice(images.shape[2]), :]
                        scrub_path = os.path.join(attribution_scrub_subdir,
                                                  images_filename)
                        save_images(scrub, scrub_path)
                        attribution["images"].append(images_filename)
                        attribution["metadata"]["type"].append(type_)
                        attribution["metadata"]["data"].append(data)
                        attribution["metadata"]["direction"].append(direction)
                        attribution["metadata"]["channel"].append(channel)

            attr_v = get_attr_batched(value_function_name)
            attr_v_trans, attr_v_res = transform_attr(attr_v)
            save_attr(attr_v_trans, attr_v_res, type_="v", data=None)
            if attr_policy:
                attr_actions = np.array([
                    get_attr_batched(
                        policy_logits_name,
                        score_fn=lambda t: t[..., action],
                    ) for action in range(len(action_combos))
                ])
                # attr_pi = attr_actions.sum(axis=-1).transpose(
                #     (1, 2, 3, 0))
                # attr_pi = np.concatenate([
                #     attr_pi[..., group].sum(axis=-1, keepdims=True)
                #     for group in attr_action_groups
                # ],
                #                          axis=-1)
                # save_attr(attr_pi, None, type_='pi', data=None)
                for action, attr in enumerate(attr_actions):
                    attr_trans, attr_res = transform_attr(attr)
                    save_attr(attr_trans,
                              attr_res,
                              type_="action",
                              data=action)
                for action_group, actions in enumerate(action_groups):
                    attr = attr_actions[actions].sum(axis=0)
                    attr_trans, attr_res = transform_attr(attr)
                    save_attr(attr_trans,
                              attr_res,
                              type_="action_group",
                              data=action_group)
            attributions.append(attribution)
            if layer_name is not None:
                attribution_totals.append(totals)

    features = comm.gather(features, root=0)
    attributions = comm.gather(attributions, root=0)
    attribution_totals = comm.gather(attribution_totals, root=0)

    if rank == 0:
        features = [feature for l in features for feature in l]
        attributions = [attribution for l in attributions for attribution in l]
        attribution_totals = [
            totals for l in attribution_totals for totals in l
        ]
        layer_keys = ([input_layer_name] if input_layer_include else
                      []) + list(layer_names.keys())
        action_colors = get_html_colors(
            len(action_combos),
            grayscale=True,
            mix_with=np.array([0.75, 0.75, 0.75]),
            mix_weight=0.25,
        )
        props = {
            "input_layer": input_layer_name,
            "layers": layer_keys,
            "features": features,
            "attributions": attributions,
            "attribution_policy": attr_policy,
            "attribution_single_channels": attr_single_channels,
            "attribution_totals": attribution_totals,
            "colors": {
                "features": get_html_colors(nmf_features),
                "actions": action_colors,
                "graphs": graph_colors,
                "trajectory": trajectory_color,
            },
            "action_combos": action_combos,
            "action_groups": action_groups,
            "trajectories": {
                "actions": trajectories["actions"],
                "rewards": trajectories["rewards"],
                "dones": trajectories["dones"],
                "policy_logits": trajectories["policy_logits"],
                "values": trajectories["values"],
                "advantages": trajectories["advantages"],
            },
            "bookmarks": bookmarks,
            "vis_defaults": {
                "subdiv_mult": vis_subdiv_mult_default,
                "expand_mult": vis_expand_mult_default,
            },
            "subdirs": {
                "observations": observations_subdir,
                "trajectories": trajectories_subdir,
                "trajectories_scrub": trajectories_scrub_subdir,
                "features": features_subdir,
                "thumbnails": thumbnails_subdir,
                "attribution": attribution_subdir,
                "attribution_scrub": attribution_scrub_subdir,
                "features_grids": features_grids_subdir,
                "attribution_totals": attribution_totals_subdir,
            },
            "formatting": {
                "video_height": video_height,
                "video_width": video_width,
                "video_speed": video_speed,
                "policy_display_height": policy_display_height,
                "policy_display_width": policy_display_width,
                "navigator_width": navigator_width,
                "scrubber_height": scrubber_height,
                "scrubber_width": scrubber_width,
                "scrubber_visible_duration": scrubber_visible_duration,
                "legend_item_height": legend_item_height,
                "legend_item_width": legend_item_width,
                "feature_viewer_height": feature_viewer_height,
                "feature_viewer_width": feature_viewer_width,
                "attribution_weight": attribution_weight,
            },
            "json_preloaded": json_preloaded,
        }

        if inline_js:
            js_path = js_source_path
        else:
            with open(js_source_path, "r") as fp:
                js_code = fp.read()
            js_path = os.path.join(output_dir, "interface.js")
            with write_handle(js_path, "w") as fp:
                fp.write(js_code)
        html_path = os.path.join(output_dir, "interface.html")
        compile_html(
            js_path,
            html_path=html_path,
            props=props,
            precision=numpy_precision,
            inline_js=inline_js,
            svelte_to_js=False,
        )
        if output_dir.startswith("gs://"):
            if not inline_js:
                subprocess.run([
                    "gsutil",
                    "setmeta",
                    "-h",
                    "Content-Type: text/javascript",
                    js_path,
                ])
            subprocess.run([
                "gsutil", "setmeta", "-h", "Content-Type: text/html", html_path
            ])
        elif output_dir.startswith("https://"):
            output_dir_parsed = urllib.parse.urlparse(output_dir)
            az_account, az_hostname = output_dir_parsed.netloc.split(".", 1)
            if az_hostname == "blob.core.windows.net":
                az_container = removeprefix(output_dir_parsed.path,
                                            "/").split("/")[0]
                az_prefix = f"https://{az_account}.{az_hostname}/{az_container}/"
                if not inline_js:
                    js_az_name = removeprefix(js_path, az_prefix)
                    subprocess.run([
                        "az",
                        "storage",
                        "blob",
                        "update",
                        "--container-name",
                        az_container,
                        "--name",
                        js_az_name,
                        "--account-name",
                        az_account,
                        "--content-type",
                        "application/javascript",
                    ])
                html_az_name = removeprefix(html_path, az_prefix)
                subprocess.run([
                    "az",
                    "storage",
                    "blob",
                    "update",
                    "--container-name",
                    az_container,
                    "--name",
                    html_az_name,
                    "--account-name",
                    az_account,
                    "--content-type",
                    "text/html",
                ])
Exemple #13
0
def save(thing, url_or_handle, **kwargs):
    """Save object to file on CNS.

    File format is inferred from path. Use save_img(), save_npy(), or save_json()
    if you need to force a particular format.

    Args:
      obj: object to save.
      path: CNS path.

    Raises:
      RuntimeError: If file extension not supported.
    """

    # Determine context
    # Is this a handle? What is the extension? Are we saving to GCS?
    is_handle = hasattr(url_or_handle, "write") and hasattr(
        url_or_handle, "name")
    if is_handle:
        path = url_or_handle.name
    else:
        path = url_or_handle

    _, ext = os.path.splitext(path)
    is_gcs = path.startswith("gs://")

    if not ext:
        raise RuntimeError("No extension in URL: " + path)

    # Determine which saver should be used
    if ext in savers:
        saver = savers[ext]
    elif isinstance(thing, str):
        saver = save_str
    else:
        message = "Unknown extension '{}'. As a result, only strings can be saved, not {}. Supported extensions: {}"
        raise ValueError(
            message.format(ext,
                           type(thing).__name__, list(savers.keys())))

    # Actually save
    if is_handle:
        result = saver(thing, url_or_handle, **kwargs)
    else:
        with write_handle(url_or_handle) as handle:
            result = saver(thing, handle, **kwargs)

    # Set mime type on gcs if html -- usually, when one saves an html to GCS,
    # they want it to be viewsable as a website.
    if is_gcs and ext == ".html":
        subprocess.run([
            "gsutil", "setmeta", "-h",
            "Content-Type: text/html; charset=utf-8", path
        ])
    if is_gcs and ext == ".json":
        subprocess.run([
            "gsutil", "setmeta", "-h", "Content-Type: application/json", path
        ])

    # capture save if a save context is available
    if this.save_contexts:
        log.debug(
            "capturing save: resulted in {} -> {} in save_context {}".format(
                result, path, this.save_contexts[-1]))
        this.save_contexts[-1].capture(result)

    if result is not None and "url" in result and result["url"].startswith(
            "gs://"):
        result["serve"] = "https://storage.googleapis.com/{}".format(
            result["url"][5:])

    return result
Exemple #14
0
def run(model, ops):

    import numpy as np
    import tensorflow as tf
    import math
    import urllib.parse
    import sklearn

    from umap import UMAP

    from lucid.misc.io import load, show, save

    from clarity.dask.cluster import get_client

    import lucid.optvis.objectives as objectives
    import lucid.optvis.param as param
    import lucid.optvis.render as render
    import lucid.optvis.transform as transform

    from lucid.modelzoo.vision_models import InceptionV1, AlexNet
    import matplotlib.pyplot as plt
    from lucid.misc.io.writing import write_handle

    from clarity.utils.distribute import DistributeDask, DistributeMPI
    from lucid.modelzoo.nets_factory import models_map, get_model

    # Produced by the "collect_activations" notebook
    def load_activations(model,
                         op_name,
                         num_activations=100,
                         batch_size=4096,
                         num_activations_per_image=1):
        activations_collected_per_image = 16  # This is hardcoded from the collection process
        if num_activations_per_image > activations_collected_per_image:
            raise ValueError(
                "Attempting to use more activations than were collected per image."
            )
        activations = []
        coordinates = []
        for s in range(0,
                       math.ceil(num_activations / num_activations_per_image),
                       batch_size):
            e = s + batch_size
            # acts_per_image=16&end=1003520&model=AlexNet&sampling_strategy=random&split=train&start=999424
            loaded_activations = load(
                f"gs://openai-clarity/encyclopedia/collect_activations/acts_per_image=16&end={e}&model={model.name}&sampling_strategy=random&split=train&start={s}/{op_name}-activations.npy"
            )
            loaded_coordinates = load(
                f"gs://openai-clarity/encyclopedia/collect_activations/acts_per_image=16&end={e}&model={model.name}&sampling_strategy=random&split=train&start={s}/{op_name}-image_crops.npy"
            )

            activations.append(
                loaded_activations[:, 0:num_activations_per_image, :])
            coordinates.append(
                loaded_coordinates[:, 0:num_activations_per_image, :])
        acts = np.concatenate(activations)
        flattened_acts = acts.reshape(
            (acts.shape[0] * acts.shape[1], acts.shape[2]))

        coords = np.concatenate(coordinates)
        flattened_coords = coords.reshape(
            (coords.shape[0] * coords.shape[1], coords.shape[2]))
        return flattened_acts[:num_activations,
                              ], flattened_coords[:num_activations, ]

    def load_ops(model):

        # Load the metadata info so we can get a list of the ops
        metadata = load(
            f"gs://openai-clarity/encyclopedia/graph_metadata/model={model.name}/metadata.json"
        )
        # Filter the ops list to only the ones that we are interested in
        ops = [(op_key, op['channels'])
               for op_key, op in metadata['ops'].items()
               if op['op_type'] in ('Relu', 'Conv2D') and op['rank'] == 4]
        return ops

    def bin_laid_out_activations(layout,
                                 activations,
                                 partition,
                                 grid_size,
                                 threshold=5):

        n = activations.shape[0]

        assert layout.shape[0] == activations.shape[0]
        assert n % 2 == 0

        # calculate which grid cells each activation's layout position falls into
        # first bin stays empty because nothing should be < 0, so we add an extra bin
        bins = np.linspace(0, 1, num=grid_size + 1)
        bins[-1] = np.inf  # last bin should include all higher values
        indices = np.digitize(
            layout, bins) - 1  # subtract 1 to account for empty first bin

        means_x, means_y, coordinates, counts_x, counts_y = [], [], [], [], []

        grid_coordinates = np.indices(
            (grid_size, grid_size)).transpose().reshape(-1, 2)
        for xy_coordinates in grid_coordinates:
            mask = np.equal(xy_coordinates, indices).all(axis=1)
            count_x = np.count_nonzero(mask[0:n // 2])
            count_y = np.count_nonzero(mask[n // 2:])
            if (count_x + count_y) > threshold:
                counts_x.append(count_x)
                counts_y.append(count_y)
                coordinates.append(xy_coordinates)
                means_x.append(
                    np.average(activations[0:n // 2][mask[0:n // 2]],
                               axis=0)[0:partition])
                means_y.append(
                    np.average(activations[n // 2:][mask[n // 2:]],
                               axis=0)[partition:])

        return coordinates, means_x, means_y, counts_x, counts_y

    def get_optimal_maps(X, Y):

        Σ_XX = X.transpose() @ X
        Σ_XY = X.transpose() @ Y
        Σ_YY = Y.transpose() @ Y
        Σ_YX = Σ_XY.transpose()

        A_XY = Σ_XY @ np.linalg.inv(Σ_YY)
        A_YX = Σ_YX @ np.linalg.inv(Σ_XX)

        Xhat = Y @ A_XY.transpose()
        Yhat = X @ A_YX.transpose()

        errx = np.sqrt(np.mean((Y @ A_XY.transpose() - X)**2))
        erry = np.sqrt(np.mean((X @ A_YX.transpose() - Y)**2))

        err_baseline_x = np.sqrt(np.mean((X - np.mean(X, 0))**2))
        err_baseline_y = np.sqrt(np.mean((Y - np.mean(Y, 0))**2))

        print(errx, err_baseline_x)
        print(erry, err_baseline_y)

        return A_XY, A_YX, Xhat, Yhat, (errx, err_baseline_x), (erry,
                                                                err_baseline_y)

    def dim_reduce(Z, method="umap"):

        if method == "svd":
            U, S, V = np.linalg.svd(Z, full_matrices=False)
            layout = U[:, 0:2]
            return layout

        if method == "umap":
            umap_defaults = dict(n_components=2,
                                 n_neighbors=50,
                                 min_dist=0.05,
                                 verbose=True,
                                 metric="cosine")
            layout = UMAP(**umap_defaults).fit_transform(Z)
            return layout

    def get_atlas(model, ops):

        model_x, model_y = get_model(model[0]), get_model(model[1])

        model_x.load_graphdef()
        model_y.load_graphdef()

        X = np.concatenate([
            load_activations(model_x, op, num_activations=50000)[0]
            for op in ops[0]
        ], 1).astype(np.float32)
        Y = np.concatenate([
            load_activations(model_y, op, num_activations=50000)[0]
            for op in ops[1]
        ], 1).astype(np.float32)

        A_XY, A_YX, Xhat, Yhat, errx, erry = get_optimal_maps(X, Y)

        Xc = np.concatenate([X, Yhat], axis=-1)
        Yc = np.concatenate([Xhat, Y], axis=-1)

        Z = np.concatenate([Xc, Yc])

        layout = dim_reduce(Z, method="umap")

        layout_centered = (layout - np.min(layout, 0))
        layout_centered = layout_centered / np.max(layout_centered, 0)

        coordinates, means_x, means_y, counts_x, counts_y = bin_laid_out_activations(
            layout_centered, Z, X.shape[1], 20)

        coordinates_x = np.array(coordinates)
        counts_x = np.array(counts_x)
        counts_y = np.array(counts_y)
        means_x = np.array(means_x)
        means_y = np.array(means_y)

        return coordinates, means_x, means_y, counts_x, counts_y, errx, erry, A_XY, A_YX, layout_centered

    import json
    import hashlib
    identifier = hashlib.md5(json.dumps(
        (model, ops)).encode('utf-8')).hexdigest()

    def pre_relu(name):
        if "mixed" in name:
            return (f"{name}_1x1:0", f"{name}_3x3:0", f"{name}_5x5:0",
                    f"{name}_pool_reduce:0")
        else:
            return [name + ":0"]

    coordinates, means_x, means_y, counts_x, counts_y, errx, erry, A_XY, A_YX, layout = \
        get_atlas(model, [pre_relu(ops[0]), pre_relu(ops[1])])

    plt.figure(figsize=(10, 10))
    plt.scatter(layout[0:layout.shape[0] // 2, 0],
                layout[0:layout.shape[0] // 2, 1], 1, "b")
    plt.scatter(layout[layout.shape[0] // 2:, 0], layout[layout.shape[0] // 2:,
                                                         1], 1, "r")
    plt.show()
    with write_handle(f"gs://clarity-public/ggoh/Diff/{identifier}/scatter.png"
                      ) as handle:
        plt.savefig(handle)

    manifest = {
        "model_x": model[0],
        "model_y": model[1],
        "ops_x": ops[0],
        "ops_y": ops[1],
        "coordinates": coordinates,
        "counts_x": counts_x,
        "counts_y": counts_y,
        "means_x": means_x,
        "means_y": means_y,
        "err_x": errx,
        "err_y": erry,
        "layout": layout,
        "A_XY": A_XY,
        "A_YX": A_YX,
        "identifier": identifier
    }

    print("Identifier", identifier)
    print(
        save(manifest,
             f"gs://clarity-public/ggoh/Diff/{identifier}/manifest.json"))

    del manifest["means_x"]
    del manifest["means_y"]
    del manifest["A_XY"]
    del manifest["A_YX"]

    manifest["layout"] = np.concatenate([
        layout[0:5000],
        layout[layout.shape[0] // 2:layout.shape[0] // 2 + 5000]
    ]).astype(np.float16)

    print(
        save(manifest,
             f"gs://clarity-public/ggoh/Diff/{identifier}/manifest_slim.json"))