def main_worker(gpu, ngpus_per_node, args): global best_acc1 args.gpu = gpu if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) if args.distributed: if args.dist_url == "env://" and args.rank == -1: args.rank = int(os.environ["RANK"]) if args.multiprocessing_distributed: # For multiprocessing distributed training, rank needs to be the # global rank among all the processes args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) num_classes = 1000 # create model print("=> creating model '{}'".format(args.arch)) model = models.__dict__[args.arch](stem_type=args.stem_type, num_classes=num_classes, block_type=models.PreBasicBlock, activation=nn.PReLU) bchef = BinaryChef('recepies/imagenet-baseline.yaml') model = bchef.run_step(model, args.step) print(model) print('Num paramters: {}'.format(count_parameters(model))) if args.distributed: # For multiprocessing distributed, DistributedDataParallel constructor # should always set the single device scope, otherwise, # DistributedDataParallel will use all available devices. if args.gpu is not None: torch.cuda.set_device(args.gpu) model.cuda(args.gpu) # When using a single GPU per process and per # DistributedDataParallel, we need to divide the batch size # ourselves based on the total number of GPUs we have args.batch_size = int(args.batch_size / ngpus_per_node) args.workers = int( (args.workers + ngpus_per_node - 1) / ngpus_per_node) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu]) else: model.cuda() # DistributedDataParallel will divide and allocate batch_size to all # available GPUs if device_ids are not set model = torch.nn.parallel.DistributedDataParallel(model) elif args.gpu is not None: torch.cuda.set_device(args.gpu) model = model.cuda(args.gpu) else: # DataParallel will divide and allocate batch_size to all available GPUs if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): model.features = torch.nn.DataParallel(model.features) model.cuda() else: model = torch.nn.DataParallel(model).cuda() # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda(args.gpu) parameters = model.parameters() if args.optimizer == 'adamw': wd = args.weight_decay if args.step == 0 else 0 optimizer = torch.optim.AdamW(parameters, args.lr, weight_decay=wd) elif args.optimizer == 'adam': optimizer = torch.optim.Adam(parameters, args.lr) elif args.optimizer == 'sgd': wd = 0 if args.step > 0 else args.weight_decay optimizer = torch.optim.SGD(parameters, args.lr, momentum=args.momentum, weight_decay=wd) else: raise ValueError('Unknown optimizer selected: {}'.format( args.optimizer)) if args.scheduler == 'multistep': milestone = [40, 70, 80, 100, 110] lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[x - args.warmup for x in milestone], gamma=0.1) # elif args.scheduler == 'cosine': lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, float(args.epochs - args.warmup), eta_min=0) else: raise ValueError('Unknown schduler selected: {}'.format( args.scheduler)) if args.warmup > 0: print('=> Applying warmup ({} epochs)'.format(args.warmup)) lr_scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=args.warmup, after_scheduler=lr_scheduler) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) if args.gpu is None: checkpoint = torch.load(args.resume) else: # Map model to be loaded to specified single gpu. loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.resume, map_location=loc) if args.resume_epoch: args.start_epoch = checkpoint['epoch'] best_acc1 = checkpoint['best_acc1'] if args.gpu is not None: pass # best_acc1 may be from a checkpoint from a different GPU #best_acc1 = best_acc1.to(args.gpu) try: model.load_state_dict(checkpoint['state_dict']) if not ('adam' in args.optimizer and 'sgd' in args.resume): print('=> Loading optimizer...') #optimizer.load_state_dict(checkpoint['optimizer']) except: print( '=> Warning: dict model mismatch, loading with strict = False' ) model.load_state_dict(checkpoint['state_dict'], strict=False) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) # Reset learning rate for g in optimizer.param_groups: g['lr'] = args.lr if args.start_epoch > 0: print('Advancing the scheduler to epoch {}'.format(args.start_epoch)) for i in range(args.start_epoch): lr_scheduler.step() cudnn.benchmark = True # Data loading code traindir = os.path.join(args.data, 'train') valdir = os.path.join(args.data, 'valid') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) transforms_train = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) transforms_val = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ]) train_dataset = datasets.ImageFolder(traindir, transforms_train) val_dataset = datasets.ImageFolder(valdir, transforms_val) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) else: train_sampler = None train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) if args.evaluate: validate(val_loader, model, criterion, args) return show_logs = (not args.multiprocessing_distributed) or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0) for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) if args.scheduler == 'cosine': lr_scheduler.step(epoch) else: lr_scheduler.step() if show_logs: print('New lr: {}'.format(lr_scheduler.get_last_lr())) # train for one epoch train(train_loader, model, criterion, optimizer, epoch, args, show_logs) # evaluate on validation set acc1 = validate(val_loader, model, criterion, args, show_logs) # remember best acc@1 and save checkpoint is_best = acc1 > best_acc1 best_acc1 = max(acc1, best_acc1) print('Current best: {}'.format(best_acc1)) if show_logs: save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_acc1': best_acc1, 'optimizer': optimizer.state_dict(), }, is_best, args.output_dir)
def main(pargs): # this should be global global have_wandb #init distributed training comm.init(pargs.wireup_method) comm_rank = comm.get_rank() comm_local_rank = comm.get_local_rank() comm_size = comm.get_size() # set up logging pargs.logging_frequency = max([pargs.logging_frequency, 1]) log_file = os.path.normpath( os.path.join(pargs.output_dir, "logs", pargs.run_tag + ".log")) logger = mll.mlperf_logger(log_file, "deepcam", "Umbrella Corp.") logger.log_start(key="init_start", sync=True) logger.log_event(key="cache_clear") #set seed seed = 333 logger.log_event(key="seed", value=seed) # Some setup torch.manual_seed(seed) if torch.cuda.is_available(): device = torch.device("cuda", comm_local_rank) torch.cuda.manual_seed(seed) #necessary for AMP to work torch.cuda.set_device(device) # TEST: allowed? Valuable? #torch.backends.cudnn.benchark = True else: device = torch.device("cpu") #visualize? visualize = (pargs.training_visualization_frequency > 0) or (pargs.validation_visualization_frequency > 0) #set up directories root_dir = os.path.join(pargs.data_dir_prefix) output_dir = pargs.output_dir plot_dir = os.path.join(output_dir, "plots") if comm_rank == 0: if not os.path.isdir(output_dir): os.makedirs(output_dir) if visualize and not os.path.isdir(plot_dir): os.makedirs(plot_dir) # Setup WandB if not pargs.enable_wandb: have_wandb = False if have_wandb and (comm_rank == 0): # get wandb api token certfile = os.path.join(pargs.wandb_certdir, ".wandbirc") try: with open(certfile) as f: token = f.readlines()[0].replace("\n", "").split() wblogin = token[0] wbtoken = token[1] except IOError: print("Error, cannot open WandB certificate {}.".format(certfile)) have_wandb = False if have_wandb: # log in: that call can be blocking, it should be quick sp.call(["wandb", "login", wbtoken]) #init db and get config resume_flag = pargs.run_tag if pargs.resume_logging else False wandb.init(entity=wblogin, project='deepcam', name=pargs.run_tag, id=pargs.run_tag, resume=resume_flag) config = wandb.config #set general parameters config.root_dir = root_dir config.output_dir = pargs.output_dir config.max_epochs = pargs.max_epochs config.local_batch_size = pargs.local_batch_size config.num_workers = comm_size config.channels = pargs.channels config.optimizer = pargs.optimizer config.start_lr = pargs.start_lr config.adam_eps = pargs.adam_eps config.weight_decay = pargs.weight_decay config.model_prefix = pargs.model_prefix config.amp_opt_level = pargs.amp_opt_level config.loss_weight_pow = pargs.loss_weight_pow config.lr_warmup_steps = pargs.lr_warmup_steps config.lr_warmup_factor = pargs.lr_warmup_factor # lr schedule if applicable if pargs.lr_schedule: for key in pargs.lr_schedule: config.update( {"lr_schedule_" + key: pargs.lr_schedule[key]}, allow_val_change=True) # Logging hyperparameters logger.log_event(key="global_batch_size", value=(pargs.local_batch_size * comm_size)) logger.log_event(key="opt_name", value=pargs.optimizer) logger.log_event(key="opt_base_learning_rate", value=pargs.start_lr * pargs.lr_warmup_factor) logger.log_event(key="opt_learning_rate_warmup_steps", value=pargs.lr_warmup_steps) logger.log_event(key="opt_learning_rate_warmup_factor", value=pargs.lr_warmup_factor) logger.log_event(key="opt_epsilon", value=pargs.adam_eps) # Define architecture n_input_channels = len(pargs.channels) n_output_channels = 3 net = deeplab_xception.DeepLabv3_plus(n_input=n_input_channels, n_classes=n_output_channels, os=16, pretrained=False, rank=comm_rank) net.to(device) #select loss loss_pow = pargs.loss_weight_pow #some magic numbers class_weights = [ 0.986267818390377**loss_pow, 0.0004578708870701058**loss_pow, 0.01327431072255291**loss_pow ] fpw_1 = 2.61461122397522257612 fpw_2 = 1.71641974795896018744 criterion = losses.fp_loss #select optimizer optimizer = None if pargs.optimizer == "Adam": optimizer = optim.Adam(net.parameters(), lr=pargs.start_lr, eps=pargs.adam_eps, weight_decay=pargs.weight_decay) elif pargs.optimizer == "AdamW": optimizer = optim.AdamW(net.parameters(), lr=pargs.start_lr, eps=pargs.adam_eps, weight_decay=pargs.weight_decay) elif have_apex and (pargs.optimizer == "LAMB"): optimizer = aoptim.FusedLAMB(net.parameters(), lr=pargs.start_lr, eps=pargs.adam_eps, weight_decay=pargs.weight_decay) else: raise NotImplementedError("Error, optimizer {} not supported".format( pargs.optimizer)) if have_apex: #wrap model and opt into amp net, optimizer = amp.initialize(net, optimizer, opt_level=pargs.amp_opt_level) #make model distributed net = DDP(net) #restart from checkpoint if desired #if (comm_rank == 0) and (pargs.checkpoint): #load it on all ranks for now if pargs.checkpoint: checkpoint = torch.load(pargs.checkpoint, map_location=device) start_step = checkpoint['step'] start_epoch = checkpoint['epoch'] optimizer.load_state_dict(checkpoint['optimizer']) net.load_state_dict(checkpoint['model']) if have_apex: amp.load_state_dict(checkpoint['amp']) else: start_step = 0 start_epoch = 0 #select scheduler if pargs.lr_schedule: scheduler_after = ph.get_lr_schedule(pargs.start_lr, pargs.lr_schedule, optimizer, last_step=start_step) # LR warmup if pargs.lr_warmup_steps > 0: if have_warmup_scheduler: scheduler = GradualWarmupScheduler( optimizer, multiplier=pargs.lr_warmup_factor, total_epoch=pargs.lr_warmup_steps, after_scheduler=scheduler_after) # Throw an error if the package is not found else: raise Exception( f'Requested {pargs.lr_warmup_steps} LR warmup steps ' 'but warmup scheduler not found. Install it from ' 'https://github.com/ildoonet/pytorch-gradual-warmup-lr') else: scheduler = scheduler_after #broadcast model and optimizer state steptens = torch.tensor(np.array([start_step, start_epoch]), requires_grad=False).to(device) dist.broadcast(steptens, src=0) ##broadcast model and optimizer state #hvd.broadcast_parameters(net.state_dict(), root_rank = 0) #hvd.broadcast_optimizer_state(optimizer, root_rank = 0) #unpack the bcasted tensor start_step = steptens.cpu().numpy()[0] start_epoch = steptens.cpu().numpy()[1] # Set up the data feeder # train train_dir = os.path.join(root_dir, "train") train_set = cam.CamDataset(train_dir, statsfile=os.path.join(root_dir, 'stats.h5'), channels=pargs.channels, allow_uneven_distribution=False, shuffle=True, preprocess=True, comm_size=comm_size, comm_rank=comm_rank) train_loader = DataLoader( train_set, pargs.local_batch_size, num_workers=min([pargs.max_inter_threads, pargs.local_batch_size]), pin_memory=True, drop_last=True) # validation: we only want to shuffle the set if we are cutting off validation after a certain number of steps validation_dir = os.path.join(root_dir, "validation") validation_set = cam.CamDataset(validation_dir, statsfile=os.path.join( root_dir, 'stats.h5'), channels=pargs.channels, allow_uneven_distribution=True, shuffle=(pargs.max_validation_steps is not None), preprocess=True, comm_size=comm_size, comm_rank=comm_rank) # use batch size = 1 here to make sure that we do not drop a sample validation_loader = DataLoader( validation_set, 1, num_workers=min([pargs.max_inter_threads, pargs.local_batch_size]), pin_memory=True, drop_last=True) # log size of datasets logger.log_event(key="train_samples", value=train_set.global_size) if pargs.max_validation_steps is not None: val_size = min([ validation_set.global_size, pargs.max_validation_steps * pargs.local_batch_size * comm_size ]) else: val_size = validation_set.global_size logger.log_event(key="eval_samples", value=val_size) # do sanity check if pargs.max_validation_steps is not None: logger.log_event(key="invalid_submission") #for visualization #if visualize: # viz = vizc.CamVisualizer() # Train network if have_wandb and (comm_rank == 0): wandb.watch(net) step = start_step epoch = start_epoch current_lr = pargs.start_lr if not pargs.lr_schedule else scheduler.get_last_lr( )[0] stop_training = False net.train() # start trining logger.log_end(key="init_stop", sync=True) logger.log_start(key="run_start", sync=True) # training loop while True: # start epoch logger.log_start(key="epoch_start", metadata={ 'epoch_num': epoch + 1, 'step_num': step }, sync=True) # epoch loop for inputs, label, filename in train_loader: # send to device inputs = inputs.to(device) label = label.to(device) # forward pass outputs = net.forward(inputs) # Compute loss and average across nodes loss = criterion(outputs, label, weight=class_weights, fpw_1=fpw_1, fpw_2=fpw_2) # Backprop optimizer.zero_grad() if have_apex: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() optimizer.step() # step counter step += 1 if pargs.lr_schedule: current_lr = scheduler.get_last_lr()[0] scheduler.step() #visualize if requested #if (step % pargs.training_visualization_frequency == 0) and (comm_rank == 0): # # Compute predictions # predictions = torch.max(outputs, 1)[1] # # # extract sample id and data tensors # sample_idx = np.random.randint(low=0, high=label.shape[0]) # plot_input = inputs.detach()[sample_idx, 0,...].cpu().numpy() # plot_prediction = predictions.detach()[sample_idx,...].cpu().numpy() # plot_label = label.detach()[sample_idx,...].cpu().numpy() # # # create filenames # outputfile = os.path.basename(filename[sample_idx]).replace("data-", "training-").replace(".h5", ".png") # outputfile = os.path.join(plot_dir, outputfile) # # # plot # viz.plot(filename[sample_idx], outputfile, plot_input, plot_prediction, plot_label) # # #log if requested # if have_wandb: # img = Image.open(outputfile) # wandb.log({"train_examples": [wandb.Image(img, caption="Prediction vs. Ground Truth")]}, step = step) #log if requested if (step % pargs.logging_frequency == 0): # allreduce for loss loss_avg = loss.detach() dist.reduce(loss_avg, dst=0, op=dist.ReduceOp.SUM) loss_avg_train = loss_avg.item() / float(comm_size) # Compute score predictions = torch.max(outputs, 1)[1] iou = utils.compute_score(predictions, label, device_id=device, num_classes=3) iou_avg = iou.detach() dist.reduce(iou_avg, dst=0, op=dist.ReduceOp.SUM) iou_avg_train = iou_avg.item() / float(comm_size) logger.log_event(key="learning_rate", value=current_lr, metadata={ 'epoch_num': epoch + 1, 'step_num': step }) logger.log_event(key="train_accuracy", value=iou_avg_train, metadata={ 'epoch_num': epoch + 1, 'step_num': step }) logger.log_event(key="train_loss", value=loss_avg_train, metadata={ 'epoch_num': epoch + 1, 'step_num': step }) if have_wandb and (comm_rank == 0): wandb.log( {"train_loss": loss_avg.item() / float(comm_size)}, step=step) wandb.log( {"train_accuracy": iou_avg.item() / float(comm_size)}, step=step) wandb.log({"learning_rate": current_lr}, step=step) wandb.log({"epoch": epoch + 1}, step=step) # validation step if desired if (step % pargs.validation_frequency == 0): logger.log_start(key="eval_start", metadata={'epoch_num': epoch + 1}) #eval net.eval() count_sum_val = torch.Tensor([0.]).to(device) loss_sum_val = torch.Tensor([0.]).to(device) iou_sum_val = torch.Tensor([0.]).to(device) # disable gradients with torch.no_grad(): # iterate over validation sample step_val = 0 # only print once per eval at most visualized = False for inputs_val, label_val, filename_val in validation_loader: #send to device inputs_val = inputs_val.to(device) label_val = label_val.to(device) # forward pass outputs_val = net.forward(inputs_val) # Compute loss and average across nodes loss_val = criterion(outputs_val, label_val, weight=class_weights, fpw_1=fpw_1, fpw_2=fpw_2) loss_sum_val += loss_val #increase counter count_sum_val += 1. # Compute score predictions_val = torch.max(outputs_val, 1)[1] iou_val = utils.compute_score(predictions_val, label_val, device_id=device, num_classes=3) iou_sum_val += iou_val # Visualize #if (step_val % pargs.validation_visualization_frequency == 0) and (not visualized) and (comm_rank == 0): # #extract sample id and data tensors # sample_idx = np.random.randint(low=0, high=label_val.shape[0]) # plot_input = inputs_val.detach()[sample_idx, 0,...].cpu().numpy() # plot_prediction = predictions_val.detach()[sample_idx,...].cpu().numpy() # plot_label = label_val.detach()[sample_idx,...].cpu().numpy() # # #create filenames # outputfile = os.path.basename(filename[sample_idx]).replace("data-", "validation-").replace(".h5", ".png") # outputfile = os.path.join(plot_dir, outputfile) # # #plot # viz.plot(filename[sample_idx], outputfile, plot_input, plot_prediction, plot_label) # visualized = True # # #log if requested # if have_wandb: # img = Image.open(outputfile) # wandb.log({"eval_examples": [wandb.Image(img, caption="Prediction vs. Ground Truth")]}, step = step) #increase eval step counter step_val += 1 if (pargs.max_validation_steps is not None ) and step_val > pargs.max_validation_steps: break # average the validation loss dist.all_reduce(count_sum_val, op=dist.ReduceOp.SUM) dist.all_reduce(loss_sum_val, op=dist.ReduceOp.SUM) dist.all_reduce(iou_sum_val, op=dist.ReduceOp.SUM) loss_avg_val = loss_sum_val.item() / count_sum_val.item() iou_avg_val = iou_sum_val.item() / count_sum_val.item() # print results logger.log_event(key="eval_accuracy", value=iou_avg_val, metadata={ 'epoch_num': epoch + 1, 'step_num': step }) logger.log_event(key="eval_loss", value=loss_avg_val, metadata={ 'epoch_num': epoch + 1, 'step_num': step }) # log in wandb if have_wandb and (comm_rank == 0): wandb.log({"eval_loss": loss_avg_val}, step=step) wandb.log({"eval_accuracy": iou_avg_val}, step=step) if (iou_avg_val >= pargs.target_iou): logger.log_event(key="target_accuracy_reached", value=pargs.target_iou, metadata={ 'epoch_num': epoch + 1, 'step_num': step }) stop_training = True # set to train net.train() logger.log_end(key="eval_stop", metadata={'epoch_num': epoch + 1}) #save model if desired if (pargs.save_frequency > 0) and (step % pargs.save_frequency == 0): logger.log_start(key="save_start", metadata={ 'epoch_num': epoch + 1, 'step_num': step }, sync=True) if comm_rank == 0: checkpoint = { 'step': step, 'epoch': epoch, 'model': net.state_dict(), 'optimizer': optimizer.state_dict() } if have_apex: checkpoint['amp'] = amp.state_dict() torch.save( checkpoint, os.path.join( output_dir, pargs.model_prefix + "_step_" + str(step) + ".cpt")) logger.log_end(key="save_stop", metadata={ 'epoch_num': epoch + 1, 'step_num': step }, sync=True) # Stop training? if stop_training: break # log the epoch logger.log_end(key="epoch_stop", metadata={ 'epoch_num': epoch + 1, 'step_num': step }, sync=True) epoch += 1 # are we done? if epoch >= pargs.max_epochs or stop_training: break # run done logger.log_end(key="run_stop", sync=True, metadata={'status': 'success'})
def main(pargs): #init distributed training comm.init(pargs.wireup_method) comm_rank = comm.get_rank() comm_local_rank = comm.get_local_rank() comm_size = comm.get_size() #set seed seed = 333 # Some setup torch.manual_seed(seed) if torch.cuda.is_available(): printr("Using GPUs", 0) device = torch.device("cuda", comm_local_rank) torch.cuda.manual_seed(seed) #necessary for AMP to work torch.cuda.set_device(device) else: printr("Using CPUs", 0) device = torch.device("cpu") #visualize? visualize = (pargs.training_visualization_frequency > 0) or (pargs.validation_visualization_frequency > 0) #set up directories root_dir = os.path.join(pargs.data_dir_prefix) output_dir = pargs.output_dir plot_dir = os.path.join(output_dir, "plots") if comm_rank == 0: if not os.path.isdir(output_dir): os.makedirs(output_dir) if visualize and not os.path.isdir(plot_dir): os.makedirs(plot_dir) # Setup WandB if (pargs.logging_frequency > 0) and (comm_rank == 0): # get wandb api token with open(os.path.join(pargs.wandb_certdir, ".wandbirc")) as f: token = f.readlines()[0].replace("\n", "").split() wblogin = token[0] wbtoken = token[1] # log in: that call can be blocking, it should be quick sp.call(["wandb", "login", wbtoken]) #init db and get config resume_flag = pargs.run_tag if pargs.resume_logging else False wandb.init(entity=wblogin, project='deepcam', name=pargs.run_tag, id=pargs.run_tag, resume=resume_flag) config = wandb.config #set general parameters config.root_dir = root_dir config.output_dir = pargs.output_dir config.max_epochs = pargs.max_epochs config.local_batch_size = pargs.local_batch_size config.num_workers = comm_size config.channels = pargs.channels config.optimizer = pargs.optimizer config.start_lr = pargs.start_lr config.adam_eps = pargs.adam_eps config.weight_decay = pargs.weight_decay config.model_prefix = pargs.model_prefix config.amp_opt_level = pargs.amp_opt_level config.loss_weight_pow = pargs.loss_weight_pow config.lr_warmup_steps = pargs.lr_warmup_steps config.lr_warmup_factor = pargs.lr_warmup_factor # lr schedule if applicable if pargs.lr_schedule: for key in pargs.lr_schedule: config.update({"lr_schedule_" + key: pargs.lr_schedule[key]}, allow_val_change=True) # Define architecture n_input_channels = len(pargs.channels) n_output_channels = 3 net = deeplab_xception.DeepLabv3_plus(n_input=n_input_channels, n_classes=n_output_channels, os=16, pretrained=False, rank=comm_rank) net.to(device) #select loss loss_pow = pargs.loss_weight_pow #some magic numbers class_weights = [ 0.986267818390377**loss_pow, 0.0004578708870701058**loss_pow, 0.01327431072255291**loss_pow ] fpw_1 = 2.61461122397522257612 fpw_2 = 1.71641974795896018744 criterion = losses.fp_loss #select optimizer optimizer = None if pargs.optimizer == "Adam": optimizer = optim.Adam(net.parameters(), lr=pargs.start_lr, eps=pargs.adam_eps, weight_decay=pargs.weight_decay) elif pargs.optimizer == "AdamW": optimizer = optim.AdamW(net.parameters(), lr=pargs.start_lr, eps=pargs.adam_eps, weight_decay=pargs.weight_decay) elif have_apex and (pargs.optimizer == "LAMB"): optimizer = aoptim.FusedLAMB(net.parameters(), lr=pargs.start_lr, eps=pargs.adam_eps, weight_decay=pargs.weight_decay) else: raise NotImplementedError("Error, optimizer {} not supported".format( pargs.optimizer)) if have_apex: #wrap model and opt into amp net, optimizer = amp.initialize(net, optimizer, opt_level=pargs.amp_opt_level) #make model distributed net = DDP(net) #restart from checkpoint if desired #if (comm_rank == 0) and (pargs.checkpoint): #load it on all ranks for now if pargs.checkpoint: checkpoint = torch.load(pargs.checkpoint, map_location=device) start_step = checkpoint['step'] start_epoch = checkpoint['epoch'] optimizer.load_state_dict(checkpoint['optimizer']) net.load_state_dict(checkpoint['model']) if have_apex: amp.load_state_dict(checkpoint['amp']) else: start_step = 0 start_epoch = 0 #select scheduler if pargs.lr_schedule: scheduler_after = ph.get_lr_schedule(pargs.start_lr, pargs.lr_schedule, optimizer, last_step=start_step) if pargs.lr_warmup_steps > 0: scheduler = GradualWarmupScheduler( optimizer, multiplier=pargs.lr_warmup_factor, total_epoch=pargs.lr_warmup_steps, after_scheduler=scheduler_after) else: scheduler = scheduler_after #broadcast model and optimizer state steptens = torch.tensor(np.array([start_step, start_epoch]), requires_grad=False).to(device) dist.broadcast(steptens, src=0) ##broadcast model and optimizer state #hvd.broadcast_parameters(net.state_dict(), root_rank = 0) #hvd.broadcast_optimizer_state(optimizer, root_rank = 0) #unpack the bcasted tensor start_step = steptens.cpu().numpy()[0] start_epoch = steptens.cpu().numpy()[1] # Set up the data feeder # train train_dir = os.path.join(root_dir, "train") train_set = cam.CamDataset(train_dir, statsfile=os.path.join(root_dir, 'stats.h5'), channels=pargs.channels, shuffle=True, preprocess=True, comm_size=comm_size, comm_rank=comm_rank) train_loader = DataLoader( train_set, pargs.local_batch_size, num_workers=min([pargs.max_inter_threads, pargs.local_batch_size]), drop_last=True) # validation: we only want to shuffle the set if we are cutting off validation after a certain number of steps validation_dir = os.path.join(root_dir, "validation") validation_set = cam.CamDataset(validation_dir, statsfile=os.path.join( root_dir, 'stats.h5'), channels=pargs.channels, shuffle=(pargs.max_validation_steps is not None), preprocess=True, comm_size=comm_size, comm_rank=comm_rank) validation_loader = DataLoader( validation_set, pargs.local_batch_size, num_workers=min([pargs.max_inter_threads, pargs.local_batch_size]), drop_last=True) #for visualization if visualize: viz = vizc.CamVisualizer() # Train network if (pargs.logging_frequency > 0) and (comm_rank == 0): wandb.watch(net) printr( '{:14.4f} REPORT: starting training'.format( dt.datetime.now().timestamp()), 0) step = start_step epoch = start_epoch current_lr = pargs.start_lr if not pargs.lr_schedule else scheduler.get_last_lr( )[0] net.train() while True: printr( '{:14.4f} REPORT: starting epoch {}'.format( dt.datetime.now().timestamp(), epoch), 0) #for inputs_raw, labels, source in train_loader: for inputs, label, filename in train_loader: #send to device inputs = inputs.to(device) label = label.to(device) # forward pass outputs = net.forward(inputs) # Compute loss and average across nodes loss = criterion(outputs, label, weight=class_weights, fpw_1=fpw_1, fpw_2=fpw_2) # allreduce for loss loss_avg = loss.detach() dist.reduce(loss_avg, dst=0, op=dist.ReduceOp.SUM) # Compute score predictions = torch.max(outputs, 1)[1] iou = utils.compute_score(predictions, label, device_id=device, num_classes=3) iou_avg = iou.detach() dist.reduce(iou_avg, dst=0, op=dist.ReduceOp.SUM) # Backprop optimizer.zero_grad() if have_apex: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() optimizer.step() #step counter step += 1 if pargs.lr_schedule: current_lr = scheduler.get_last_lr()[0] scheduler.step() #print some metrics printr( '{:14.4f} REPORT training: step {} loss {} iou {} LR {}'. format(dt.datetime.now().timestamp(), step, loss_avg.item() / float(comm_size), iou_avg.item() / float(comm_size), current_lr), 0) #visualize if requested if (step % pargs.training_visualization_frequency == 0) and (comm_rank == 0): #extract sample id and data tensors sample_idx = np.random.randint(low=0, high=label.shape[0]) plot_input = inputs.detach()[sample_idx, 0, ...].cpu().numpy() plot_prediction = predictions.detach()[sample_idx, ...].cpu().numpy() plot_label = label.detach()[sample_idx, ...].cpu().numpy() #create filenames outputfile = os.path.basename(filename[sample_idx]).replace( "data-", "training-").replace(".h5", ".png") outputfile = os.path.join(plot_dir, outputfile) #plot viz.plot(filename[sample_idx], outputfile, plot_input, plot_prediction, plot_label) #log if requested if pargs.logging_frequency > 0: img = Image.open(outputfile) wandb.log( { "Training Examples": [ wandb.Image( img, caption="Prediction vs. Ground Truth") ] }, step=step) #log if requested if (pargs.logging_frequency > 0) and ( step % pargs.logging_frequency == 0) and (comm_rank == 0): wandb.log( {"Training Loss": loss_avg.item() / float(comm_size)}, step=step) wandb.log({"Training IoU": iou_avg.item() / float(comm_size)}, step=step) wandb.log({"Current Learning Rate": current_lr}, step=step) # validation step if desired if (step % pargs.validation_frequency == 0): #eval net.eval() count_sum_val = torch.Tensor([0.]).to(device) loss_sum_val = torch.Tensor([0.]).to(device) iou_sum_val = torch.Tensor([0.]).to(device) # disable gradients with torch.no_grad(): # iterate over validation sample step_val = 0 # only print once per eval at most visualized = False for inputs_val, label_val, filename_val in validation_loader: #send to device inputs_val = inputs_val.to(device) label_val = label_val.to(device) # forward pass outputs_val = net.forward(inputs_val) # Compute loss and average across nodes loss_val = criterion(outputs_val, label_val, weight=class_weights) loss_sum_val += loss_val #increase counter count_sum_val += 1. # Compute score predictions_val = torch.max(outputs_val, 1)[1] iou_val = utils.compute_score(predictions_val, label_val, device_id=device, num_classes=3) iou_sum_val += iou_val # Visualize if (step_val % pargs.validation_visualization_frequency == 0) and (not visualized) and (comm_rank == 0): #extract sample id and data tensors sample_idx = np.random.randint( low=0, high=label_val.shape[0]) plot_input = inputs_val.detach()[ sample_idx, 0, ...].cpu().numpy() plot_prediction = predictions_val.detach()[ sample_idx, ...].cpu().numpy() plot_label = label_val.detach()[sample_idx, ...].cpu().numpy() #create filenames outputfile = os.path.basename( filename[sample_idx]).replace( "data-", "validation-").replace(".h5", ".png") outputfile = os.path.join(plot_dir, outputfile) #plot viz.plot(filename[sample_idx], outputfile, plot_input, plot_prediction, plot_label) visualized = True #log if requested if pargs.logging_frequency > 0: img = Image.open(outputfile) wandb.log( { "Validation Examples": [ wandb.Image( img, caption= "Prediction vs. Ground Truth") ] }, step=step) #increase eval step counter step_val += 1 if (pargs.max_validation_steps is not None ) and step_val > pargs.max_validation_steps: break # average the validation loss dist.reduce(count_sum_val, dst=0, op=dist.ReduceOp.SUM) dist.reduce(loss_sum_val, dst=0, op=dist.ReduceOp.SUM) dist.reduce(iou_sum_val, dst=0, op=dist.ReduceOp.SUM) loss_avg_val = loss_sum_val.item() / count_sum_val.item() iou_avg_val = iou_sum_val.item() / count_sum_val.item() # print results printr( '{:14.4f} REPORT validation: step {} loss {} iou {}'. format(dt.datetime.now().timestamp(), step, loss_avg_val, iou_avg_val), 0) # log in wandb if (pargs.logging_frequency > 0) and (comm_rank == 0): wandb.log({"Validation Loss": loss_avg_val}, step=step) wandb.log({"Validation IoU": iou_avg_val}, step=step) # set to train net.train() #save model if desired if (step % pargs.save_frequency == 0) and (comm_rank == 0): checkpoint = { 'step': step, 'epoch': epoch, 'model': net.state_dict(), 'optimizer': optimizer.state_dict() } if have_apex: checkpoint['amp'] = amp.state_dict() torch.save( checkpoint, os.path.join( output_dir, pargs.model_prefix + "_step_" + str(step) + ".cpt")) #do some after-epoch prep, just for the books epoch += 1 if comm_rank == 0: # Save the model checkpoint = { 'step': step, 'epoch': epoch, 'model': net.state_dict(), 'optimizer': optimizer.state_dict() } if have_apex: checkpoint['amp'] = amp.state_dict() torch.save( checkpoint, os.path.join( output_dir, pargs.model_prefix + "_epoch_" + str(epoch) + ".cpt")) #are we done? if epoch >= pargs.max_epochs: break printr( '{:14.4f} REPORT: finishing training'.format( dt.datetime.now().timestamp()), 0)