Ejemplo n.º 1
0
 def check_make_dir(path):
     path = pathlib.Path(path)
     if path.exists():
         if not path.is_dir():
             tools.fatal(
                 f"Given path {str(path)!r} must point to a directory")
     else:
         path.mkdir(mode=0o755, parents=True)
     return path
Ejemplo n.º 2
0
        def generate_defense(gars):
            # Preprocess given configuration
            freq_sum = 0.
            defenses = list()
            for info in gars.split(";"):
                # Parse GAR info
                info = info.split(",", maxsplit=2)
                name = info[0].strip()
                if len(info) >= 2:
                    freq = info[1].strip()
                    if freq == "-":
                        freq = 1.
                    else:
                        freq = float(freq)
                else:
                    freq = 1.
                if len(info) >= 3:
                    try:
                        conf = json.loads(info[2].strip())
                        if not isinstance(conf, dict):
                            tools.fatal(
                                f"Invalid GAR arguments for GAR {name!r}: expected a dictionary, got {getattr(type(conf), '__qualname__', '<unknown>')!r}"
                            )
                    except json.decoder.JSONDecodeError as err:
                        tools.fatal(
                            f"Invalid GAR arguments for GAR {name!r}: {str(err).lower()}"
                        )
                else:
                    conf = dict()
                # Recover association GAR function
                defense = aggregators.gars.get(name)
                if defense is None:
                    tools.fatal_unavailable(aggregators.gars,
                                            name,
                                            what="aggregation rule")
                # Store parsed defense
                freq_sum += freq
                defenses.append((defense, freq_sum, conf))
            # Return closure
            def unchecked(**kwargs):
                sel = random.random() * freq_sum
                for func, freq, conf in defenses:
                    if sel < freq:
                        return func.unchecked(**kwargs, **conf)
                return func.unchecked(
                    **kwargs, **conf)  # Gracefully handle numeric imprecision

            def check(**kwargs):
                for defense, _, conf in defenses:
                    message = defense.check(**kwargs, **conf)
                    if message is not None:
                        return message

            return aggregators.make_gar(unchecked, check)
Ejemplo n.º 3
0
 def _transform_schedule(schedule):
     itr = iter(schedule)
     yield (0, next(itr))
     last = 0
     while True:
         try:
             step = next(itr)
         except StopIteration:
             return
         if step <= last:
             tools.fatal(
                 f"Invalid arguments: learning rate schedule step numbers must by strictly increasing"
             )
         yield (step, next(itr))
         last = step
Ejemplo n.º 4
0
    parser.add_argument(
        "--supercharge",
        type=int,
        default=1,
        help=
        "How many experiments are run in parallel per device, must be positive"
    )
    # Parse command line
    return parser.parse_args(sys.argv[1:])


with tools.Context("cmdline", "info"):
    args = process_commandline()
    # Check the "supercharge" parameter
    if args.supercharge < 1:
        tools.fatal(
            f"Expected a positive supercharge value, got {args.supercharge}")
    # Make the result directories
    def check_make_dir(path):
        path = pathlib.Path(path)
        if path.exists():
            if not path.is_dir():
                tools.fatal(
                    f"Given path {str(path)!r} must point to a directory")
        else:
            path.mkdir(mode=0o755, parents=True)
        return path

    args.data_directory = check_make_dir(args.data_directory)
    args.plot_directory = check_make_dir(args.plot_directory)
    # Preprocess/resolve the devices to use
    if args.devices == "auto":
Ejemplo n.º 5
0
    parser.add_argument(
        "--supercharge",
        type=int,
        default=1,
        help=
        "How many experiments are run in parallel per device, must be positive"
    )
    # Parse command line
    return parser.parse_args(sys.argv[1:])


