Exemple #1
0
    def setUp(self):
        """Spoof args."""
        args = forest.options().parse_args()
        self.defs = forest.ConservativeStrategy(args)
        device = torch.device(
            'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
        dtype = torch.float

        self.args = args
        self.setup = dict(device=device, dtype=dtype)
"""General interface script to launch distributed poisoning jobs. Launch only through the pytorch launch utility."""

import socket
import datetime
import time

import torch
import forest
torch.backends.cudnn.benchmark = forest.consts.BENCHMARK
torch.multiprocessing.set_sharing_strategy(forest.consts.SHARING_STRATEGY)

# Parse input arguments
args = forest.options().parse_args()
# Parse training strategy
defs = forest.training_strategy(args)
# 100% reproducibility?
if args.deterministic:
    forest.utils.set_deterministic()

if args.local_rank is None:
    raise ValueError(
        'This script should only be launched via the pytorch launch utility!')

if __name__ == "__main__":

    if torch.cuda.device_count() < args.local_rank:
        raise ValueError(
            'Process invalid, oversubscribing to GPUs is not possible in this mode.'
        )
    else:
        torch.cuda.set_device(args.local_rank)