Ejemplo n.º 1
0
 def _get_models(self):
   """ Lazy-initialize and return the map '__models'.
   Returns:
     '__models'
   """
   # Fast-path already loaded
   if self.__models is not None:
     return self.__models
   # Initialize the dictionary
   self.__models = dict()
   # Populate this dictionary with TorchVision's models
   for name in dir(torchvision.models):
     if len(name) == 0 or name[0] == "_": # Ignore "protected" members
       continue
     builder = getattr(torchvision.models, name)
     if isinstance(builder, types.FunctionType): # Heuristic
       self.__models["torchvision-%s" % name.lower()] = builder
   # Dynamically add the custom models from subdirectory 'models/'
   def add_custom_models(name, module, _):
     nonlocal self
     # Check if has exports, fallback otherwise
     exports = getattr(module, "__all__", None)
     if exports is None:
       tools.warning("Model module %r does not provide '__all__'; falling back to '__dict__' for name discovery" % name)
       exports = (name for name in dir(module) if len(name) > 0 and name[0] != "_")
     # Register the association 'name -> constructor' for all the models
     exported = False
     for model in exports:
       # Check model name type
       if not isinstance(model, str):
         tools.warning("Model module %r exports non-string name %r; ignored" % (name, model))
         continue
       # Recover instance from name
       constructor = getattr(module, model, None)
       # Check instance is callable (it's only an heuristic...)
       if not callable(constructor):
         continue
       # Register callable with composite name
       exported = True
       fullname = "%s-%s" % (name, model)
       if fullname in self.__models:
         tools.warning("Unable to make available model %r from module %r, as the name %r already exists" % (model, name, fullname))
         continue
       self.__models[fullname] = constructor
     if not exported:
       tools.warning("Model module %r does not export any valid constructor name through '__all__'" % name)
   with tools.Context("models", None):
     tools.import_directory(pathlib.Path(__file__).parent / "models", {"__package__": "%s.models" % __package__}, post=add_custom_models)
   # Return the dictionary
   return self.__models
Ejemplo n.º 2
0
 def _get_datasets(self):
   """ Lazy-initialize and return the map '__datasets'.
   Returns:
     '__datasets'
   """
   global transforms
   # Fast-path already loaded
   if self.__datasets is not None:
     return self.__datasets
   # Initialize the dictionary
   self.__datasets = dict()
   # Populate this dictionary with TorchVision's datasets
   for name in dir(torchvision.datasets):
     if len(name) == 0 or name[0] == "_": # Ignore "protected" members
       continue
     constructor = getattr(torchvision.datasets, name)
     if isinstance(constructor, type): # Heuristic
       def make_builder(constructor, name):
         def builder(root, batch_size=None, shuffle=False, num_workers=1, *args, **kwargs):
           # Try to build the dataset instance
           data = constructor(root, *args, **kwargs)
           assert isinstance(data, torch.utils.data.Dataset), f"Internal heuristic failed: {name!r} was not a dataset name"
           # Ensure there is at least a tensor transformation for each torchvision dataset
           if name not in transforms:
             transforms[name] = torchvision.transforms.ToTensor()
           # Wrap into a loader
           batch_size = batch_size or len(data)
           loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
           # Wrap into an infinite batch sampler generator
           return make_sampler(loader)
         return builder
       self.__datasets[name.lower()] = make_builder(constructor, name)
   # Dynamically add the custom datasets from subdirectory 'datasets/'
   def add_custom_datasets(name, module, _):
     nonlocal self
     # Check if has exports, fallback otherwise
     exports = getattr(module, "__all__", None)
     if exports is None:
       tools.warning(f"Dataset module {name!r} does not provide '__all__'; falling back to '__dict__' for name discovery")
       exports = (name for name in dir(module) if len(name) > 0 and name[0] != "_")
     # Register the association 'name -> constructor' for all the datasets
     exported = False
     for dataset in exports:
       # Check dataset name type
       if not isinstance(dataset, str):
         tools.warning(f"Dataset module {name!r} exports non-string name {dataset!r}; ignored")
         continue
       # Recover instance from name
       constructor = getattr(module, dataset, None)
       # Check instance is callable (it's only an heuristic...)
       if not callable(constructor):
         continue
       # Register callable with composite name
       exported = True
       fullname = f"{name}-{dataset}"
       if fullname in self.__datasets:
         tools.warning(f"Unable to make available dataset {dataset!r} from module {name!r}, as the name {fullname!r} already exists")
         continue
       self.__datasets[fullname] = constructor
     if not exported:
       tools.warning(f"Dataset module {name!r} does not export any valid constructor name through '__all__'")
   with tools.Context("datasets", None):
     tools.import_directory(pathlib.Path(__file__).parent / "datasets", {"__package__": f"{__package__}.datasets"}, post=add_custom_datasets)
   # Return the dictionary
   return self.__datasets
Ejemplo n.º 3
0
  # Return the selected function with the associated name
  return func

def register(name, unchecked, check, upper_bound=None, influence=None):
  """ Simple registration-wrapper helper.
  Args:
    name        GAR name
    unchecked   Associated function (see module description)
    check       Parameter validity check function
    upper_bound Compute the theoretical upper bound on the ratio non-Byzantine standard deviation / norm to use this aggregation rule: (n, f, d) -> float
    influence   Attack acceptation ratio function
  """
  global gars
  # Check if name already in use
  if name in gars:
    tools.warning("Unable to register %r GAR: name already in use" % name)
    return
  # Export the selected function with the associated name
  gars[name] = make_gar(unchecked, check, upper_bound=upper_bound, influence=influence)

# Registered rules (mapping name -> aggregation rule)
gars = dict()

# Load all local modules
with tools.Context("aggregators", None):
  tools.import_directory(pathlib.Path(__file__).parent, globals())

# Bind/overwrite the GAR name with the associated rules in globals()
for name, rule in gars.items():
  globals()[name] = rule