with tools.Context("cmdline", "info"):
    args = process_commandline()
    # Check the "supercharge" parameter
    if args.supercharge < 1:
        tools.fatal("Expected a positive supercharge value, got %d" %
                    args.supercharge)
    # Make the result directories
    def check_make_dir(path):
        path = pathlib.Path(path)
        if path.exists():
            if not path.is_dir():
                tools.fatal("Given path %r must point to a directory" %
                            (str(path), ))
        else:
            path.mkdir(mode=0o755, parents=True)
        return path

    args.data_directory = check_make_dir(args.data_directory)
    args.plot_directory = check_make_dir(args.plot_directory)
    # Preprocess/resolve the devices to use
    if args.devices == "auto":
Ejemplo n.º 6
0
def fatal_unavailable(*args, **kwargs):
  """ Helper forwarding the 'UnavailableException' explanatory string to 'fatal'.
  Args:
    ... Forward (keyword-)arguments to 'make_unavailable_exception_text'
  """
  tools.fatal(make_unavailable_exception_text(*args, **kwargs))
Ejemplo n.º 7
0
    default=0,
    help="How many training steps between two prompts for user command inputs, 0 for no user input")
  # Parse command line
  return parser.parse_args(sys.argv[1:])

with tools.Context("cmdline", "info"):
  args = process_commandline()
  # Parse additional arguments
  for name in ("gar", "attack", "model", "dataset", "loss", "criterion"):
    name = f"{name}_args"
    keyval = getattr(args, name)
    setattr(args, name, dict() if keyval is None else tools.parse_keyval(keyval))
  # Count the number of real honest workers
  args.nb_honests = args.nb_workers - args.nb_real_byz
  if args.nb_honests < 0:
    tools.fatal(f"Invalid arguments: there are more real Byzantine workers ({args.nb_real_byz}) than total workers ({args.nb_workers})")
  # Check general training parameters
  if args.momentum < 0.:
    tools.fatal(f"Invalid arguments: negative momentum factor {args.momentum}")
  if args.dampening < 0.:
    tools.fatal(f"Invalid arguments: negative dampening factor {args.dampening}")
  if args.weight_decay < 0.:
    tools.fatal(f"Invalid arguments: negative weight decay factor {args.weight_decay}")
  # Check the learning rate and associated options
  if args.learning_rate <= 0:
    tools.fatal(f"Invalid arguments: non-positive learning rate {args.learning_rate}")
  if args.learning_rate_decay < 0:
    tools.fatal(f"Invalid arguments: negative learning rate decay {args.learning_rate_decay}")
  if args.learning_rate_decay_delta <= 0:
    tools.fatal(f"Invalid arguments: non-positive learning rate decay delta {args.learning_rate_decay_delta}")
  # Check the privacy-related metrics
Ejemplo n.º 8
0

with tools.Context("cmdline", "info"):
    args = process_commandline()
    # Parse additional arguments
    for name in ("init_multi", "init_mono", "gar", "attack", "model", "loss",
                 "criterion"):
        name = f"{name}_args"
        keyval = getattr(args, name)
        setattr(args, name,
                dict() if keyval is None else tools.parse_keyval(keyval))
    # Count the number of real honest workers
    args.nb_honests = args.nb_workers - args.nb_real_byz
    if args.nb_honests < 0:
        tools.fatal(
            f"Invalid arguments: there are more real Byzantine workers ({args.nb_real_byz}) than total workers ({args.nb_workers})"
        )
    # Check the learning rate and associated options
    if args.learning_rate_schedule is None:
        if args.learning_rate <= 0:
            tools.fatal(
                f"Invalid arguments: non-positive learning rate {args.learning_rate}"
            )
        if args.learning_rate_decay < 0:
            tools.fatal(
                f"Invalid arguments: negative learning rate decay {args.learning_rate_decay}"
            )
        if args.learning_rate_decay_delta <= 0:
            tools.fatal(
                f"Invalid arguments: non-positive learning rate decay delta {args.learning_rate_decay_delta}"
            )
