def _setup_misc(self): # misc setup components that were in goissip_sgd config = self.config state = {} update_state( state, { 'epoch': 0, 'itr': 0, 'best_prec1': 0, 'is_best': True, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'elapsed_time': 0, 'batch_meter': Meter(ptag='Time').__dict__, 'data_meter': Meter(ptag='Data').__dict__, 'nn_meter': Meter(ptag='Forward/Backward').__dict__ }) self.state = state # module used to relaunch jobs and handle external termination signals ClusterManager.set_checkpoint_dir(config['checkpoint_dir']) self.cmanager = ClusterManager(rank=config['rank'], world_size=config['world_size'], model_tag=config['tag'], state=state, all_workers=config['checkpoint_all']) # enable low-level optimization of compute graph using cuDNN library? cudnn.benchmark = True self.batch_meter = Meter(state['batch_meter']) self.data_meter = Meter(state['data_meter']) self.nn_meter = Meter(state['nn_meter']) # initalize log file if not os.path.exists(config['out_fname']): with open(config['out_fname'], 'w') as f: print('BEGIN-TRAINING\n' 'World-Size,{ws}\n' 'Num-DLWorkers,{nw}\n' 'Batch-Size,{bs}\n' 'Epoch,itr,BT(s),avg:BT(s),std:BT(s),' 'NT(s),avg:NT(s),std:NT(s),' 'DT(s),avg:DT(s),std:DT(s),' 'Loss,avg:Loss,Prec@1,avg:Prec@1,Prec@5,avg:Prec@5,val'. format(ws=config['world_size'], nw=config['num_dataloader_workers'], bs=config['batch_size']), file=f) self.start_itr = state['itr'] self.start_epoch = state['epoch'] self.elapsed_time = state['elapsed_time'] self.begin_time = time.time() - state['elapsed_time'] self.best_val_prec1 = 0
def parse_args(): """ Set env-vars and global args rank: <-- $SLRUM_PROCID world_size<-- $SLURM_NTASKS Master address <-- $SLRUM_NODENAME of rank 0 process (or HOSTNAME) Master port <-- any free port (doesn't really matter) """ class DataStore(): def __init__(self): self.all_reduce = 'False' self.batch_size = 32 self.lr = 0.1 self.num_dataloader_workers = 10 self.num_epochs = 90 self.num_iterations_per_training_epoch = None self.momentum = 0.9 self.weight_decay = 1e-4 self.push_sum = 'True' self.graph_type = 5 self.mixing_strategy = 0 self.schedule = None self.peers_per_itr_schedule = None self.overlap = 'False' self.synch_freq = 0 self.warmup = 'False' self.seed = 47 self.print_freq = 10 self.checkpoint_all = 'False' self.overwrite_checkpoints = 'True' self.master_port = '40100' self.checkpoint_dir = "./" self.network_interface_type = 'infiniband' self.num_itr_ignore = 10 # self.dataset_dir = "./data/" self.no_cuda_streams = None self.master_addr = None self.backend = 'nccl' self.rank = 1 self.world_size = 5 self.tag = '' self.out_fname = '' self.resume = 'False' self.verbose = 'True' self.train_fast = 'False' self.nesterov = 'False' args = DataStore() #parser.parse_args() ClusterManager.set_checkpoint_dir(args.checkpoint_dir) # rank and world_size need to be changed depending on the scheduler being # used to run the distributed jobs args.master_addr = os.environ['HOSTNAME'] if args.backend == 'mpi': args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) args.world_size = int(os.environ['OMPI_UNIVERSE_SIZE']) else: args.rank = 1 #int(os.environ['SLURM_PROCID']) args.world_size = 5 #int(os.environ['SLURM_NTASKS']) args.out_fname = ClusterManager.CHECKPOINT_DIR \ + args.tag \ + 'out_r' + str(args.rank) \ + '_n' + str(args.world_size) \ + '.csv' args.resume = True if args.resume == 'True' else False args.verbose = True if args.verbose == 'True' else False args.train_fast = True if args.train_fast == 'True' else False args.nesterov = True if args.nesterov == 'True' else False args.checkpoint_all = True if args.checkpoint_all == 'True' else False args.warmup = True if args.warmup == 'True' else False args.overlap = True if args.overlap == 'True' else False args.push_sum = True if args.push_sum == 'True' else False args.all_reduce = True if args.all_reduce == 'True' else False args.cpu_comm = True if (args.backend == 'gloo' and not args.push_sum and not args.all_reduce) else False args.comm_device = torch.device('cpu') if args.cpu_comm else torch.device( 'cuda') args.overwrite_checkpoints = True if args.overwrite_checkpoints == 'True' else False args.lr_schedule = {} if args.schedule is None: args.schedule = [30, 0.1, 60, 0.1, 80, 0.1] i, epoch = 0, None for v in args.schedule: if i == 0: epoch = v elif i == 1: args.lr_schedule[epoch] = v i = (i + 1) % 2 del args.schedule # parse peers per itr sched (epoch, num_peers) args.ppi_schedule = {} if args.peers_per_itr_schedule is None: args.peers_per_itr_schedule = [0, 1] i, epoch = 0, None for v in args.peers_per_itr_schedule: if i == 0: epoch = v elif i == 1: args.ppi_schedule[epoch] = v i = (i + 1) % 2 del args.peers_per_itr_schedule # must specify how many peers to communicate from the start of training assert 0 in args.ppi_schedule if args.all_reduce: assert args.graph_type == -1 if args.backend == 'gloo': assert args.network_interface_type == 'ethernet' os.environ['GLOO_SOCKET_IFNAME'] = get_tcp_interface_name( network_interface_type=args.network_interface_type) elif args.network_interface_type == 'ethernet': if args.backend == 'nccl': os.environ['NCCL_SOCKET_IFNAME'] = get_tcp_interface_name( network_interface_type=args.network_interface_type) os.environ['NCCL_IB_DISABLE'] = '1' else: raise NotImplementedError # initialize torch distributed backend os.environ['MASTER_ADDR'] = args.master_addr os.environ['MASTER_PORT'] = args.master_port dist.init_process_group(backend=args.backend, world_size=args.world_size, rank=args.rank) args.graph, args.mixing = None, None graph_class = GRAPH_TOPOLOGIES[args.graph_type] if graph_class: # dist.barrier is done here to ensure the NCCL communicator is created # here. This prevents an error which may be caused if the NCCL # communicator is created at a time gap of more than 5 minutes in # different processes dist.barrier() args.graph = graph_class(args.rank, args.world_size, peers_per_itr=args.ppi_schedule[0]) mixing_class = MIXING_STRATEGIES[args.mixing_strategy] if mixing_class and args.graph: args.mixing = mixing_class(args.graph, args.comm_device) return args
def main(): global args, state, log args = parse_args() print("Successfully parsed the args") log = make_logger(args.rank, args.verbose) log.info('args: {}'.format(args)) log.info(socket.gethostname()) # seed for reproducibility torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.backends.cudnn.deterministic = True # init model, loss, and optimizer model = init_model() print("Model has been initialized") if args.all_reduce: model = torch.nn.parallel.DistributedDataParallel(model) else: model = GossipDataParallel(model, graph=args.graph, mixing=args.mixing, comm_device=args.comm_device, push_sum=args.push_sum, overlap=args.overlap, synch_freq=args.synch_freq, verbose=args.verbose, use_streams=not args.no_cuda_streams) core_criterion = nn.CrossEntropyLoss( ) #nn.KLDivLoss(reduction='batchmean').cuda() log_softmax = nn.LogSoftmax(dim=1) def criterion(input, kl_target): assert kl_target.dtype != torch.int64 loss = core_criterion(log_softmax(input), kl_target) return loss optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov) optimizer.zero_grad() # dictionary used to encode training state state = {} update_state( state, { 'epoch': 0, 'itr': 0, 'best_prec1': 0, 'is_best': True, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'elapsed_time': 0, 'batch_meter': Meter(ptag='Time').__dict__, 'data_meter': Meter(ptag='Data').__dict__, 'nn_meter': Meter(ptag='Forward/Backward').__dict__ }) # module used to relaunch jobs and handle external termination signals cmanager = ClusterManager(rank=args.rank, world_size=args.world_size, model_tag=args.tag, state=state, all_workers=args.checkpoint_all) # resume from checkpoint if args.resume: if os.path.isfile(cmanager.checkpoint_fpath): log.info("=> loading checkpoint '{}'".format( cmanager.checkpoint_fpath)) checkpoint = torch.load(cmanager.checkpoint_fpath) update_state( state, { 'epoch': checkpoint['epoch'], 'itr': checkpoint['itr'], 'best_prec1': checkpoint['best_prec1'], 'is_best': False, 'state_dict': checkpoint['state_dict'], 'optimizer': checkpoint['optimizer'], 'elapsed_time': checkpoint['elapsed_time'], 'batch_meter': checkpoint['batch_meter'], 'data_meter': checkpoint['data_meter'], 'nn_meter': checkpoint['nn_meter'] }) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) log.info("=> loaded checkpoint '{}' (epoch {}; itr {})".format( cmanager.checkpoint_fpath, checkpoint['epoch'], checkpoint['itr'])) else: log.info("=> no checkpoint found at '{}'".format( cmanager.checkpoint_fpath)) # enable low-level optimization of compute graph using cuDNN library? cudnn.benchmark = True # meters used to compute timing stats batch_meter = Meter(state['batch_meter']) data_meter = Meter(state['data_meter']) nn_meter = Meter(state['nn_meter']) # initalize log file if not os.path.exists(args.out_fname): with open(args.out_fname, 'w') as f: print('BEGIN-TRAINING\n' 'World-Size,{ws}\n' 'Num-DLWorkers,{nw}\n' 'Batch-Size,{bs}\n' 'Epoch,itr,BT(s),avg:BT(s),std:BT(s),' 'NT(s),avg:NT(s),std:NT(s),' 'DT(s),avg:DT(s),std:DT(s),' 'Loss,avg:Loss,Accuracy,avg:Accuracy,val'.format( ws=args.world_size, nw=args.num_dataloader_workers, bs=args.batch_size), file=f) # create distributed data loaders loader, sampler = make_dataloader(args, train=True) if not args.train_fast: val_loader = make_dataloader(args, train=False) start_itr = state['itr'] start_epoch = state['epoch'] elapsed_time = state['elapsed_time'] begin_time = time.time() - state['elapsed_time'] best_val_prec1 = 0 for epoch in range(start_epoch, args.num_epochs): # deterministic seed used to load agent's subset of data sampler.set_epoch(epoch + args.seed * 90) if not args.all_reduce: # update the model's peers_per_itr attribute update_peers_per_itr(model, epoch) # start all agents' training loop at same time if not args.all_reduce: model.block() train(model, criterion, optimizer, batch_meter, data_meter, nn_meter, loader, epoch, start_itr, begin_time, args.num_itr_ignore) start_itr = 0 if not args.train_fast: # update state after each epoch elapsed_time = time.time() - begin_time update_state( state, { 'epoch': epoch + 1, 'itr': start_itr, 'is_best': False, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'elapsed_time': elapsed_time, 'batch_meter': batch_meter.__dict__, 'data_meter': data_meter.__dict__, 'nn_meter': nn_meter.__dict__ }) # evaluate on validation set and save checkpoint prec1 = validate(val_loader, model, criterion) with open(args.out_fname, '+a') as f: print('{ep},{itr},{bt},{nt},{dt},' '{filler},{filler},' '{filler},{filler},' '{filler},{filler},' '{val}'.format(ep=epoch, itr=-1, bt=batch_meter, dt=data_meter, nt=nn_meter, filler=-1, val=prec1), file=f) if prec1 > best_val_prec1: update_state(state, {'is_best': True}) best_val_prec1 = prec1 epoch_id = epoch if not args.overwrite_checkpoints else None cmanager.save_checkpoint( epoch_id, requeue_on_signal=(epoch != args.num_epochs - 1)) if args.train_fast: val_loader = make_dataloader(args, train=False) acc = validate(val_loader, model, criterion) log.info('Test accuracy: {}'.format(acc)) log.info('elapsed_time {0}'.format(elapsed_time))
def parse_args(): """ Set env-vars and global args rank: <-- $SLRUM_PROCID world_size<-- $SLURM_NTASKS Master address <-- $SLRUM_NODENAME of rank 0 process (or HOSTNAME) Master port <-- any free port (doesn't really matter) """ args = parser.parse_args() ClusterManager.set_checkpoint_dir(args.checkpoint_dir) args.master_addr = os.environ['HOSTNAME'] if args.backend == 'mpi': args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) args.world_size = int(os.environ['OMPI_UNIVERSE_SIZE']) args.device_id = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) else: args.rank = int(os.environ['SLURM_PROCID']) args.world_size = int(os.environ['SLURM_NTASKS']) args.device_id = int(os.environ['SLURM_LOCALID']) args.out_fname = ClusterManager.CHECKPOINT_DIR \ + args.tag \ + 'out_r' + str(args.rank) \ + '_n' + str(args.world_size) \ + '.csv' args.resume = True if args.resume == 'True' else False args.verbose = True if args.verbose == 'True' else False args.train_fast = True if args.train_fast == 'True' else False args.nesterov = True if args.nesterov == 'True' else False args.checkpoint_all = True if args.checkpoint_all == 'True' else False args.warmup = True if args.warmup == 'True' else False args.cpu_comm = True if args.backend == 'gloo' else False args.comm_device = torch.device('cpu') if args.cpu_comm else torch.device( 'cuda') args.overlap = True if args.overlap == 'True' else False args.push_sum = True if args.push_sum == 'True' else False args.all_reduce = True if args.all_reduce == 'True' else False args.bilat = True if args.bilat == 'True' else False args.global_epoch = None args.global_itr = None if args.rank == 0 and os.path.isfile(args.shared_fpath): os.remove(args.shared_fpath) while os.path.isfile(args.shared_fpath): pass args.lr_schedule = {} if args.schedule is None: args.schedule = [30, 0.1, 60, 0.1, 80, 0.1] i, epoch = 0, None for v in args.schedule: if i == 0: epoch = v elif i == 1: args.lr_schedule[epoch] = v i = (i + 1) % 2 del args.schedule # parse peers per itr sched (epoch, num_peers) args.ppi_schedule = {} if args.peers_per_itr_schedule is None: args.peers_per_itr_schedule = [0, 1] i, epoch = 0, None for v in args.peers_per_itr_schedule: if i == 0: epoch = v elif i == 1: args.ppi_schedule[epoch] = v i = (i + 1) % 2 del args.peers_per_itr_schedule # must specify how many peers to communicate from the start of training assert 0 in args.ppi_schedule if args.backend == 'gloo': assert args.network_interface_type == 'ethernet' os.environ['GLOO_SOCKET_IFNAME'] = get_tcp_interface_name( network_interface_type=args.network_interface_type) elif args.network_interface_type == 'ethernet': if args.backend == 'nccl': os.environ['NCCL_SOCKET_IFNAME'] = get_tcp_interface_name( network_interface_type=args.network_interface_type) os.environ['NCCL_IB_DISABLE'] = '1' else: raise NotImplementedError # initialize torch distributed backend os.environ['MASTER_ADDR'] = args.master_addr os.environ['MASTER_PORT'] = str(int(args.master_port) + 1) dist.init_process_group(backend=args.backend, world_size=args.world_size, rank=args.rank) args.graph_class = GRAPH_TOPOLOGIES[args.graph_type] args.mixing_class = MIXING_STRATEGIES[args.mixing_strategy] if args.graph_class is None: raise Exception('Incorrect arguments for graph_type') if args.mixing_class is None: raise Exception('Incorrect arguments for mixing_strategy') return args
def main(): global args, state, log args = parse_args() log = make_logger(args.rank, args.verbose) log.info('args: {}'.format(args)) log.info(socket.gethostname()) # seed for reproducibility torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.backends.cudnn.deterministic = True # init model, loss, and optimizer model = init_model() assert args.bilat and not args.all_reduce model = BilatGossipDataParallel( model, master_addr=args.master_addr, master_port=args.master_port, backend=args.backend, world_size=args.world_size, rank=args.rank, graph_class=args.graph_class, mixing_class=args.mixing_class, comm_device=args.comm_device, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov, verbose=args.verbose, num_peers=args.ppi_schedule[0], network_interface_type=args.network_interface_type) criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov) optimizer.zero_grad() # dictionary used to encode training state state = {} update_state( state, { 'epoch': 0, 'itr': 0, 'best_prec1': 0, 'is_best': True, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'elapsed_time': 0, 'batch_meter': Meter(ptag='Time').__dict__, 'data_meter': Meter(ptag='Data').__dict__, 'nn_meter': Meter(ptag='Forward/Backward').__dict__ }) # module used to relaunch jobs and handle external termination signals cmanager = ClusterManager(rank=args.rank, world_size=args.world_size, model_tag=args.tag, state=state, all_workers=args.checkpoint_all) # resume from checkpoint if args.resume: if os.path.isfile(cmanager.checkpoint_fpath): log.info("=> loading checkpoint '{}'".format( cmanager.checkpoint_fpath)) checkpoint = torch.load(cmanager.checkpoint_fpath) update_state( state, { 'epoch': checkpoint['epoch'], 'itr': checkpoint['itr'], 'best_prec1': checkpoint['best_prec1'], 'is_best': False, 'state_dict': checkpoint['state_dict'], 'optimizer': checkpoint['optimizer'], 'elapsed_time': checkpoint['elapsed_time'], 'batch_meter': checkpoint['batch_meter'], 'data_meter': checkpoint['data_meter'], 'nn_meter': checkpoint['nn_meter'] }) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) log.info("=> loaded checkpoint '{}' (epoch {}; itr {})".format( cmanager.checkpoint_fpath, checkpoint['epoch'], checkpoint['itr'])) else: log.info("=> no checkpoint found at '{}'".format( cmanager.checkpoint_fpath)) # enable low-level optimization of compute graph using cuDNN library? cudnn.benchmark = True # meters used to compute timing stats batch_meter = Meter(state['batch_meter']) data_meter = Meter(state['data_meter']) nn_meter = Meter(state['nn_meter']) # initalize log file if not args.resume: with open(args.out_fname, 'w') as f: print( 'BEGIN-TRAINING\n' 'World-Size,{ws}\n' 'Num-DLWorkers,{nw}\n' 'Batch-Size,{bs}\n' 'Epoch,itr,BT(s),avg:BT(s),std:BT(s),' 'NT(s),avg:NT(s),std:NT(s),' 'DT(s),avg:DT(s),std:DT(s),' 'Loss,avg:Loss,Prec@1,avg:Prec@1,Prec@5,avg:Prec@5,val'.format( ws=args.world_size, nw=args.num_dataloader_workers, bs=args.batch_size), file=f) # create distributed data loaders loader, sampler = make_dataloader(args, train=True) if not args.train_fast: val_loader = make_dataloader(args, train=False) # start all agents' training loop at same time model.block() start_itr = state['itr'] start_epoch = state['epoch'] elapsed_time = state['elapsed_time'] begin_time = time.time() - state['elapsed_time'] epoch = start_epoch stopping_criterion = epoch >= args.num_epochs while not stopping_criterion: # deterministic seed used to load agent's subset of data sampler.set_epoch(epoch + args.seed * 90) train(model, criterion, optimizer, batch_meter, data_meter, nn_meter, loader, epoch, start_itr, begin_time) start_itr = 0 if not args.train_fast: # update state after each epoch elapsed_time = time.time() - begin_time update_state( state, { 'epoch': epoch + 1, 'itr': start_itr, 'is_best': False, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'elapsed_time': elapsed_time, 'batch_meter': batch_meter.__dict__, 'data_meter': data_meter.__dict__, 'nn_meter': nn_meter.__dict__ }) # evaluate on validation set and save checkpoint prec1 = validate(val_loader, model, criterion) with open(args.out_fname, '+a') as f: print('{ep},{itr},{bt},{nt},{dt},' '{filler},{filler},' '{filler},{filler},' '{filler},{filler},' '{val}'.format(ep=epoch, itr=-1, bt=batch_meter, dt=data_meter, nt=nn_meter, filler=-1, val=prec1), file=f) cmanager.save_checkpoint() # sycnhronize models at the end of validation run model.block() epoch += 1 stopping_criterion = args.global_epoch >= args.num_epochs if args.train_fast: val_loader = make_dataloader(args, train=False) prec1 = validate(val_loader, model, criterion) log.info('Test accuracy: {}'.format(prec1)) log.info('elapsed_time {0}'.format(elapsed_time))
class SGPRunner(object): def __init__(self, model_creator, data_creator, optimizer_creator, config=None): """Initializes the runner. Args: model_creator (dict -> torch.nn.Module): see pytorch_trainer.py. data_creator (dict -> Dataset, Dataset): see pytorch_trainer.py. optimizer_creator (torch.nn.Module, dict -> loss, optimizer): see pytorch_trainer.py. config (dict): see pytorch_trainer.py. batch_size (int): see pytorch_trainer.py. """ self.model_creator = model_creator self.data_creator = data_creator self.optimizer_creator = optimizer_creator self.config = {} if config is None else config self.verbose = True self.epoch = 0 self._timers = { k: utils.TimerStat(window_size=1) for k in [ "setup_proc", "setup_model", "get_state", "set_state", "validation", "training" ] } def setup(self, url, world_rank, world_size): """Connects to the distributed PyTorch backend and initializes the model. Args: url (str): the URL used to connect to distributed PyTorch. world_rank (int): the index of the runner. world_size (int): the total number of runners. """ # print('_setup_distributed_pytorch') # checking current dir # from subprocess import Popen, PIPE # process = Popen(['ls', './'], stdout=PIPE, stderr=PIPE) # stdout, stderr = process.communicate() # print(stdout) self._update_config(world_rank, world_size) self._setup_distributed_pytorch(url, world_rank, world_size) # print('_setup_gossip_related') self._setup_gossip_related() # print('_setup_training') self._setup_training() # print('_setup_misc') self._setup_misc() # print('setup done') def _update_config(self, world_rank, world_size): self.config['rank'] = world_rank self.config['world_size'] = world_size self.config[ 'out_fname'] = '/home/ubuntu/stochastic_gradient_push/ckpt/{tag}out_r{rank}_n{wsize}.csv'.format( tag=self.config['tag'], rank=self.config['rank'], wsize=self.config['world_size']) def _setup_distributed_pytorch(self, url, world_rank, world_size): # os.environ["CUDA_LAUNCH_BLOCKING"] = "1" # the distributed pytorch runner has this but don't know why with self._timers["setup_proc"]: self.world_rank = world_rank logger.debug( "Connecting to {} world_rank: {} world_size: {}".format( url, world_rank, world_size)) logger.debug("using {}".format(self.config['backend'])) print("Connecting to {} world_rank: {} world_size: {}".format( url, world_rank, world_size)) dist.init_process_group(backend=self.config['backend'], init_method=url, rank=world_rank, world_size=world_size) def _setup_gossip_related(self): config = self.config self.comm_device = torch.device( 'cpu') if config['cpu_comm'] else torch.device('cuda') self.graph, self.mixing = None, None graph_class = GRAPH_TOPOLOGIES[config['graph_type']] if graph_class: # dist.barrier is done here to ensure the NCCL communicator is created # here. This prevents an error which may be caused if the NCCL # communicator is created at a time gap of more than 5 minutes in # different processes dist.barrier() self.graph = graph_class(config['rank'], config['world_size'], peers_per_itr=config['ppi_schedule'][0]) mixing_class = MIXING_STRATEGIES[config['mixing_strategy']] if mixing_class and self.graph: self.mixing = mixing_class(self.graph, self.comm_device) def _setup_training(self): config = self.config ## note: assume gpu available logger.debug("Creating model") # print('model') self.model = self.model_creator(self.config) if config['all_reduce']: # print('DistributedDataParallel') self.model = torch.nn.parallel.DistributedDataParallel(self.model) else: # print('GossipDataParallel') self.model = GossipDataParallel( self.model, graph=self.graph, mixing=self.mixing, comm_device=self.comm_device, push_sum=config['push_sum'], overlap=config['overlap'], synch_freq=config['synch_freq'], verbose=config['verbose'], use_streams=not config['no_cuda_streams']) logger.debug("Creating optimizer") # print('Optimizer') self.criterion, self.optimizer = self.optimizer_creator( self.model, self.config) # if torch.cuda.is_available(): # self.criterion = self.criterion.cuda() # ptorch runner has this but does not work here logger.debug("Creating dataset") # print('DataSet') self.training_set, self.validation_set = self.data_creator(self.config) # TODO: make num_workers configurable self.train_sampler = torch.utils.data.distributed.DistributedSampler( dataset=self.training_set, num_replicas=config['world_size'], rank=config['rank']) self.train_loader = torch.utils.data.DataLoader( self.training_set, batch_size=config['batch_size'], shuffle=(self.train_sampler is None), num_workers=config['num_dataloader_workers'], pin_memory=True, sampler=self.train_sampler) self.validation_sampler = torch.utils.data.distributed.DistributedSampler( dataset=self.validation_set, num_replicas=config['world_size'], rank=config['rank']) # pytorch runner ver # self.validation_loader = torch.utils.data.DataLoader( # self.validation_set, # batch_size=config['batch_size'], # shuffle=(self.validation_sampler is None), # num_workers=config['num_dataloader_workers'], # pin_memory=True, # sampler=self.validation_sampler) # sgp code ver self.validation_loader = torch.utils.data.DataLoader( self.validation_set, batch_size=config['batch_size'], shuffle=False, num_workers=config['num_dataloader_workers'], pin_memory=True) self.optimizer.zero_grad() def _setup_misc(self): # misc setup components that were in goissip_sgd config = self.config state = {} update_state( state, { 'epoch': 0, 'itr': 0, 'best_prec1': 0, 'is_best': True, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'elapsed_time': 0, 'batch_meter': Meter(ptag='Time').__dict__, 'data_meter': Meter(ptag='Data').__dict__, 'nn_meter': Meter(ptag='Forward/Backward').__dict__ }) self.state = state # module used to relaunch jobs and handle external termination signals ClusterManager.set_checkpoint_dir(config['checkpoint_dir']) self.cmanager = ClusterManager(rank=config['rank'], world_size=config['world_size'], model_tag=config['tag'], state=state, all_workers=config['checkpoint_all']) # enable low-level optimization of compute graph using cuDNN library? cudnn.benchmark = True self.batch_meter = Meter(state['batch_meter']) self.data_meter = Meter(state['data_meter']) self.nn_meter = Meter(state['nn_meter']) # initalize log file if not os.path.exists(config['out_fname']): with open(config['out_fname'], 'w') as f: print('BEGIN-TRAINING\n' 'World-Size,{ws}\n' 'Num-DLWorkers,{nw}\n' 'Batch-Size,{bs}\n' 'Epoch,itr,BT(s),avg:BT(s),std:BT(s),' 'NT(s),avg:NT(s),std:NT(s),' 'DT(s),avg:DT(s),std:DT(s),' 'Loss,avg:Loss,Prec@1,avg:Prec@1,Prec@5,avg:Prec@5,val'. format(ws=config['world_size'], nw=config['num_dataloader_workers'], bs=config['batch_size']), file=f) self.start_itr = state['itr'] self.start_epoch = state['epoch'] self.elapsed_time = state['elapsed_time'] self.begin_time = time.time() - state['elapsed_time'] self.best_val_prec1 = 0 def resume(self): # TODO pass # neeed optimizer.zero_grad() def step(self): config = self.config state = self.state # TODO: epoch setting? """Runs a training epoch and updates the model parameters.""" logger.debug("Starting Epoch {}".format(self.epoch)) self.train_sampler.set_epoch(self.epoch + self.config['seed'] * 90) if not config['all_reduce']: # update the model's peers_per_itr attribute sgp_utils.update_peers_per_itr(self.config, self.model, self.epoch) # start all agents' training loop at same time if not config['all_reduce']: self.model.block() losses_avg, top1_avg, top5_avg = sgp_utils.train( self.config, self.model, self.criterion, self.optimizer, self.batch_meter, self.data_meter, self.nn_meter, self.train_loader, self.epoch, self.start_itr, self.begin_time, self.config['num_itr_ignore'], logger) train_stats = { "epoch": self.epoch, "train_loss": losses_avg, "train_top1": top1_avg, "train_top5": top5_avg, } start_itr = 0 if not config['train_fast']: # update state after each epoch elapsed_time = time.time() - self.begin_time update_state( state, { 'epoch': self.epoch + 1, 'itr': self.start_itr, 'is_best': False, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'elapsed_time': elapsed_time, 'batch_meter': self.batch_meter.__dict__, 'data_meter': self.data_meter.__dict__, 'nn_meter': self.nn_meter.__dict__ }) # evaluate on validation set and save checkpoint loss1, prec1, prec5 = sgp_utils.validate(self.validation_loader, self.model, self.criterion, logger) with open(config['out_fname'], '+a') as f: print('{ep},{itr},{bt},{nt},{dt},' '{filler},{filler},' '{filler},{filler},' '{filler},{filler},' '{val}'.format(ep=self.epoch, itr=-1, bt=self.batch_meter, dt=self.data_meter, nt=self.nn_meter, filler=-1, val=prec1), file=f) if prec1 > self.best_val_prec1: update_state(state, {'is_best': True}) self.best_val_prec1 = prec1 epoch_id = self.epoch if not config[ 'overwrite_checkpoints'] else None self.cmanager.save_checkpoint( epoch_id, requeue_on_signal=(self.epoch != config['num_epochs'] - 1)) print('Finished Epoch {ep}, elapsed {tt:.3f}sec'.format( ep=self.epoch, tt=elapsed_time)) else: elapsed_time = time.time() - self.begin_time print('Finished Epoch {ep}, elapsed {tt:.3f}sec'.format( ep=self.epoch, tt=elapsed_time)) self.epoch += 1 return train_stats # def validate(self): # """Evaluates the model on the validation data set.""" # with self._timers["validation"]: # validation_stats = pytorch_utils.validate( # self.validation_loader, self.model, self.criterion) # validation_stats.update(self.stats()) # return validation_stats # def stats(self): # """Returns a dictionary of statistics collected.""" # stats = {"epoch": self.epoch} # for k, t in self._timers.items(): # stats[k + "_time_mean"] = t.mean # stats[k + "_time_total"] = t.sum # t.reset() # return stats # def get_state(self): # """Returns the state of the runner.""" # return { # "epoch": self.epoch, # "model": self.model.state_dict(), # "optimizer": self.optimizer.state_dict(), # "stats": self.stats() # } # def set_state(self, state): # """Sets the state of the model.""" # # TODO: restore timer stats # self.model.load_state_dict(state["model"]) # self.optimizer.load_state_dict(state["optimizer"]) # self.epoch = state["stats"]["epoch"] def shutdown(self): """Attempts to shut down the worker.""" del self.validation_loader del self.validation_set del self.train_loader del self.training_set del self.criterion del self.optimizer del self.model if torch.cuda.is_available(): torch.cuda.empty_cache() dist.destroy_process_group() def get_node_ip(self): """Returns the IP address of the current node.""" return ray.services.get_node_ip_address() def find_free_port(self): """Finds a free port on the current node.""" return utils.find_free_port()
def parse_args(): """ Set env-vars and global args rank: <-- $SLRUM_PROCID world_size<-- $SLURM_NTASKS Master address <-- $SLRUM_NODENAME of rank 0 process (or HOSTNAME) Master port <-- any free port (doesn't really matter) """ args = parser.parse_args() ClusterManager.set_user_name(args.user_name) args.distributed = True if args.distributed == 'True' else False if not args.distributed: args.rank = 0 args.world_size = 1 args.device_id = 0 args.master_addr = 'localhost' else: args.master_addr = os.environ['HOSTNAME'] if args.backend == 'mpi': args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) args.world_size = int(os.environ['OMPI_UNIVERSE_SIZE']) args.device_id = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) else: args.rank = int(os.environ['SLURM_PROCID']) args.world_size = int(os.environ['SLURM_NTASKS']) args.device_id = int(os.environ['SLURM_LOCALID']) args.out_fname = ClusterManager.CHECKPOINT_DIR \ + args.tag \ + 'out_r' + str(args.rank) \ + '_n' + str(args.world_size) \ + '.csv' args.resume = True if args.resume == 'True' else False args.verbose = True if args.verbose == 'True' else False args.train_fast = True if args.train_fast == 'True' else False args.nesterov = True if args.nesterov == 'True' else False args.checkpoint_all = True if args.checkpoint_all == 'True' else False args.warmup = True if args.warmup == 'True' else False args.cpu_comm = True if args.backend == 'tcp' else False args.comm_device = torch.device('cpu') if args.cpu_comm else torch.device('cuda') args.overlap = True if args.overlap == 'True' else False args.single_threaded = True if args.single_threaded == 'True' else False args.data_preloaded = True if args.data_preloaded == 'True' else False args.push_sum = True if args.push_sum == 'True' else False args.all_reduce = True if args.all_reduce == 'True' else False args.overwrite_checkpoints = True if args.overwrite_checkpoints == 'True' else False args.lr_schedule = {} if args.schedule is None: args.schedule = [30, 0.1, 60, 0.1, 80, 0.1] i, epoch = 0, None for v in args.schedule: if i == 0: epoch = v elif i == 1: args.lr_schedule[epoch] = v i = (i + 1) % 2 del args.schedule # parse peers per itr sched (epoch, num_peers) args.ppi_schedule = {} if args.peers_per_itr_schedule is None: args.peers_per_itr_schedule = [0, 1] i, epoch = 0, None for v in args.peers_per_itr_schedule: if i == 0: epoch = v elif i == 1: args.ppi_schedule[epoch] = v i = (i + 1) % 2 del args.peers_per_itr_schedule # must specify how many peers to communicate from the start of training assert 0 in args.ppi_schedule if args.distributed: try: args.graph = GRAPH_TOPOLOGIES[args.graph_type]( args.rank, args.world_size, peers_per_itr=args.ppi_schedule[0]) except Exception: args.graph = None try: args.mixing = MIXING_STRATEGIES[args.mixing_strategy](args.graph) except Exception: args.mixing = None else: args.graph, args.mixing = None, None return args
def main(): global args, state, log args = parse_args() log = make_logger(args.rank, args.verbose) log.info('args: {}'.format(args)) log.info(socket.gethostname()) # seed for reproducibility torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.backends.cudnn.deterministic = True if args.distributed: # initialize torch distributed backend os.environ['MASTER_ADDR'] = args.master_addr os.environ['MASTER_PORT'] = args.master_port dist.init_process_group(backend=args.backend, world_size=args.world_size, rank=args.rank) # init model, loss, and optimizer model = init_model() if args.all_reduce: model = AllReduceDataParallel(model, distributed=args.distributed, comm_device=args.comm_device, verbose=args.verbose) else: if args.single_threaded: model = SimpleGossipDataParallel(model, distributed=args.distributed, graph=args.graph, comm_device=args.comm_device, push_sum=args.push_sum, verbose=args.verbose) else: model = GossipDataParallel(model, distributed=args.distributed, graph=args.graph, mixing=args.mixing, comm_device=args.comm_device, push_sum=args.push_sum, overlap=args.overlap, synch_freq=args.synch_freq, verbose=args.verbose) criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov) optimizer.zero_grad() # dictionary used to encode training state state = {} update_state(state, { 'epoch': 0, 'itr': 0, 'best_prec1': 0, 'is_best': True, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'elapsed_time': 0, 'batch_meter': Meter(ptag='Time').__dict__, 'data_meter': Meter(ptag='Data').__dict__, 'nn_meter': Meter(ptag='Forward/Backward').__dict__ }) # module used to relaunch jobs and handle external termination signals cmanager = ClusterManager(rank=args.rank, world_size=args.world_size, bs_fname=args.bs_fpath, model_tag=args.tag, state=state, all_workers=args.checkpoint_all) # resume from checkpoint if args.resume: f_fpath = cmanager.checkpoint_fpath if os.path.isfile(f_fpath): log.info("=> loading checkpoint '{}'" .format(f_fpath)) checkpoint = torch.load(f_fpath) update_state(state, { 'epoch': checkpoint['epoch'], 'itr': checkpoint['itr'], 'best_prec1': checkpoint['best_prec1'], 'is_best': False, 'state_dict': checkpoint['state_dict'], 'optimizer': checkpoint['optimizer'], 'elapsed_time': checkpoint['elapsed_time'], 'batch_meter': checkpoint['batch_meter'], 'data_meter': checkpoint['data_meter'], 'nn_meter': checkpoint['nn_meter'] }) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) log.info("=> loaded checkpoint '{}' (epoch {}; itr {})" .format(f_fpath, checkpoint['epoch'], checkpoint['itr'])) # synch models that are loaded if not args.overlap or args.all_reduce: model.transfer_params() else: log.info("=> no checkpoint found at '{}'" .format(cmanager.checkpoint_fpath)) # enable low-level optimization of compute graph using cuDNN library? cudnn.benchmark = True # meters used to compute timing stats batch_meter = Meter(state['batch_meter']) data_meter = Meter(state['data_meter']) nn_meter = Meter(state['nn_meter']) # initalize log file if not args.resume: with open(args.out_fname, 'w') as f: print('BEGIN-TRAINING\n' 'World-Size,{ws}\n' 'Num-DLWorkers,{nw}\n' 'Batch-Size,{bs}\n' 'Epoch,itr,BT(s),avg:BT(s),std:BT(s),' 'NT(s),avg:NT(s),std:NT(s),' 'DT(s),avg:DT(s),std:DT(s),' 'Loss,avg:Loss,Prec@1,avg:Prec@1,Prec@5,avg:Prec@5,val' .format(ws=args.world_size, nw=args.num_dataloader_workers, bs=args.batch_size), file=f) # create distributed data loaders loader, sampler = make_dataloader(args, train=True) if not args.train_fast: val_loader = make_dataloader(args, train=False) start_itr = state['itr'] start_epoch = state['epoch'] elapsed_time = state['elapsed_time'] begin_time = time.time() - state['elapsed_time'] for epoch in range(start_epoch, args.num_epochs): # deterministic seed used to load agent's subset of data sampler.set_epoch(epoch + args.seed * 90) if not args.all_reduce: # update the model's peers_per_itr attribute update_peers_per_itr(model, epoch) # start all agents' training loop at same time model.block() train(model, criterion, optimizer, batch_meter, data_meter, nn_meter, loader, epoch, start_itr, begin_time) start_itr = 0 if not args.train_fast: # update state after each epoch elapsed_time = time.time() - begin_time update_state(state, { 'epoch': epoch + 1, 'itr': start_itr, 'is_best': False, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'elapsed_time': elapsed_time, 'batch_meter': batch_meter.__dict__, 'data_meter': data_meter.__dict__, 'nn_meter': nn_meter.__dict__ }) # evaluate on validation set and save checkpoint prec1 = validate(val_loader, model, criterion) with open(args.out_fname, '+a') as f: print('{ep},{itr},{bt},{nt},{dt},' '{filler},{filler},' '{filler},{filler},' '{filler},{filler},' '{val}' .format(ep=epoch, itr=-1, bt=batch_meter, dt=data_meter, nt=nn_meter, filler=-1, val=prec1), file=f) epoch_id = epoch if not args.overwrite_checkpoints else None cmanager.save_checkpoint(epoch_id) if args.train_fast: val_loader = make_dataloader(args, train=False) prec1 = validate(val_loader, model, criterion) log.info('Test accuracy: {}'.format(prec1)) cmanager.halt = True log.info('elapsed_time {0}'.format(elapsed_time))