def maybe_download(dataset_root):
    if not dataset_root or not os.path.exists(
            os.path.join(dataset_root, "AgentHuman")):
        screen.log_title(
            "Downloading the CARLA dataset. This might take a while.")

        google_drive_download_id = "1hloAeyamYn-H6MfV1dRtY1gJPhkR55sY"
        filename_to_save = "datasets/CORL2017ImitationLearningData.tar.gz"
        download_command = 'wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=' \
                           '$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies ' \
                           '--no-check-certificate \"https://docs.google.com/uc?export=download&id={}\" -O- | ' \
                           'sed -rn \'s/.*confirm=([0-9A-Za-z_]+).*/\\1\\n/p\')&id={}" -O {} && rm -rf /tmp/cookies.txt'\
                           .format(google_drive_download_id, google_drive_download_id, filename_to_save)

        # start downloading and wait for it to finish
        start_shell_command_and_wait(download_command)

        screen.log_title("Unzipping the dataset")
        unzip_command = 'tar -xzf {} --checkpoint=.10000'.format(
            filename_to_save)
        if dataset_root is not None:
            unzip_command += " -C {}".format(dataset_root)

        if not os.path.exists(dataset_root):
            os.makedirs(dataset_root)
        start_shell_command_and_wait(unzip_command)
Esempio n. 2
0
def save_onnx_graph(input_nodes, output_nodes,
                    checkpoint_save_dir: str) -> None:
    """
    Given the input nodes and output nodes of the TF graph, save it as an onnx graph
    This requires the TF graph and the weights checkpoint to be stored in the experiment directory.
    It then freezes the graph (merging the graph and weights checkpoint), and converts it to ONNX.

    :param input_nodes: A list of input nodes for the TF graph
    :param output_nodes: A list of output nodes for the TF graph
    :param checkpoint_save_dir: The directory to save the ONNX graph to
    :return: None
    """
    import tf2onnx  # just to verify that tf2onnx is installed

    # freeze graph
    frozen_graph_path = os.path.join(checkpoint_save_dir, "frozen_graph.pb")
    freeze_graph_command = [
        "python -m tensorflow.python.tools.freeze_graph",
        "--input_graph={}".format(
            os.path.join(checkpoint_save_dir, "graphdef.pb")),
        "--input_binary=true", "--output_node_names='{}'".format(','.join([
            o.split(":")[0] for o in output_nodes
        ])), "--input_checkpoint={}".format(
            tf.train.latest_checkpoint(checkpoint_save_dir)),
        "--output_graph={}".format(frozen_graph_path)
    ]
    start_shell_command_and_wait(" ".join(freeze_graph_command))

    # convert graph to onnx
    onnx_graph_path = os.path.join(checkpoint_save_dir, "model.onnx")
    convert_to_onnx_command = [
        "python -m tf2onnx.convert", "--input {}".format(frozen_graph_path),
        "--inputs '{}'".format(','.join(input_nodes)),
        "--outputs '{}'".format(','.join(output_nodes)),
        "--output {}".format(onnx_graph_path), "--verbose"
    ]
    start_shell_command_and_wait(" ".join(convert_to_onnx_command))