Ejemplo n.º 9
0
def _build_and_load():
    """ Incrementally rebuild all libraries and bind all local modules in the global.
  """
    glob = globals()
    # Standard imports
    import os
    import pathlib
    import traceback
    import warnings
    # External imports
    import torch
    import torch.utils.cpp_extension
    # Internal imports
    import tools
    # Constants
    base_directory = pathlib.Path(__file__).parent.resolve()
    dependencies_file = ".deps"
    debug_mode_envname = "NATIVE_OPT"
    debug_mode_in_env = debug_mode_envname in os.environ
    if debug_mode_in_env:
        raw = os.environ[debug_mode_envname]
        value = raw.lower()
        if value in ["0", "n", "no", "false"]:
            debug_mode = True
        elif value in ["1", "y", "yes", "true"]:
            debug_mode = False
        else:
            tools.fatal(
                "%r defined in the environment, but with unexpected soft-boolean %r"
                % (debug_mode_envname, "%s=%s" % (debug_mode_envname, raw)))
    else:
        debug_mode = __debug__
    cpp_std_envname = "NATIVE_STD"
    cpp_std = os.environ.get(cpp_std_envname, "c++14")
    ident_to_is_python = {"so_": False, "py_": True}
    source_suffixes = {".cpp", ".cc", ".C", ".cxx", ".c++"}
    extra_cflags = ["-Wall", "-Wextra", "-Wfatal-errors", "-std=%s" % cpp_std]
    if torch.cuda.is_available():
        source_suffixes.update(
            set((".cu" + suffix) for suffix in source_suffixes))
        source_suffixes.add(".cu")
        extra_cflags.append("-DTORCH_CUDA_AVAILABLE")
    extra_cuda_cflags = [
        "-DTORCH_CUDA_AVAILABLE", "--expt-relaxed-constexpr",
        "-std=%s" % cpp_std
    ]
    extra_ldflags = ["-Wl,-L" + base_directory.root]
    extra_include_path = base_directory / "include"
    try:
        extra_include_paths = [str(extra_include_path.resolve())]
    except Exception:
        extra_include_paths = None
        warnings.warn("Not found include directory: " +
                      repr(str(extra_include_path)))
    # Print configuration information
    cpp_std_message = "Native modules compiled with %s standard; (re)define %r in the environment to compile with another standard" % (
        cpp_std, "%s=<standard>" % cpp_std_envname)
    if debug_mode:
        tools.warning(cpp_std_message)
        tools.warning(
            "Native modules compiled in debug mode; %sdefine %r in the environment or%s run python with -O/-OO options to compile in release mode"
            % ("re" if debug_mode_in_env else "", "%s=1" % debug_mode_envname,
               " undefine it and" if debug_mode_in_env else ""))
        extra_cflags += ["-O0", "-g"]
    else:
        quiet_envname = "NATIVE_QUIET"
        if quiet_envname not in os.environ:
            tools.trace(cpp_std_message)
            tools.trace(
                "Native modules compiled in release mode; %sdefine %r in the environment or%s run python without -O/-OO options to compile in debug mode"
                % ("re" if debug_mode_in_env else "",
                   "%s=0" % debug_mode_envname,
                   " undefine it and" if debug_mode_in_env else ""))
            tools.trace(
                "Define %r in the environment to hide these messages in release mode"
                % quiet_envname)
        extra_cflags += ["-O3", "-DNDEBUG"]
    # Variables
    done_modules = []
    fail_modules = []

    # Local procedures
    def build_and_load_one(path, deps=[]):
        """ Check if the given directory is a module to build and load, and if yes recursively build and load its dependencies before it.
    Args:
      path Given directory path
      deps Dependent module paths
    Returns:
      True on success, False on failure, None if not a module
    """
        nonlocal done_modules
        nonlocal fail_modules
        with tools.Context(path.name, "info"):
            ident = path.name[:3]
            if ident in ident_to_is_python.keys():
                # Is a module directory
                if len(path.name) <= 3 or path.name[3] == "_":
                    tools.warning("Skipped invalid module directory name " +
                                  repr(path.name))
                    return None
                if not path.exists():
                    tools.warning("Unable to build and load " +
                                  repr(str(path.name)) +
                                  ": module does not exist")
                    fail_modules.append(path)  # Mark as failed
                    return False
                is_python_module = ident_to_is_python[ident]
                # Check if already built and loaded, or failed
                if path in done_modules:
                    if len(deps) == 0 and debug_mode:
                        tools.info("Already built and loaded " +
                                   repr(str(path.name)))
                    return True
                if path in fail_modules:
                    if len(deps) == 0:
                        tools.warning("Was unable to build and load " +
                                      repr(str(path.name)))
                    return False
                # Check for dependency cycle (disallowed as they may mess with the linker)
                if path in deps:
                    tools.warning("Unable to build and load " +
                                  repr(str(path.name)) +
                                  ": dependency cycle found")
                    fail_modules.append(path)  # Mark as failed
                    return False
                # Build and load dependencies
                this_ldflags = list(extra_ldflags)
                depsfile = path / dependencies_file
                if depsfile.exists():
                    for modname in depsfile.read_text().splitlines():
                        res = build_and_load_one(base_directory / modname,
                                                 deps + [path])
                        if res == False:  # Unable to build a dependency
                            if len(deps) == 0:
                                tools.warning("Unable to build and load " +
                                              repr(str(path.name)) +
                                              ": dependency " + repr(modname) +
                                              " build and load failed")
                            fail_modules.append(path)  # Mark as failed
                            return False
                        elif res == True:  # Module and its sub-dependencies was/were built and loaded successfully
                            this_ldflags.append("-Wl,--library=:" + str(
                                (base_directory / modname /
                                 (modname + ".so")).resolve()))
                # List sources
                sources = []
                for subpath in path.iterdir():
                    if subpath.is_file() and ("").join(
                            subpath.suffixes) in source_suffixes:
                        sources.append(str(subpath))
                # Build and load this module
                try:
                    res = torch.utils.cpp_extension.load(
                        name=path.name,
                        sources=sources,
                        extra_cflags=extra_cflags,
                        extra_cuda_cflags=extra_cuda_cflags,
                        extra_ldflags=this_ldflags,
                        extra_include_paths=extra_include_paths,
                        build_directory=str(path),
                        verbose=debug_mode,
                        is_python_module=is_python_module)
                    if is_python_module:
                        glob[path.name[3:]] = res
                except Exception as err:
                    tools.warning("Unable to build and load " +
                                  repr(str(path.name)) + ": " + str(err))
                    fail_modules.append(path)  # Mark as failed
                    return False
                done_modules.append(path)  # Mark as built and loaded
                return True

    # Main loop
    for path in base_directory.iterdir():
        if path.is_dir():
            try:
                build_and_load_one(path)
            except Exception as err:
                tools.warning("Exception while processing " + repr(str(path)) +
                              ": " + str(err))
                with tools.Context("traceback", "trace"):
                    traceback.print_exc()
