def setUp(self): """Spoof args.""" args = forest.options().parse_args() self.defs = forest.ConservativeStrategy(args) device = torch.device( 'cuda:0') if torch.cuda.is_available() else torch.device('cpu') dtype = torch.float self.args = args self.setup = dict(device=device, dtype=dtype)
"""General interface script to launch distributed poisoning jobs. Launch only through the pytorch launch utility.""" import socket import datetime import time import torch import forest torch.backends.cudnn.benchmark = forest.consts.BENCHMARK torch.multiprocessing.set_sharing_strategy(forest.consts.SHARING_STRATEGY) # Parse input arguments args = forest.options().parse_args() # Parse training strategy defs = forest.training_strategy(args) # 100% reproducibility? if args.deterministic: forest.utils.set_deterministic() if args.local_rank is None: raise ValueError( 'This script should only be launched via the pytorch launch utility!') if __name__ == "__main__": if torch.cuda.device_count() < args.local_rank: raise ValueError( 'Process invalid, oversubscribing to GPUs is not possible in this mode.' ) else: torch.cuda.set_device(args.local_rank)