def validate(val_loader, model, criterion): """ Evaluate model using criterion on validation set """ losses = Meter(ptag='Loss') acc = Meter(ptag='Accuracy') # top1 = Meter(ptag='Prec@1') # top5 = Meter(ptag='Prec@5') # switch to evaluate mode model.eval() with torch.no_grad(): for i, (features, target) in enumerate(val_loader): target = target.cuda(non_blocking=True) # create one-hot vector from target kl_target = torch.zeros(target.shape[0], 1000, device='cuda').scatter_( 1, target.view(-1, 1), 1) # compute output output = model(features) loss = criterion(output, kl_target) # measure accuracy and record loss # prec1, prec5 = accuracy(output, target, topk=()) acc_val = accuracy(output, target) losses.update(loss.item(), features.size(0)) acc.update(acc_val, features.size(0)) # top1.update(prec1.item(), features.size(0)) # top5.update(prec5.item(), features.size(0)) log.info(' * Accuracy {acc.avg:.3f}'.format(acc=acc)) return acc.avg
def _setup_misc(self): # misc setup components that were in goissip_sgd config = self.config state = {} update_state( state, { 'epoch': 0, 'itr': 0, 'best_prec1': 0, 'is_best': True, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'elapsed_time': 0, 'batch_meter': Meter(ptag='Time').__dict__, 'data_meter': Meter(ptag='Data').__dict__, 'nn_meter': Meter(ptag='Forward/Backward').__dict__ }) self.state = state # module used to relaunch jobs and handle external termination signals ClusterManager.set_checkpoint_dir(config['checkpoint_dir']) self.cmanager = ClusterManager(rank=config['rank'], world_size=config['world_size'], model_tag=config['tag'], state=state, all_workers=config['checkpoint_all']) # enable low-level optimization of compute graph using cuDNN library? cudnn.benchmark = True self.batch_meter = Meter(state['batch_meter']) self.data_meter = Meter(state['data_meter']) self.nn_meter = Meter(state['nn_meter']) # initalize log file if not os.path.exists(config['out_fname']): with open(config['out_fname'], 'w') as f: print('BEGIN-TRAINING\n' 'World-Size,{ws}\n' 'Num-DLWorkers,{nw}\n' 'Batch-Size,{bs}\n' 'Epoch,itr,BT(s),avg:BT(s),std:BT(s),' 'NT(s),avg:NT(s),std:NT(s),' 'DT(s),avg:DT(s),std:DT(s),' 'Loss,avg:Loss,Prec@1,avg:Prec@1,Prec@5,avg:Prec@5,val'. format(ws=config['world_size'], nw=config['num_dataloader_workers'], bs=config['batch_size']), file=f) self.start_itr = state['itr'] self.start_epoch = state['epoch'] self.elapsed_time = state['elapsed_time'] self.begin_time = time.time() - state['elapsed_time'] self.best_val_prec1 = 0
def train(model, criterion, optimizer, batch_meter, data_meter, nn_meter, loader, epoch, itr, begin_time, num_itr_ignore): losses = Meter(ptag='Loss') acc = Meter(ptag="Accuracy") # top1 = Meter(ptag='Prec@1') # top5 = Meter(ptag='Prec@5') # switch to train mode model.train() # spoof sampler to continue from checkpoint w/o loading data all over again _train_loader = loader.__iter__() for i in range(itr): try: next(_train_loader.sample_iter) except Exception: # finished epoch but prempted before state was updated log.info('Loader spoof error attempt {}/{}'.format(i, len(loader))) return log.debug('Training (epoch {})'.format(epoch)) batch_time = time.time() for i, (batch, target) in enumerate(_train_loader, start=itr): target = target.cuda(non_blocking=True) # create one-hot vector from target # kl_target = torch.zeros(target.shape[0], 1000, device='cuda').scatter_( # 1, target.view(-1, 1), 1) if num_itr_ignore == 0: data_meter.update(time.time() - batch_time) # ----------------------------------------------------------- # # Forward/Backward pass # ----------------------------------------------------------- # nn_time = time.time() output = model(batch) loss = criterion(output, target) loss.backward() if i % 100 == 0: update_learning_rate(optimizer, epoch, itr=i, itr_per_epoch=len(loader)) optimizer.step() # optimization update optimizer.zero_grad() if not args.overlap and not args.all_reduce: log.debug('Transferring params') model.transfer_params() if num_itr_ignore == 0: nn_meter.update(time.time() - nn_time) # ----------------------------------------------------------- # if num_itr_ignore == 0: batch_meter.update(time.time() - batch_time) batch_time = time.time() log_time = time.time() # measure accuracy and record loss acc_val = accuracy(output, target) losses.update(loss.item(), batch.size(0)) acc.update(acc_val, batch.size(0)) # top1.update(prec1.item(), batch.size(0)) # top5.update(prec5.item(), batch.size(0)) if i % args.print_freq == 0: with open(args.out_fname, '+a') as f: print('{ep},{itr},{bt},{nt},{dt},' '{loss.val:.4f},{loss.avg:.4f},' '{acc.val:.3f},{acc.avg:.3f},' '-1'.format(ep=epoch, itr=i, bt=batch_meter, dt=data_meter, nt=nn_meter, loss=losses, acc=acc), file=f) if num_itr_ignore > 0: num_itr_ignore -= 1 log_time = time.time() - log_time log.debug(log_time) if (args.num_iterations_per_training_epoch != -1 and i + 1 == args.num_iterations_per_training_epoch): break with open(args.out_fname, '+a') as f: print('{ep},{itr},{bt},{nt},{dt},' '{loss.val:.4f},{loss.avg:.4f},' '{acc.val:.3f},{acc.avg:.3f},' '-1'.format(ep=epoch, itr=i, bt=batch_meter, dt=data_meter, nt=nn_meter, loss=losses, acc=acc), file=f)
def main(): global args, state, log args = parse_args() print("Successfully parsed the args") log = make_logger(args.rank, args.verbose) log.info('args: {}'.format(args)) log.info(socket.gethostname()) # seed for reproducibility torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.backends.cudnn.deterministic = True # init model, loss, and optimizer model = init_model() print("Model has been initialized") if args.all_reduce: model = torch.nn.parallel.DistributedDataParallel(model) else: model = GossipDataParallel(model, graph=args.graph, mixing=args.mixing, comm_device=args.comm_device, push_sum=args.push_sum, overlap=args.overlap, synch_freq=args.synch_freq, verbose=args.verbose, use_streams=not args.no_cuda_streams) core_criterion = nn.CrossEntropyLoss( ) #nn.KLDivLoss(reduction='batchmean').cuda() log_softmax = nn.LogSoftmax(dim=1) def criterion(input, kl_target): assert kl_target.dtype != torch.int64 loss = core_criterion(log_softmax(input), kl_target) return loss optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov) optimizer.zero_grad() # dictionary used to encode training state state = {} update_state( state, { 'epoch': 0, 'itr': 0, 'best_prec1': 0, 'is_best': True, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'elapsed_time': 0, 'batch_meter': Meter(ptag='Time').__dict__, 'data_meter': Meter(ptag='Data').__dict__, 'nn_meter': Meter(ptag='Forward/Backward').__dict__ }) # module used to relaunch jobs and handle external termination signals cmanager = ClusterManager(rank=args.rank, world_size=args.world_size, model_tag=args.tag, state=state, all_workers=args.checkpoint_all) # resume from checkpoint if args.resume: if os.path.isfile(cmanager.checkpoint_fpath): log.info("=> loading checkpoint '{}'".format( cmanager.checkpoint_fpath)) checkpoint = torch.load(cmanager.checkpoint_fpath) update_state( state, { 'epoch': checkpoint['epoch'], 'itr': checkpoint['itr'], 'best_prec1': checkpoint['best_prec1'], 'is_best': False, 'state_dict': checkpoint['state_dict'], 'optimizer': checkpoint['optimizer'], 'elapsed_time': checkpoint['elapsed_time'], 'batch_meter': checkpoint['batch_meter'], 'data_meter': checkpoint['data_meter'], 'nn_meter': checkpoint['nn_meter'] }) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) log.info("=> loaded checkpoint '{}' (epoch {}; itr {})".format( cmanager.checkpoint_fpath, checkpoint['epoch'], checkpoint['itr'])) else: log.info("=> no checkpoint found at '{}'".format( cmanager.checkpoint_fpath)) # enable low-level optimization of compute graph using cuDNN library? cudnn.benchmark = True # meters used to compute timing stats batch_meter = Meter(state['batch_meter']) data_meter = Meter(state['data_meter']) nn_meter = Meter(state['nn_meter']) # initalize log file if not os.path.exists(args.out_fname): with open(args.out_fname, 'w') as f: print('BEGIN-TRAINING\n' 'World-Size,{ws}\n' 'Num-DLWorkers,{nw}\n' 'Batch-Size,{bs}\n' 'Epoch,itr,BT(s),avg:BT(s),std:BT(s),' 'NT(s),avg:NT(s),std:NT(s),' 'DT(s),avg:DT(s),std:DT(s),' 'Loss,avg:Loss,Accuracy,avg:Accuracy,val'.format( ws=args.world_size, nw=args.num_dataloader_workers, bs=args.batch_size), file=f) # create distributed data loaders loader, sampler = make_dataloader(args, train=True) if not args.train_fast: val_loader = make_dataloader(args, train=False) start_itr = state['itr'] start_epoch = state['epoch'] elapsed_time = state['elapsed_time'] begin_time = time.time() - state['elapsed_time'] best_val_prec1 = 0 for epoch in range(start_epoch, args.num_epochs): # deterministic seed used to load agent's subset of data sampler.set_epoch(epoch + args.seed * 90) if not args.all_reduce: # update the model's peers_per_itr attribute update_peers_per_itr(model, epoch) # start all agents' training loop at same time if not args.all_reduce: model.block() train(model, criterion, optimizer, batch_meter, data_meter, nn_meter, loader, epoch, start_itr, begin_time, args.num_itr_ignore) start_itr = 0 if not args.train_fast: # update state after each epoch elapsed_time = time.time() - begin_time update_state( state, { 'epoch': epoch + 1, 'itr': start_itr, 'is_best': False, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'elapsed_time': elapsed_time, 'batch_meter': batch_meter.__dict__, 'data_meter': data_meter.__dict__, 'nn_meter': nn_meter.__dict__ }) # evaluate on validation set and save checkpoint prec1 = validate(val_loader, model, criterion) with open(args.out_fname, '+a') as f: print('{ep},{itr},{bt},{nt},{dt},' '{filler},{filler},' '{filler},{filler},' '{filler},{filler},' '{val}'.format(ep=epoch, itr=-1, bt=batch_meter, dt=data_meter, nt=nn_meter, filler=-1, val=prec1), file=f) if prec1 > best_val_prec1: update_state(state, {'is_best': True}) best_val_prec1 = prec1 epoch_id = epoch if not args.overwrite_checkpoints else None cmanager.save_checkpoint( epoch_id, requeue_on_signal=(epoch != args.num_epochs - 1)) if args.train_fast: val_loader = make_dataloader(args, train=False) acc = validate(val_loader, model, criterion) log.info('Test accuracy: {}'.format(acc)) log.info('elapsed_time {0}'.format(elapsed_time))
def validate(val_loader, model, criterion): """ Evaluate model using criterion on validation set """ losses = Meter(ptag='Loss') top1 = Meter(ptag='Prec@1') top5 = Meter(ptag='Prec@5') # switch to evaluate mode model.eval() model.disable_gossip() with torch.no_grad(): for i, (features, target) in enumerate(val_loader): target = target.cuda(non_blocking=True) # compute output output = model(features) loss = criterion(output, target) # measure accuracy and record loss prec1, prec5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), features.size(0)) top1.update(prec1.item(), features.size(0)) top5.update(prec5.item(), features.size(0)) log.info(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'.format( top1=top1, top5=top5)) return top1.avg
def train(model, criterion, optimizer, batch_meter, data_meter, nn_meter, loader, epoch, itr, begin_time): losses = Meter(ptag='Loss') top1 = Meter(ptag='Prec@1') top5 = Meter(ptag='Prec@5') # switch to train mode model.train() # spoof sampler to continue from checkpoint w/o loading data all over again _train_loader = loader.__iter__() for i in range(itr): try: next(_train_loader.sample_iter) except Exception: # finished epoch but prempted before state was updated log.info('Loader spoof error attempt {}/{}'.format(i, len(loader))) return log.debug('Training (epoch {})'.format(epoch)) model.enable_gossip() batch_time = time.time() for i, (batch, target) in enumerate(_train_loader, start=itr): target = target.cuda(non_blocking=True) data_meter.update(time.time() - batch_time) # ----------------------------------------------------------- # # Forward/Backward pass # ----------------------------------------------------------- # nn_time = time.time() output = model(batch) loss = criterion(output, target) bilat_freq = 100 if i == 0: update_global_iteration_counter(itr=1, itr_per_epoch=len(loader)) update_bilat_learning_rate(model, itr_per_epoch=len(loader)) elif (i + args.rank) % (bilat_freq) == 0: update_global_iteration_counter(itr=bilat_freq, itr_per_epoch=len(loader)) update_bilat_learning_rate(model, itr_per_epoch=len(loader)) loss.backward() update_learning_rate(optimizer, epoch, itr=i, itr_per_epoch=len(loader)) optimizer.step() # optimization update optimizer.zero_grad() nn_meter.update(time.time() - nn_time) # ----------------------------------------------------------- # batch_meter.update(time.time() - batch_time) batch_time = time.time() log_time = time.time() # measure accuracy and record loss prec1, prec5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), batch.size(0)) top1.update(prec1.item(), batch.size(0)) top5.update(prec5.item(), batch.size(0)) if i % args.print_freq == 0: ep = args.global_epoch itr = args.global_itr % (len(loader) * args.world_size) with open(args.out_fname, '+a') as f: print('{ep},{itr},{bt},{nt},{dt},' '{loss.val:.4f},{loss.avg:.4f},' '{top1.val:.3f},{top1.avg:.3f},' '{top5.val:.3f},{top5.avg:.3f},-1'.format(ep=ep, itr=itr, bt=batch_meter, dt=data_meter, nt=nn_meter, loss=losses, top1=top1, top5=top5), file=f) log_time = time.time() - log_time log.debug(log_time) with open(args.out_fname, '+a') as f: print('{ep},{itr},{bt},{nt},{dt},' '{loss.val:.4f},{loss.avg:.4f},' '{top1.val:.3f},{top1.avg:.3f},' '{top5.val:.3f},{top5.avg:.3f},-1'.format(ep=epoch, itr=i, bt=batch_meter, dt=data_meter, nt=nn_meter, loss=losses, top1=top1, top5=top5), file=f)
def main(): global args, state, log args = parse_args() log = make_logger(args.rank, args.verbose) log.info('args: {}'.format(args)) log.info(socket.gethostname()) # seed for reproducibility torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.backends.cudnn.deterministic = True # init model, loss, and optimizer model = init_model() assert args.bilat and not args.all_reduce model = BilatGossipDataParallel( model, master_addr=args.master_addr, master_port=args.master_port, backend=args.backend, world_size=args.world_size, rank=args.rank, graph_class=args.graph_class, mixing_class=args.mixing_class, comm_device=args.comm_device, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov, verbose=args.verbose, num_peers=args.ppi_schedule[0], network_interface_type=args.network_interface_type) criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov) optimizer.zero_grad() # dictionary used to encode training state state = {} update_state( state, { 'epoch': 0, 'itr': 0, 'best_prec1': 0, 'is_best': True, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'elapsed_time': 0, 'batch_meter': Meter(ptag='Time').__dict__, 'data_meter': Meter(ptag='Data').__dict__, 'nn_meter': Meter(ptag='Forward/Backward').__dict__ }) # module used to relaunch jobs and handle external termination signals cmanager = ClusterManager(rank=args.rank, world_size=args.world_size, model_tag=args.tag, state=state, all_workers=args.checkpoint_all) # resume from checkpoint if args.resume: if os.path.isfile(cmanager.checkpoint_fpath): log.info("=> loading checkpoint '{}'".format( cmanager.checkpoint_fpath)) checkpoint = torch.load(cmanager.checkpoint_fpath) update_state( state, { 'epoch': checkpoint['epoch'], 'itr': checkpoint['itr'], 'best_prec1': checkpoint['best_prec1'], 'is_best': False, 'state_dict': checkpoint['state_dict'], 'optimizer': checkpoint['optimizer'], 'elapsed_time': checkpoint['elapsed_time'], 'batch_meter': checkpoint['batch_meter'], 'data_meter': checkpoint['data_meter'], 'nn_meter': checkpoint['nn_meter'] }) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) log.info("=> loaded checkpoint '{}' (epoch {}; itr {})".format( cmanager.checkpoint_fpath, checkpoint['epoch'], checkpoint['itr'])) else: log.info("=> no checkpoint found at '{}'".format( cmanager.checkpoint_fpath)) # enable low-level optimization of compute graph using cuDNN library? cudnn.benchmark = True # meters used to compute timing stats batch_meter = Meter(state['batch_meter']) data_meter = Meter(state['data_meter']) nn_meter = Meter(state['nn_meter']) # initalize log file if not args.resume: with open(args.out_fname, 'w') as f: print( 'BEGIN-TRAINING\n' 'World-Size,{ws}\n' 'Num-DLWorkers,{nw}\n' 'Batch-Size,{bs}\n' 'Epoch,itr,BT(s),avg:BT(s),std:BT(s),' 'NT(s),avg:NT(s),std:NT(s),' 'DT(s),avg:DT(s),std:DT(s),' 'Loss,avg:Loss,Prec@1,avg:Prec@1,Prec@5,avg:Prec@5,val'.format( ws=args.world_size, nw=args.num_dataloader_workers, bs=args.batch_size), file=f) # create distributed data loaders loader, sampler = make_dataloader(args, train=True) if not args.train_fast: val_loader = make_dataloader(args, train=False) # start all agents' training loop at same time model.block() start_itr = state['itr'] start_epoch = state['epoch'] elapsed_time = state['elapsed_time'] begin_time = time.time() - state['elapsed_time'] epoch = start_epoch stopping_criterion = epoch >= args.num_epochs while not stopping_criterion: # deterministic seed used to load agent's subset of data sampler.set_epoch(epoch + args.seed * 90) train(model, criterion, optimizer, batch_meter, data_meter, nn_meter, loader, epoch, start_itr, begin_time) start_itr = 0 if not args.train_fast: # update state after each epoch elapsed_time = time.time() - begin_time update_state( state, { 'epoch': epoch + 1, 'itr': start_itr, 'is_best': False, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'elapsed_time': elapsed_time, 'batch_meter': batch_meter.__dict__, 'data_meter': data_meter.__dict__, 'nn_meter': nn_meter.__dict__ }) # evaluate on validation set and save checkpoint prec1 = validate(val_loader, model, criterion) with open(args.out_fname, '+a') as f: print('{ep},{itr},{bt},{nt},{dt},' '{filler},{filler},' '{filler},{filler},' '{filler},{filler},' '{val}'.format(ep=epoch, itr=-1, bt=batch_meter, dt=data_meter, nt=nn_meter, filler=-1, val=prec1), file=f) cmanager.save_checkpoint() # sycnhronize models at the end of validation run model.block() epoch += 1 stopping_criterion = args.global_epoch >= args.num_epochs if args.train_fast: val_loader = make_dataloader(args, train=False) prec1 = validate(val_loader, model, criterion) log.info('Test accuracy: {}'.format(prec1)) log.info('elapsed_time {0}'.format(elapsed_time))
def validate(val_loader, model, criterion): """ Evaluate model using criterion on validation set """ losses = Meter(ptag='Loss') top1 = Meter(ptag='Prec@1') top5 = Meter(ptag='Prec@5') # switch to evaluate mode model.eval() with torch.no_grad(): for i, (features, target) in enumerate(val_loader): # if args.fp16: # features = features.cuda(non_blocking=True).half() # This is not needed but let it be since there is no harm target = target.cuda(non_blocking=True) # create one-hot vector from target kl_target = torch.zeros(target.shape[0], 1000, device='cuda').scatter_( 1, target.view(-1, 1), 1) # compute output output = model(features) loss = criterion(output, kl_target) # measure accuracy and record loss prec1, prec5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), features.size(0)) top1.update(prec1.item(), features.size(0)) top5.update(prec5.item(), features.size(0)) # log.info(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' # .format(top1=top1, top5=top5)) log.info( ' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {losses.avg:.3f}' .format(top1=top1, top5=top5, losses=losses)) return top1.avg
def main(): global args, state, log args = parse_args() log = make_logger(args.rank, args.verbose) log.info('args: {}'.format(args)) log.info(socket.gethostname()) # seed for reproducibility torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.backends.cudnn.deterministic = True if args.distributed: # initialize torch distributed backend os.environ['MASTER_ADDR'] = args.master_addr os.environ['MASTER_PORT'] = args.master_port dist.init_process_group(backend=args.backend, world_size=args.world_size, rank=args.rank) # init model, loss, and optimizer model = init_model() if args.all_reduce: model = AllReduceDataParallel(model, distributed=args.distributed, comm_device=args.comm_device, verbose=args.verbose) else: if args.single_threaded: model = SimpleGossipDataParallel(model, distributed=args.distributed, graph=args.graph, comm_device=args.comm_device, push_sum=args.push_sum, verbose=args.verbose) else: model = GossipDataParallel(model, distributed=args.distributed, graph=args.graph, mixing=args.mixing, comm_device=args.comm_device, push_sum=args.push_sum, overlap=args.overlap, synch_freq=args.synch_freq, verbose=args.verbose) criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov) optimizer.zero_grad() # dictionary used to encode training state state = {} update_state(state, { 'epoch': 0, 'itr': 0, 'best_prec1': 0, 'is_best': True, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'elapsed_time': 0, 'batch_meter': Meter(ptag='Time').__dict__, 'data_meter': Meter(ptag='Data').__dict__, 'nn_meter': Meter(ptag='Forward/Backward').__dict__ }) # module used to relaunch jobs and handle external termination signals cmanager = ClusterManager(rank=args.rank, world_size=args.world_size, bs_fname=args.bs_fpath, model_tag=args.tag, state=state, all_workers=args.checkpoint_all) # resume from checkpoint if args.resume: f_fpath = cmanager.checkpoint_fpath if os.path.isfile(f_fpath): log.info("=> loading checkpoint '{}'" .format(f_fpath)) checkpoint = torch.load(f_fpath) update_state(state, { 'epoch': checkpoint['epoch'], 'itr': checkpoint['itr'], 'best_prec1': checkpoint['best_prec1'], 'is_best': False, 'state_dict': checkpoint['state_dict'], 'optimizer': checkpoint['optimizer'], 'elapsed_time': checkpoint['elapsed_time'], 'batch_meter': checkpoint['batch_meter'], 'data_meter': checkpoint['data_meter'], 'nn_meter': checkpoint['nn_meter'] }) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) log.info("=> loaded checkpoint '{}' (epoch {}; itr {})" .format(f_fpath, checkpoint['epoch'], checkpoint['itr'])) # synch models that are loaded if not args.overlap or args.all_reduce: model.transfer_params() else: log.info("=> no checkpoint found at '{}'" .format(cmanager.checkpoint_fpath)) # enable low-level optimization of compute graph using cuDNN library? cudnn.benchmark = True # meters used to compute timing stats batch_meter = Meter(state['batch_meter']) data_meter = Meter(state['data_meter']) nn_meter = Meter(state['nn_meter']) # initalize log file if not args.resume: with open(args.out_fname, 'w') as f: print('BEGIN-TRAINING\n' 'World-Size,{ws}\n' 'Num-DLWorkers,{nw}\n' 'Batch-Size,{bs}\n' 'Epoch,itr,BT(s),avg:BT(s),std:BT(s),' 'NT(s),avg:NT(s),std:NT(s),' 'DT(s),avg:DT(s),std:DT(s),' 'Loss,avg:Loss,Prec@1,avg:Prec@1,Prec@5,avg:Prec@5,val' .format(ws=args.world_size, nw=args.num_dataloader_workers, bs=args.batch_size), file=f) # create distributed data loaders loader, sampler = make_dataloader(args, train=True) if not args.train_fast: val_loader = make_dataloader(args, train=False) start_itr = state['itr'] start_epoch = state['epoch'] elapsed_time = state['elapsed_time'] begin_time = time.time() - state['elapsed_time'] for epoch in range(start_epoch, args.num_epochs): # deterministic seed used to load agent's subset of data sampler.set_epoch(epoch + args.seed * 90) if not args.all_reduce: # update the model's peers_per_itr attribute update_peers_per_itr(model, epoch) # start all agents' training loop at same time model.block() train(model, criterion, optimizer, batch_meter, data_meter, nn_meter, loader, epoch, start_itr, begin_time) start_itr = 0 if not args.train_fast: # update state after each epoch elapsed_time = time.time() - begin_time update_state(state, { 'epoch': epoch + 1, 'itr': start_itr, 'is_best': False, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'elapsed_time': elapsed_time, 'batch_meter': batch_meter.__dict__, 'data_meter': data_meter.__dict__, 'nn_meter': nn_meter.__dict__ }) # evaluate on validation set and save checkpoint prec1 = validate(val_loader, model, criterion) with open(args.out_fname, '+a') as f: print('{ep},{itr},{bt},{nt},{dt},' '{filler},{filler},' '{filler},{filler},' '{filler},{filler},' '{val}' .format(ep=epoch, itr=-1, bt=batch_meter, dt=data_meter, nt=nn_meter, filler=-1, val=prec1), file=f) epoch_id = epoch if not args.overwrite_checkpoints else None cmanager.save_checkpoint(epoch_id) if args.train_fast: val_loader = make_dataloader(args, train=False) prec1 = validate(val_loader, model, criterion) log.info('Test accuracy: {}'.format(prec1)) cmanager.halt = True log.info('elapsed_time {0}'.format(elapsed_time))
def train(config, model, criterion, optimizer, batch_meter, data_meter, nn_meter, loader, epoch, itr, begin_time, num_itr_ignore, log): losses = Meter(ptag='Loss') top1 = Meter(ptag='Prec@1') top5 = Meter(ptag='Prec@5') # switch to train mode model.train() # spoof sampler to continue from checkpoint w/o loading data all over again _train_loader = loader.__iter__() for i in range(itr): try: next(_train_loader.sample_iter) except Exception: # finished epoch but prempted before state was updated log.info('Loader spoof error attempt {}/{}'.format(i, len(loader))) return log.debug('Training (epoch {})'.format(epoch)) batch_time = time.time() for i, (batch, target) in enumerate(_train_loader, start=itr): # if args.fp16: # batch = batch.cuda(non_blocking=True).half() target = target.cuda(non_blocking=True) # create one-hot vector from target kl_target = torch.zeros(target.shape[0], 1000, device='cuda').scatter_( 1, target.view(-1, 1), 1) if num_itr_ignore == 0: data_meter.update(time.time() - batch_time) # ----------------------------------------------------------- # # Forward/Backward pass # ----------------------------------------------------------- # nn_time = time.time() output = model(batch) loss = criterion(output, kl_target) # if args.fp16: # if args.amp: # with amp_handle.scale_loss(loss, optimizer) as scaled_loss: # scaled_loss.backward() # else: # optimizer.backward(loss) # else: # loss.backward() loss.backward() if i % 100 == 0: update_learning_rate(config, optimizer, epoch, log, itr=i, itr_per_epoch=len(loader)) optimizer.step() # optimization update optimizer.zero_grad() if not config['overlap'] and not config['all_reduce']: log.debug('Transferring params') model.transfer_params() if num_itr_ignore == 0: nn_meter.update(time.time() - nn_time) # ----------------------------------------------------------- # if num_itr_ignore == 0: batch_meter.update(time.time() - batch_time) batch_time = time.time() log_time = time.time() # measure accuracy and record loss prec1, prec5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), batch.size(0)) top1.update(prec1.item(), batch.size(0)) top5.update(prec5.item(), batch.size(0)) if i % config['print_freq'] == 0: with open(config['out_fname'], '+a') as f: print('{ep},{itr},{bt},{nt},{dt},' '{loss.val:.4f},{loss.avg:.4f},' '{top1.val:.3f},{top1.avg:.3f},' '{top5.val:.3f},{top5.avg:.3f},-1' .format(ep=epoch, itr=i, bt=batch_meter, dt=data_meter, nt=nn_meter, loss=losses, top1=top1, top5=top5), file=f) if num_itr_ignore > 0: num_itr_ignore -= 1 log_time = time.time() - log_time log.debug(log_time) if (config['num_iterations_per_training_epoch'] != -1 and i+1 == config['num_iterations_per_training_epoch']): break with open(config['out_fname'], '+a') as f: print('{ep},{itr},{bt},{nt},{dt},' '{loss.val:.4f},{loss.avg:.4f},' '{top1.val:.3f},{top1.avg:.3f},' '{top5.val:.3f},{top5.avg:.3f},-1' .format(ep=epoch, itr=i, bt=batch_meter, dt=data_meter, nt=nn_meter, loss=losses, top1=top1, top5=top5), file=f) return losses.avg, top1.avg, top5.avg