def mode_allworkers_saveall(out_dir, mode):
    path = build_json(out_dir, include_workers="all", save_all=True)
    num_workers = len(get_available_gpus())
    mode_args = ["--model_dir", out_dir]
    launch_smdataparallel_job(
        script_file_path=SMDATAPARALLEL_TF2_TEST_MNIST_SCRIPT,
        script_args=mode_args,
        num_workers=num_workers,
        config_file_path=path,
        mode=mode,
    )
    tr = create_trial(out_dir)
    assert len(tr.workers()) == num_workers
    assert len(tr.tensor_names()) == 35
    assert len(tr.tensor(
        tr.tensor_names(collection="weights")[0]).workers(0)) == num_workers
    assert len(tr.tensor("loss").workers(0)) == num_workers
def mode_allworkers_saveall(out_dir, mode):
    path = build_json(out_dir, include_workers="all", save_all=True)
    num_workers = 1 if bool(device_count()) is False else device_count()
    mode_args = list(SMDATAPARALLEL_PYTORCH_TEST_MNIST_ARGS)
    launch_smdataparallel_job(
        script_file_path=SMDATAPARALLEL_PYTORCH_TEST_MNIST_SCRIPT,
        script_args=mode_args,
        num_workers=num_workers,
        config_file_path=path,
        mode=mode,
    )
    tr = create_trial(out_dir)
    assert len(tr.workers()) == num_workers
    assert len(tr.tensor_names()) > 25
    assert len(tr.tensor(
        tr.tensor_names(collection="weights")[0]).workers(0)) == num_workers
    assert len(tr.tensor(
        tr.tensor_names(collection="losses")[0]).workers(0)) == num_workers
def mode_allworkers_default_collections(out_dir, mode):
    path = build_json(out_dir,
                      include_workers="all",
                      include_collections=TF_DEFAULT_SAVED_COLLECTIONS)
    num_workers = len(get_available_gpus())
    mode_args = ["--model_dir", out_dir]
    launch_smdataparallel_job(
        script_file_path=SMDATAPARALLEL_TF2_TEST_MNIST_SCRIPT,
        script_args=mode_args,
        num_workers=num_workers,
        config_file_path=path,
        mode=mode,
    )
    tr = create_trial(out_dir)
    assert len(tr.workers()) == num_workers
    assert len(tr.tensor_names()) == 1
    assert len(tr.tensor(
        tr.tensor_names(collection="losses")[0]).workers(0)) == num_workers
def mode_allworkers(out_dir, mode):
    path = build_json(out_dir,
                      include_workers="all",
                      include_collections=["weights", "optimizer_variables"])
    num_workers = len(get_available_gpus())
    mode_args = ["--model_dir", out_dir]
    launch_smdataparallel_job(
        script_file_path=SMDATAPARALLEL_TF2_TEST_MNIST_SCRIPT,
        script_args=mode_args,
        num_workers=num_workers,
        config_file_path=path,
        mode=mode,
    )
    tr = create_trial(out_dir)
    assert len(tr.workers()) == num_workers
    print("tensor names: ", tr.tensor_names())
    assert len(tr.tensor_names()) == 5
    assert len(tr.tensor(
        tr.tensor_names(collection="weights")[0]).workers(0)) == num_workers