Ejemplo n.º 10
0
    return parser.parse_args(sys.argv[1:])


with tools.Context("cmdline", "info"):
    args = process_commandline()
    # Parse additional arguments
    for name in ("gar", "attack", "model", "loss", "criterion"):
        name = "%s_args" % name
        keyval = getattr(args, name)
        setattr(args, name,
                dict() if keyval is None else tools.parse_keyval(keyval))
    # Count the number of real honest workers
    args.nb_honests = args.nb_workers - args.nb_real_byz
    if args.nb_honests < 0:
        tools.fatal(
            "Invalid arguments: there are more real Byzantine workers (%d) than total workers (%d)"
            % (args.nb_real_byz, args.nb_workers))
    # Check the learning rate and associated options
    if args.learning_rate <= 0:
        tools.fatal("Invalid arguments: non-positive learning rate %s" %
                    args.learning_rate)
    if args.learning_rate_decay < 0:
        tools.fatal("Invalid arguments: negative learning rate decay %s" %
                    args.learning_rate_decay)
    if args.learning_rate_decay_delta <= 0:
        tools.fatal(
            "Invalid arguments: non-positive learning rate decay delta %s" %
            args.learning_rate_decay_delta)
    # Check the momentum position
    momentum_at_values = ("update", "server", "worker")
    if args.momentum_at not in momentum_at_values: