def getOptimizer(model, hparams): if hparams.optimizer_fused: from apex import optimizers as apexopt if hparams.optimizer == "Adam": optimizer = apexopt.FusedAdam(model.parameters(), lr=hparams.learning_rate) elif hparams.optimizer == "LAMB": optimizer = apexopt.FusedLAMB(model.parameters(), lr=hparams.learning_rate) else: if hparams.optimizer == "Adam": optimizer = torch.optim.Adam(model.parameters(), lr=hparams.learning_rate) elif hparams.optimizer == "LAMB": raise NotImplementedError # PyTorch doesn't currently include LAMB optimizer. return optimizer
def get_optimizer(model, optimizer: str, fp16_run=True, optimizer_fused=True, learning_rate=1e-4, max_grad_norm=200): optimizer = optimizer.lower() if optimizer_fused: from apex import optimizers as apexopt if optimizer == "adam": optimizer = apexopt.FusedAdam(model.parameters(), lr=learning_rate) elif optimizer == "lamb": optimizer = apexopt.FusedLAMB(model.parameters(), lr=learning_rate, max_grad_norm=max_grad_norm) else: if optimizer == "adam": optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) elif optimizer == "lamb": from lamb import Lamb as optLAMB optimizer = optLAMB(model.parameters(), lr=learning_rate) global amp if fp16_run: from apex import amp model, optimizer = amp.initialize(model, optimizer, opt_level='O1') else: amp = None return model, optimizer
def main(pargs): # this should be global global have_wandb #init distributed training comm.init(pargs.wireup_method) comm_rank = comm.get_rank() comm_local_rank = comm.get_local_rank() comm_size = comm.get_size() # set up logging pargs.logging_frequency = max([pargs.logging_frequency, 1]) log_file = os.path.normpath( os.path.join(pargs.output_dir, "logs", pargs.run_tag + ".log")) logger = mll.mlperf_logger(log_file, "deepcam", "Umbrella Corp.") logger.log_start(key="init_start", sync=True) logger.log_event(key="cache_clear") #set seed seed = 333 logger.log_event(key="seed", value=seed) # Some setup torch.manual_seed(seed) if torch.cuda.is_available(): device = torch.device("cuda", comm_local_rank) torch.cuda.manual_seed(seed) #necessary for AMP to work torch.cuda.set_device(device) # TEST: allowed? Valuable? #torch.backends.cudnn.benchark = True else: device = torch.device("cpu") #visualize? visualize = (pargs.training_visualization_frequency > 0) or (pargs.validation_visualization_frequency > 0) #set up directories root_dir = os.path.join(pargs.data_dir_prefix) output_dir = pargs.output_dir plot_dir = os.path.join(output_dir, "plots") if comm_rank == 0: if not os.path.isdir(output_dir): os.makedirs(output_dir) if visualize and not os.path.isdir(plot_dir): os.makedirs(plot_dir) # Setup WandB if not pargs.enable_wandb: have_wandb = False if have_wandb and (comm_rank == 0): # get wandb api token certfile = os.path.join(pargs.wandb_certdir, ".wandbirc") try: with open(certfile) as f: token = f.readlines()[0].replace("\n", "").split() wblogin = token[0] wbtoken = token[1] except IOError: print("Error, cannot open WandB certificate {}.".format(certfile)) have_wandb = False if have_wandb: # log in: that call can be blocking, it should be quick sp.call(["wandb", "login", wbtoken]) #init db and get config resume_flag = pargs.run_tag if pargs.resume_logging else False wandb.init(entity=wblogin, project='deepcam', name=pargs.run_tag, id=pargs.run_tag, resume=resume_flag) config = wandb.config #set general parameters config.root_dir = root_dir config.output_dir = pargs.output_dir config.max_epochs = pargs.max_epochs config.local_batch_size = pargs.local_batch_size config.num_workers = comm_size config.channels = pargs.channels config.optimizer = pargs.optimizer config.start_lr = pargs.start_lr config.adam_eps = pargs.adam_eps config.weight_decay = pargs.weight_decay config.model_prefix = pargs.model_prefix config.amp_opt_level = pargs.amp_opt_level config.loss_weight_pow = pargs.loss_weight_pow config.lr_warmup_steps = pargs.lr_warmup_steps config.lr_warmup_factor = pargs.lr_warmup_factor # lr schedule if applicable if pargs.lr_schedule: for key in pargs.lr_schedule: config.update( {"lr_schedule_" + key: pargs.lr_schedule[key]}, allow_val_change=True) # Logging hyperparameters logger.log_event(key="global_batch_size", value=(pargs.local_batch_size * comm_size)) logger.log_event(key="opt_name", value=pargs.optimizer) logger.log_event(key="opt_base_learning_rate", value=pargs.start_lr * pargs.lr_warmup_factor) logger.log_event(key="opt_learning_rate_warmup_steps", value=pargs.lr_warmup_steps) logger.log_event(key="opt_learning_rate_warmup_factor", value=pargs.lr_warmup_factor) logger.log_event(key="opt_epsilon", value=pargs.adam_eps) # Define architecture n_input_channels = len(pargs.channels) n_output_channels = 3 net = deeplab_xception.DeepLabv3_plus(n_input=n_input_channels, n_classes=n_output_channels, os=16, pretrained=False, rank=comm_rank) net.to(device) #select loss loss_pow = pargs.loss_weight_pow #some magic numbers class_weights = [ 0.986267818390377**loss_pow, 0.0004578708870701058**loss_pow, 0.01327431072255291**loss_pow ] fpw_1 = 2.61461122397522257612 fpw_2 = 1.71641974795896018744 criterion = losses.fp_loss #select optimizer optimizer = None if pargs.optimizer == "Adam": optimizer = optim.Adam(net.parameters(), lr=pargs.start_lr, eps=pargs.adam_eps, weight_decay=pargs.weight_decay) elif pargs.optimizer == "AdamW": optimizer = optim.AdamW(net.parameters(), lr=pargs.start_lr, eps=pargs.adam_eps, weight_decay=pargs.weight_decay) elif have_apex and (pargs.optimizer == "LAMB"): optimizer = aoptim.FusedLAMB(net.parameters(), lr=pargs.start_lr, eps=pargs.adam_eps, weight_decay=pargs.weight_decay) else: raise NotImplementedError("Error, optimizer {} not supported".format( pargs.optimizer)) if have_apex: #wrap model and opt into amp net, optimizer = amp.initialize(net, optimizer, opt_level=pargs.amp_opt_level) #make model distributed net = DDP(net) #restart from checkpoint if desired #if (comm_rank == 0) and (pargs.checkpoint): #load it on all ranks for now if pargs.checkpoint: checkpoint = torch.load(pargs.checkpoint, map_location=device) start_step = checkpoint['step'] start_epoch = checkpoint['epoch'] optimizer.load_state_dict(checkpoint['optimizer']) net.load_state_dict(checkpoint['model']) if have_apex: amp.load_state_dict(checkpoint['amp']) else: start_step = 0 start_epoch = 0 #select scheduler if pargs.lr_schedule: scheduler_after = ph.get_lr_schedule(pargs.start_lr, pargs.lr_schedule, optimizer, last_step=start_step) # LR warmup if pargs.lr_warmup_steps > 0: if have_warmup_scheduler: scheduler = GradualWarmupScheduler( optimizer, multiplier=pargs.lr_warmup_factor, total_epoch=pargs.lr_warmup_steps, after_scheduler=scheduler_after) # Throw an error if the package is not found else: raise Exception( f'Requested {pargs.lr_warmup_steps} LR warmup steps ' 'but warmup scheduler not found. Install it from ' 'https://github.com/ildoonet/pytorch-gradual-warmup-lr') else: scheduler = scheduler_after #broadcast model and optimizer state steptens = torch.tensor(np.array([start_step, start_epoch]), requires_grad=False).to(device) dist.broadcast(steptens, src=0) ##broadcast model and optimizer state #hvd.broadcast_parameters(net.state_dict(), root_rank = 0) #hvd.broadcast_optimizer_state(optimizer, root_rank = 0) #unpack the bcasted tensor start_step = steptens.cpu().numpy()[0] start_epoch = steptens.cpu().numpy()[1] # Set up the data feeder # train train_dir = os.path.join(root_dir, "train") train_set = cam.CamDataset(train_dir, statsfile=os.path.join(root_dir, 'stats.h5'), channels=pargs.channels, allow_uneven_distribution=False, shuffle=True, preprocess=True, comm_size=comm_size, comm_rank=comm_rank) train_loader = DataLoader( train_set, pargs.local_batch_size, num_workers=min([pargs.max_inter_threads, pargs.local_batch_size]), pin_memory=True, drop_last=True) # validation: we only want to shuffle the set if we are cutting off validation after a certain number of steps validation_dir = os.path.join(root_dir, "validation") validation_set = cam.CamDataset(validation_dir, statsfile=os.path.join( root_dir, 'stats.h5'), channels=pargs.channels, allow_uneven_distribution=True, shuffle=(pargs.max_validation_steps is not None), preprocess=True, comm_size=comm_size, comm_rank=comm_rank) # use batch size = 1 here to make sure that we do not drop a sample validation_loader = DataLoader( validation_set, 1, num_workers=min([pargs.max_inter_threads, pargs.local_batch_size]), pin_memory=True, drop_last=True) # log size of datasets logger.log_event(key="train_samples", value=train_set.global_size) if pargs.max_validation_steps is not None: val_size = min([ validation_set.global_size, pargs.max_validation_steps * pargs.local_batch_size * comm_size ]) else: val_size = validation_set.global_size logger.log_event(key="eval_samples", value=val_size) # do sanity check if pargs.max_validation_steps is not None: logger.log_event(key="invalid_submission") #for visualization #if visualize: # viz = vizc.CamVisualizer() # Train network if have_wandb and (comm_rank == 0): wandb.watch(net) step = start_step epoch = start_epoch current_lr = pargs.start_lr if not pargs.lr_schedule else scheduler.get_last_lr( )[0] stop_training = False net.train() # start trining logger.log_end(key="init_stop", sync=True) logger.log_start(key="run_start", sync=True) # training loop while True: # start epoch logger.log_start(key="epoch_start", metadata={ 'epoch_num': epoch + 1, 'step_num': step }, sync=True) # epoch loop for inputs, label, filename in train_loader: # send to device inputs = inputs.to(device) label = label.to(device) # forward pass outputs = net.forward(inputs) # Compute loss and average across nodes loss = criterion(outputs, label, weight=class_weights, fpw_1=fpw_1, fpw_2=fpw_2) # Backprop optimizer.zero_grad() if have_apex: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() optimizer.step() # step counter step += 1 if pargs.lr_schedule: current_lr = scheduler.get_last_lr()[0] scheduler.step() #visualize if requested #if (step % pargs.training_visualization_frequency == 0) and (comm_rank == 0): # # Compute predictions # predictions = torch.max(outputs, 1)[1] # # # extract sample id and data tensors # sample_idx = np.random.randint(low=0, high=label.shape[0]) # plot_input = inputs.detach()[sample_idx, 0,...].cpu().numpy() # plot_prediction = predictions.detach()[sample_idx,...].cpu().numpy() # plot_label = label.detach()[sample_idx,...].cpu().numpy() # # # create filenames # outputfile = os.path.basename(filename[sample_idx]).replace("data-", "training-").replace(".h5", ".png") # outputfile = os.path.join(plot_dir, outputfile) # # # plot # viz.plot(filename[sample_idx], outputfile, plot_input, plot_prediction, plot_label) # # #log if requested # if have_wandb: # img = Image.open(outputfile) # wandb.log({"train_examples": [wandb.Image(img, caption="Prediction vs. Ground Truth")]}, step = step) #log if requested if (step % pargs.logging_frequency == 0): # allreduce for loss loss_avg = loss.detach() dist.reduce(loss_avg, dst=0, op=dist.ReduceOp.SUM) loss_avg_train = loss_avg.item() / float(comm_size) # Compute score predictions = torch.max(outputs, 1)[1] iou = utils.compute_score(predictions, label, device_id=device, num_classes=3) iou_avg = iou.detach() dist.reduce(iou_avg, dst=0, op=dist.ReduceOp.SUM) iou_avg_train = iou_avg.item() / float(comm_size) logger.log_event(key="learning_rate", value=current_lr, metadata={ 'epoch_num': epoch + 1, 'step_num': step }) logger.log_event(key="train_accuracy", value=iou_avg_train, metadata={ 'epoch_num': epoch + 1, 'step_num': step }) logger.log_event(key="train_loss", value=loss_avg_train, metadata={ 'epoch_num': epoch + 1, 'step_num': step }) if have_wandb and (comm_rank == 0): wandb.log( {"train_loss": loss_avg.item() / float(comm_size)}, step=step) wandb.log( {"train_accuracy": iou_avg.item() / float(comm_size)}, step=step) wandb.log({"learning_rate": current_lr}, step=step) wandb.log({"epoch": epoch + 1}, step=step) # validation step if desired if (step % pargs.validation_frequency == 0): logger.log_start(key="eval_start", metadata={'epoch_num': epoch + 1}) #eval net.eval() count_sum_val = torch.Tensor([0.]).to(device) loss_sum_val = torch.Tensor([0.]).to(device) iou_sum_val = torch.Tensor([0.]).to(device) # disable gradients with torch.no_grad(): # iterate over validation sample step_val = 0 # only print once per eval at most visualized = False for inputs_val, label_val, filename_val in validation_loader: #send to device inputs_val = inputs_val.to(device) label_val = label_val.to(device) # forward pass outputs_val = net.forward(inputs_val) # Compute loss and average across nodes loss_val = criterion(outputs_val, label_val, weight=class_weights, fpw_1=fpw_1, fpw_2=fpw_2) loss_sum_val += loss_val #increase counter count_sum_val += 1. # Compute score predictions_val = torch.max(outputs_val, 1)[1] iou_val = utils.compute_score(predictions_val, label_val, device_id=device, num_classes=3) iou_sum_val += iou_val # Visualize #if (step_val % pargs.validation_visualization_frequency == 0) and (not visualized) and (comm_rank == 0): # #extract sample id and data tensors # sample_idx = np.random.randint(low=0, high=label_val.shape[0]) # plot_input = inputs_val.detach()[sample_idx, 0,...].cpu().numpy() # plot_prediction = predictions_val.detach()[sample_idx,...].cpu().numpy() # plot_label = label_val.detach()[sample_idx,...].cpu().numpy() # # #create filenames # outputfile = os.path.basename(filename[sample_idx]).replace("data-", "validation-").replace(".h5", ".png") # outputfile = os.path.join(plot_dir, outputfile) # # #plot # viz.plot(filename[sample_idx], outputfile, plot_input, plot_prediction, plot_label) # visualized = True # # #log if requested # if have_wandb: # img = Image.open(outputfile) # wandb.log({"eval_examples": [wandb.Image(img, caption="Prediction vs. Ground Truth")]}, step = step) #increase eval step counter step_val += 1 if (pargs.max_validation_steps is not None ) and step_val > pargs.max_validation_steps: break # average the validation loss dist.all_reduce(count_sum_val, op=dist.ReduceOp.SUM) dist.all_reduce(loss_sum_val, op=dist.ReduceOp.SUM) dist.all_reduce(iou_sum_val, op=dist.ReduceOp.SUM) loss_avg_val = loss_sum_val.item() / count_sum_val.item() iou_avg_val = iou_sum_val.item() / count_sum_val.item() # print results logger.log_event(key="eval_accuracy", value=iou_avg_val, metadata={ 'epoch_num': epoch + 1, 'step_num': step }) logger.log_event(key="eval_loss", value=loss_avg_val, metadata={ 'epoch_num': epoch + 1, 'step_num': step }) # log in wandb if have_wandb and (comm_rank == 0): wandb.log({"eval_loss": loss_avg_val}, step=step) wandb.log({"eval_accuracy": iou_avg_val}, step=step) if (iou_avg_val >= pargs.target_iou): logger.log_event(key="target_accuracy_reached", value=pargs.target_iou, metadata={ 'epoch_num': epoch + 1, 'step_num': step }) stop_training = True # set to train net.train() logger.log_end(key="eval_stop", metadata={'epoch_num': epoch + 1}) #save model if desired if (pargs.save_frequency > 0) and (step % pargs.save_frequency == 0): logger.log_start(key="save_start", metadata={ 'epoch_num': epoch + 1, 'step_num': step }, sync=True) if comm_rank == 0: checkpoint = { 'step': step, 'epoch': epoch, 'model': net.state_dict(), 'optimizer': optimizer.state_dict() } if have_apex: checkpoint['amp'] = amp.state_dict() torch.save( checkpoint, os.path.join( output_dir, pargs.model_prefix + "_step_" + str(step) + ".cpt")) logger.log_end(key="save_stop", metadata={ 'epoch_num': epoch + 1, 'step_num': step }, sync=True) # Stop training? if stop_training: break # log the epoch logger.log_end(key="epoch_stop", metadata={ 'epoch_num': epoch + 1, 'step_num': step }, sync=True) epoch += 1 # are we done? if epoch >= pargs.max_epochs or stop_training: break # run done logger.log_end(key="run_stop", sync=True, metadata={'status': 'success'})
def train(num_gpus, rank, group_name, output_directory, epochs, learning_rate, sigma, loss_empthasis, iters_per_checkpoint, batch_size, seed, fp16_run, checkpoint_path, with_tensorboard, logdirname, datedlogdir, warm_start=False, optimizer='ADAM', start_zero=False): torch.manual_seed(seed) torch.cuda.manual_seed(seed) #=====START: ADDED FOR DISTRIBUTED====== if num_gpus > 1: init_distributed(rank, num_gpus, group_name, **dist_config) #=====END: ADDED FOR DISTRIBUTED====== global WaveGlow global WaveGlowLoss ax = True # this is **really** bad coding practice :D if ax: from efficient_model_ax import WaveGlow from efficient_loss import WaveGlowLoss else: if waveglow_config["yoyo"]: # efficient_mode # TODO: Add to Config File from efficient_model import WaveGlow from efficient_loss import WaveGlowLoss else: from glow import WaveGlow, WaveGlowLoss criterion = WaveGlowLoss(sigma, loss_empthasis) model = WaveGlow(**waveglow_config).cuda() #=====START: ADDED FOR DISTRIBUTED====== if num_gpus > 1: model = apply_gradient_allreduce(model) #=====END: ADDED FOR DISTRIBUTED====== STFTs = [STFT.TacotronSTFT(filter_length=window, hop_length=data_config['hop_length'], win_length=window, sampling_rate=data_config['sampling_rate'], n_mel_channels=160, mel_fmin=0, mel_fmax=16000) for window in data_config['validation_windows']] loader_STFT = STFT.TacotronSTFT(filter_length=data_config['filter_length'], hop_length=data_config['hop_length'], win_length=data_config['win_length'], sampling_rate=data_config['sampling_rate'], n_mel_channels=data_config['n_mel_channels'] if 'n_mel_channels' in data_config.keys() else 160, mel_fmin=data_config['mel_fmin'], mel_fmax=data_config['mel_fmax']) #optimizer = "Adam" optimizer = optimizer.lower() optimizer_fused = bool( 0 ) # use Apex fused optimizer, should be identical to normal but slightly faster and only works on RTX cards if optimizer_fused: from apex import optimizers as apexopt if optimizer == "adam": optimizer = apexopt.FusedAdam(model.parameters(), lr=learning_rate) elif optimizer == "lamb": optimizer = apexopt.FusedLAMB(model.parameters(), lr=learning_rate, max_grad_norm=200) else: if optimizer == "adam": optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) elif optimizer == "lamb": from lamb import Lamb as optLAMB optimizer = optLAMB(model.parameters(), lr=learning_rate) #import torch_optimizer as optim #optimizer = optim.Lamb(model.parameters(), lr=learning_rate) #raise# PyTorch doesn't currently include LAMB optimizer. if fp16_run: global amp from apex import amp model, optimizer = amp.initialize(model, optimizer, opt_level='O1') else: amp = None ## LEARNING RATE SCHEDULER if True: from torch.optim.lr_scheduler import ReduceLROnPlateau min_lr = 1e-8 factor = 0.1**(1/5) # amount to scale the LR by on Validation Loss plateau scheduler = ReduceLROnPlateau(optimizer, 'min', factor=factor, patience=20, cooldown=2, min_lr=min_lr, verbose=True, threshold=0.0001, threshold_mode='abs') print("ReduceLROnPlateau used as Learning Rate Scheduler.") else: scheduler=False # Load checkpoint if one exists iteration = 0 if checkpoint_path != "": model, optimizer, iteration, scheduler = load_checkpoint(checkpoint_path, model, optimizer, scheduler, fp16_run, warm_start=warm_start) iteration += 1 # next iteration is iteration + 1 if start_zero: iteration = 0 trainset = Mel2Samp(**data_config, check_files=True) speaker_lookup = trainset.speaker_ids # =====START: ADDED FOR DISTRIBUTED====== if num_gpus > 1: train_sampler = DistributedSampler(trainset, shuffle=True) shuffle = False else: train_sampler = None shuffle = True # =====END: ADDED FOR DISTRIBUTED====== train_loader = DataLoader(trainset, num_workers=3, shuffle=shuffle, sampler=train_sampler, batch_size=batch_size, pin_memory=False, drop_last=True) # Get shared output_directory ready if rank == 0: if not os.path.isdir(output_directory): os.makedirs(output_directory) os.chmod(output_directory, 0o775) print("output directory", output_directory) if with_tensorboard and rank == 0: from tensorboardX import SummaryWriter if datedlogdir: timestr = time.strftime("%Y_%m_%d-%H_%M_%S") log_directory = os.path.join(output_directory, logdirname, timestr) else: log_directory = os.path.join(output_directory, logdirname) logger = SummaryWriter(log_directory) moving_average = int(min(len(train_loader), 100)) # average loss over entire Epoch rolling_sum = StreamingMovingAverage(moving_average) start_time = time.time() start_time_iter = time.time() start_time_dekaiter = time.time() model.train() # best (averaged) training loss if os.path.exists(os.path.join(output_directory, "best_model")+".txt"): best_model_loss = float(str(open(os.path.join(output_directory, "best_model")+".txt", "r", encoding="utf-8").read()).split("\n")[0]) else: best_model_loss = -6.20 # best (validation) MSE on inferred spectrogram. if os.path.exists(os.path.join(output_directory, "best_val_model")+".txt"): best_MSE = float(str(open(os.path.join(output_directory, "best_val_model")+".txt", "r", encoding="utf-8").read()).split("\n")[0]) else: best_MSE = 9e9 epoch_offset = max(0, int(iteration / len(train_loader))) pytorch_total_params = sum(p.numel() for p in model.parameters()) print("{:,} total parameters in model".format(pytorch_total_params)) pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print("{:,} trainable parameters.".format(pytorch_total_params)) print(f"Segment Length: {data_config['segment_length']:,}\nBatch Size: {batch_size:,}\nNumber of GPUs: {num_gpus:,}\nSamples/Iter: {data_config['segment_length']*batch_size*num_gpus:,}") training = True while training: try: if rank == 0: epochs_iterator = tqdm(range(epoch_offset, epochs), initial=epoch_offset, total=epochs, smoothing=0.01, desc="Epoch", position=1, unit="epoch") else: epochs_iterator = range(epoch_offset, epochs) # ================ MAIN TRAINING LOOP! =================== for epoch in epochs_iterator: print(f"Epoch: {epoch}") if num_gpus > 1: train_sampler.set_epoch(epoch) if rank == 0: iters_iterator = tqdm(enumerate(train_loader), desc=" Iter", smoothing=0, total=len(train_loader), position=0, unit="iter", leave=True) else: iters_iterator = enumerate(train_loader) for i, batch in iters_iterator: # run external code every iter, allows the run to be adjusted without restarts if (i==0 or iteration % param_interval == 0): try: with open("run_every_epoch.py") as f: internal_text = str(f.read()) if len(internal_text) > 0: #code = compile(internal_text, "run_every_epoch.py", 'exec') ldict = {'iteration': iteration, 'seconds_elapsed': time.time()-start_time} exec(internal_text, globals(), ldict) else: print("No Custom code found, continuing without changes.") except Exception as ex: print(f"Custom code FAILED to run!\n{ex}") globals().update(ldict) locals().update(ldict) if show_live_params: print(internal_text) if not iteration % 50: # check actual learning rate every 20 iters (because I sometimes see learning_rate variable go out-of-sync with real LR) learning_rate = optimizer.param_groups[0]['lr'] # Learning Rate Schedule if custom_lr: old_lr = learning_rate if iteration < warmup_start: learning_rate = warmup_start_lr elif iteration < warmup_end: learning_rate = (iteration-warmup_start)*((A_+C_)-warmup_start_lr)/(warmup_end-warmup_start) + warmup_start_lr # learning rate increases from warmup_start_lr to A_ linearly over (warmup_end-warmup_start) iterations. else: if iteration < decay_start: learning_rate = A_ + C_ else: iteration_adjusted = iteration - decay_start learning_rate = (A_*(e**(-iteration_adjusted/B_))) + C_ assert learning_rate > -1e-8, "Negative Learning Rate." if old_lr != learning_rate: for param_group in optimizer.param_groups: param_group['lr'] = learning_rate else: scheduler.patience = scheduler_patience scheduler.cooldown = scheduler_cooldown if override_scheduler_last_lr: scheduler._last_lr = override_scheduler_last_lr if override_scheduler_best: scheduler.best = override_scheduler_best if override_scheduler_last_lr or override_scheduler_best: print("scheduler._last_lr =", scheduler._last_lr, "scheduler.best =", scheduler.best, " |", end='') model.zero_grad() mel, audio, speaker_ids = batch mel = torch.autograd.Variable(mel.cuda(non_blocking=True)) audio = torch.autograd.Variable(audio.cuda(non_blocking=True)) speaker_ids = speaker_ids.cuda(non_blocking=True).long().squeeze(1) outputs = model(mel, audio, speaker_ids) loss = criterion(outputs) if num_gpus > 1: reduced_loss = reduce_tensor(loss.data, num_gpus).item() else: reduced_loss = loss.item() if fp16_run: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if (reduced_loss > LossExplosionThreshold) or (math.isnan(reduced_loss)): model.zero_grad() raise LossExplosion(f"\nLOSS EXPLOSION EXCEPTION ON RANK {rank}: Loss reached {reduced_loss} during iteration {iteration}.\n\n\n") if use_grad_clip: if fp16_run: grad_norm = torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), grad_clip_thresh) else: grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), grad_clip_thresh) if type(grad_norm) == torch.Tensor: grad_norm = grad_norm.item() is_overflow = math.isinf(grad_norm) or math.isnan(grad_norm) else: is_overflow = False; grad_norm=0.00001 optimizer.step() if not is_overflow and rank == 0: # get current Loss Scale of first optimizer loss_scale = amp._amp_state.loss_scalers[0]._loss_scale if fp16_run else 32768 if with_tensorboard: if (iteration % 100000 == 0): # plot distribution of parameters for tag, value in model.named_parameters(): tag = tag.replace('.', '/') logger.add_histogram(tag, value.data.cpu().numpy(), iteration) logger.add_scalar('training_loss', reduced_loss, iteration) logger.add_scalar('training_loss_samples', reduced_loss, iteration*batch_size) if (iteration % 20 == 0): logger.add_scalar('learning.rate', learning_rate, iteration) if (iteration % 10 == 0): logger.add_scalar('duration', ((time.time() - start_time_dekaiter)/10), iteration) average_loss = rolling_sum.process(reduced_loss) if (iteration % 10 == 0): tqdm.write("{} {}: {:.3f} {:.3f} {:.3f} {:08.3F} {:.8f}LR ({:.8f} Effective) {:.2f}s/iter {:.4f}s/item".format(time.strftime("%H:%M:%S"), iteration, reduced_loss, average_loss, best_MSE, round(grad_norm,3), learning_rate, min((grad_clip_thresh/grad_norm)*learning_rate,learning_rate), (time.time() - start_time_dekaiter)/10, ((time.time() - start_time_dekaiter)/10)/(batch_size*num_gpus))) start_time_dekaiter = time.time() else: tqdm.write("{} {}: {:.3f} {:.3f} {:.3f} {:08.3F} {:.8f}LR ({:.8f} Effective) {}LS".format(time.strftime("%H:%M:%S"), iteration, reduced_loss, average_loss, best_MSE, round(grad_norm,3), learning_rate, min((grad_clip_thresh/grad_norm)*learning_rate,learning_rate), loss_scale)) start_time_iter = time.time() if rank == 0 and (len(rolling_sum.values) > moving_average-2): if (average_loss+best_model_margin) < best_model_loss: checkpoint_path = os.path.join(output_directory, "best_model") try: save_checkpoint(model, optimizer, learning_rate, iteration, amp, scheduler, speaker_lookup, checkpoint_path) except KeyboardInterrupt: # Avoid corrupting the model. save_checkpoint(model, optimizer, learning_rate, iteration, amp, scheduler, speaker_lookup, checkpoint_path) text_file = open((f"{checkpoint_path}.txt"), "w", encoding="utf-8") text_file.write(str(average_loss)+"\n"+str(iteration)) text_file.close() best_model_loss = average_loss #Only save the model if X better than the current loss. if rank == 0 and iteration > 0 and ((iteration % iters_per_checkpoint == 0) or (os.path.exists(save_file_check_path))): checkpoint_path = f"{output_directory}/waveglow_{iteration}" save_checkpoint(model, optimizer, learning_rate, iteration, amp, scheduler, speaker_lookup, checkpoint_path) if (os.path.exists(save_file_check_path)): os.remove(save_file_check_path) if (iteration % validation_interval == 0): if rank == 0: MSE, MAE = validate(model, loader_STFT, STFTs, logger, iteration, data_config['validation_files'], speaker_lookup, sigma, output_directory, data_config) if scheduler: MSE = torch.tensor(MSE, device='cuda') if num_gpus > 1: broadcast(MSE, 0) scheduler.step(MSE.item()) if MSE < best_MSE: checkpoint_path = os.path.join(output_directory, "best_val_model") try: save_checkpoint(model, optimizer, learning_rate, iteration, amp, scheduler, speaker_lookup, checkpoint_path) except KeyboardInterrupt: # Avoid corrupting the model. save_checkpoint(model, optimizer, learning_rate, iteration, amp, scheduler, speaker_lookup, checkpoint_path) text_file = open((f"{checkpoint_path}.txt"), "w", encoding="utf-8") text_file.write(str(MSE.item())+"\n"+str(iteration)) text_file.close() best_MSE = MSE.item() #Only save the model if X better than the current loss. else: if scheduler: MSE = torch.zeros(1, device='cuda') broadcast(MSE, 0) scheduler.step(MSE.item()) learning_rate = optimizer.param_groups[0]['lr'] #check actual learning rate (because I sometimes see learning_rate variable go out-of-sync with real LR) iteration += 1 training = False # exit the While loop except LossExplosion as ex: # print Exception and continue from checkpoint. (turns out it takes < 4 seconds to restart like this, f*****g awesome) print(ex) # print Loss checkpoint_path = os.path.join(output_directory, "best_model") assert os.path.exists(checkpoint_path), "best_val_model must exist for automatic restarts" # clearing VRAM for load checkpoint audio = mel = speaker_ids = loss = None torch.cuda.empty_cache() model.eval() model, optimizer, iteration, scheduler = load_checkpoint(checkpoint_path, model, optimizer, scheduler, fp16_run) learning_rate = optimizer.param_groups[0]['lr'] epoch_offset = max(0, int(iteration / len(train_loader))) model.train() iteration += 1 pass # and continue training.
def main(pargs): #init distributed training comm.init(pargs.wireup_method) comm_rank = comm.get_rank() comm_local_rank = comm.get_local_rank() comm_size = comm.get_size() #set seed seed = 333 # Some setup torch.manual_seed(seed) if torch.cuda.is_available(): printr("Using GPUs", 0) device = torch.device("cuda", comm_local_rank) torch.cuda.manual_seed(seed) #necessary for AMP to work torch.cuda.set_device(device) else: printr("Using CPUs", 0) device = torch.device("cpu") #set up directories root_dir = os.path.join(pargs.data_dir_prefix) output_dir = pargs.output_dir plot_dir = os.path.join(output_dir, "plots") if comm_rank == 0: if not os.path.isdir(output_dir): os.makedirs(output_dir) # Define architecture n_input_channels = len(pargs.channels) n_output_channels = 3 net = deeplab_xception.DeepLabv3_plus(n_input=n_input_channels, n_classes=n_output_channels, os=16, pretrained=False, rank=comm_rank) net.to(device) #select loss loss_pow = pargs.loss_weight_pow #some magic numbers class_weights = [ 0.986267818390377**loss_pow, 0.0004578708870701058**loss_pow, 0.01327431072255291**loss_pow ] fpw_1 = 2.61461122397522257612 fpw_2 = 1.71641974795896018744 criterion = losses.fp_loss #select optimizer optimizer = None if pargs.optimizer == "Adam": optimizer = optim.Adam(net.parameters(), lr=pargs.start_lr, eps=pargs.adam_eps, weight_decay=pargs.weight_decay) elif pargs.optimizer == "AdamW": optimizer = optim.AdamW(net.parameters(), lr=pargs.start_lr, eps=pargs.adam_eps, weight_decay=pargs.weight_decay) elif have_apex and (pargs.optimizer == "LAMB"): optimizer = aoptim.FusedLAMB(net.parameters(), lr=pargs.start_lr, eps=pargs.adam_eps, weight_decay=pargs.weight_decay) else: raise NotImplementedError("Error, optimizer {} not supported".format( pargs.optimizer)) if have_apex: #wrap model and opt into amp net, optimizer = amp.initialize(net, optimizer, opt_level=pargs.amp_opt_level) #make model distributed net = DDP(net) #select scheduler if pargs.lr_schedule: scheduler = ph.get_lr_schedule(pargs.start_lr, pargs.lr_schedule, optimizer, last_step=0) # Set up the data feeder # train train_dir = os.path.join(root_dir, "train") train_set = cam.CamDataset(train_dir, statsfile=os.path.join(root_dir, 'stats.h5'), channels=pargs.channels, shuffle=True, preprocess=True, comm_size=comm_size, comm_rank=comm_rank) train_loader = DataLoader( train_set, pargs.local_batch_size, num_workers=min([pargs.max_inter_threads, pargs.local_batch_size]), drop_last=True) printr( '{:14.4f} REPORT: starting warmup'.format( dt.datetime.now().timestamp()), 0) step = 0 current_lr = pargs.start_lr if not pargs.lr_schedule else scheduler.get_last_lr( )[0] current_lr = pargs.start_lr net.train() while True: #for inputs_raw, labels, source in train_loader: for inputs, label, filename in train_loader: # Print status if step == pargs.num_warmup_steps: printr( '{:14.4f} REPORT: starting profiling'.format( dt.datetime.now().timestamp()), 0) # Forward pass with Profile(pargs, "Forward", step): #send data to device inputs = inputs.to(device) label = label.to(device) # Compute output outputs = net.forward(inputs) # Compute loss loss = criterion(outputs, label, weight=class_weights, fpw_1=fpw_1, fpw_2=fpw_2) # allreduce for loss loss_avg = loss.detach() dist.reduce(loss_avg, dst=0, op=dist.ReduceOp.SUM) # Compute score predictions = torch.max(outputs, 1)[1] iou = utils.compute_score(predictions, label, device_id=device, num_classes=3) iou_avg = iou.detach() dist.reduce(iou_avg, dst=0, op=dist.ReduceOp.SUM) # Backprop with Profile(pargs, "Backward", step): # reset grads optimizer.zero_grad() # compute grads if have_apex: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() # weight update with Profile(pargs, "Optimizer", step): # update weights optimizer.step() # advance the scheduler if pargs.lr_schedule: current_lr = scheduler.get_last_lr()[0] scheduler.step() #step counter step += 1 #are we done? if step >= (pargs.num_warmup_steps + pargs.num_profile_steps): break #need to check here too if step >= (pargs.num_warmup_steps + pargs.num_profile_steps): break printr( '{:14.4f} REPORT: finishing profiling'.format( dt.datetime.now().timestamp()), 0)
def main(pargs): #init distributed training comm.init(pargs.wireup_method) comm_rank = comm.get_rank() comm_local_rank = comm.get_local_rank() comm_size = comm.get_size() #set seed seed = 333 # Some setup torch.manual_seed(seed) if torch.cuda.is_available(): printr("Using GPUs", 0) device = torch.device("cuda", comm_local_rank) torch.cuda.manual_seed(seed) #necessary for AMP to work torch.cuda.set_device(device) else: printr("Using CPUs", 0) device = torch.device("cpu") #visualize? visualize = (pargs.training_visualization_frequency > 0) or (pargs.validation_visualization_frequency > 0) #set up directories root_dir = os.path.join(pargs.data_dir_prefix) output_dir = pargs.output_dir plot_dir = os.path.join(output_dir, "plots") if comm_rank == 0: if not os.path.isdir(output_dir): os.makedirs(output_dir) if visualize and not os.path.isdir(plot_dir): os.makedirs(plot_dir) # Setup WandB if (pargs.logging_frequency > 0) and (comm_rank == 0): # get wandb api token with open(os.path.join(pargs.wandb_certdir, ".wandbirc")) as f: token = f.readlines()[0].replace("\n", "").split() wblogin = token[0] wbtoken = token[1] # log in: that call can be blocking, it should be quick sp.call(["wandb", "login", wbtoken]) #init db and get config resume_flag = pargs.run_tag if pargs.resume_logging else False wandb.init(entity=wblogin, project='deepcam', name=pargs.run_tag, id=pargs.run_tag, resume=resume_flag) config = wandb.config #set general parameters config.root_dir = root_dir config.output_dir = pargs.output_dir config.max_epochs = pargs.max_epochs config.local_batch_size = pargs.local_batch_size config.num_workers = comm_size config.channels = pargs.channels config.optimizer = pargs.optimizer config.start_lr = pargs.start_lr config.adam_eps = pargs.adam_eps config.weight_decay = pargs.weight_decay config.model_prefix = pargs.model_prefix config.amp_opt_level = pargs.amp_opt_level config.loss_weight_pow = pargs.loss_weight_pow config.lr_warmup_steps = pargs.lr_warmup_steps config.lr_warmup_factor = pargs.lr_warmup_factor # lr schedule if applicable if pargs.lr_schedule: for key in pargs.lr_schedule: config.update({"lr_schedule_" + key: pargs.lr_schedule[key]}, allow_val_change=True) # Define architecture n_input_channels = len(pargs.channels) n_output_channels = 3 net = deeplab_xception.DeepLabv3_plus(n_input=n_input_channels, n_classes=n_output_channels, os=16, pretrained=False, rank=comm_rank) net.to(device) #select loss loss_pow = pargs.loss_weight_pow #some magic numbers class_weights = [ 0.986267818390377**loss_pow, 0.0004578708870701058**loss_pow, 0.01327431072255291**loss_pow ] fpw_1 = 2.61461122397522257612 fpw_2 = 1.71641974795896018744 criterion = losses.fp_loss #select optimizer optimizer = None if pargs.optimizer == "Adam": optimizer = optim.Adam(net.parameters(), lr=pargs.start_lr, eps=pargs.adam_eps, weight_decay=pargs.weight_decay) elif pargs.optimizer == "AdamW": optimizer = optim.AdamW(net.parameters(), lr=pargs.start_lr, eps=pargs.adam_eps, weight_decay=pargs.weight_decay) elif have_apex and (pargs.optimizer == "LAMB"): optimizer = aoptim.FusedLAMB(net.parameters(), lr=pargs.start_lr, eps=pargs.adam_eps, weight_decay=pargs.weight_decay) else: raise NotImplementedError("Error, optimizer {} not supported".format( pargs.optimizer)) if have_apex: #wrap model and opt into amp net, optimizer = amp.initialize(net, optimizer, opt_level=pargs.amp_opt_level) #make model distributed net = DDP(net) #restart from checkpoint if desired #if (comm_rank == 0) and (pargs.checkpoint): #load it on all ranks for now if pargs.checkpoint: checkpoint = torch.load(pargs.checkpoint, map_location=device) start_step = checkpoint['step'] start_epoch = checkpoint['epoch'] optimizer.load_state_dict(checkpoint['optimizer']) net.load_state_dict(checkpoint['model']) if have_apex: amp.load_state_dict(checkpoint['amp']) else: start_step = 0 start_epoch = 0 #select scheduler if pargs.lr_schedule: scheduler_after = ph.get_lr_schedule(pargs.start_lr, pargs.lr_schedule, optimizer, last_step=start_step) if pargs.lr_warmup_steps > 0: scheduler = GradualWarmupScheduler( optimizer, multiplier=pargs.lr_warmup_factor, total_epoch=pargs.lr_warmup_steps, after_scheduler=scheduler_after) else: scheduler = scheduler_after #broadcast model and optimizer state steptens = torch.tensor(np.array([start_step, start_epoch]), requires_grad=False).to(device) dist.broadcast(steptens, src=0) ##broadcast model and optimizer state #hvd.broadcast_parameters(net.state_dict(), root_rank = 0) #hvd.broadcast_optimizer_state(optimizer, root_rank = 0) #unpack the bcasted tensor start_step = steptens.cpu().numpy()[0] start_epoch = steptens.cpu().numpy()[1] # Set up the data feeder # train train_dir = os.path.join(root_dir, "train") train_set = cam.CamDataset(train_dir, statsfile=os.path.join(root_dir, 'stats.h5'), channels=pargs.channels, shuffle=True, preprocess=True, comm_size=comm_size, comm_rank=comm_rank) train_loader = DataLoader( train_set, pargs.local_batch_size, num_workers=min([pargs.max_inter_threads, pargs.local_batch_size]), drop_last=True) # validation: we only want to shuffle the set if we are cutting off validation after a certain number of steps validation_dir = os.path.join(root_dir, "validation") validation_set = cam.CamDataset(validation_dir, statsfile=os.path.join( root_dir, 'stats.h5'), channels=pargs.channels, shuffle=(pargs.max_validation_steps is not None), preprocess=True, comm_size=comm_size, comm_rank=comm_rank) validation_loader = DataLoader( validation_set, pargs.local_batch_size, num_workers=min([pargs.max_inter_threads, pargs.local_batch_size]), drop_last=True) #for visualization if visualize: viz = vizc.CamVisualizer() # Train network if (pargs.logging_frequency > 0) and (comm_rank == 0): wandb.watch(net) printr( '{:14.4f} REPORT: starting training'.format( dt.datetime.now().timestamp()), 0) step = start_step epoch = start_epoch current_lr = pargs.start_lr if not pargs.lr_schedule else scheduler.get_last_lr( )[0] net.train() while True: printr( '{:14.4f} REPORT: starting epoch {}'.format( dt.datetime.now().timestamp(), epoch), 0) #for inputs_raw, labels, source in train_loader: for inputs, label, filename in train_loader: #send to device inputs = inputs.to(device) label = label.to(device) # forward pass outputs = net.forward(inputs) # Compute loss and average across nodes loss = criterion(outputs, label, weight=class_weights, fpw_1=fpw_1, fpw_2=fpw_2) # allreduce for loss loss_avg = loss.detach() dist.reduce(loss_avg, dst=0, op=dist.ReduceOp.SUM) # Compute score predictions = torch.max(outputs, 1)[1] iou = utils.compute_score(predictions, label, device_id=device, num_classes=3) iou_avg = iou.detach() dist.reduce(iou_avg, dst=0, op=dist.ReduceOp.SUM) # Backprop optimizer.zero_grad() if have_apex: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() optimizer.step() #step counter step += 1 if pargs.lr_schedule: current_lr = scheduler.get_last_lr()[0] scheduler.step() #print some metrics printr( '{:14.4f} REPORT training: step {} loss {} iou {} LR {}'. format(dt.datetime.now().timestamp(), step, loss_avg.item() / float(comm_size), iou_avg.item() / float(comm_size), current_lr), 0) #visualize if requested if (step % pargs.training_visualization_frequency == 0) and (comm_rank == 0): #extract sample id and data tensors sample_idx = np.random.randint(low=0, high=label.shape[0]) plot_input = inputs.detach()[sample_idx, 0, ...].cpu().numpy() plot_prediction = predictions.detach()[sample_idx, ...].cpu().numpy() plot_label = label.detach()[sample_idx, ...].cpu().numpy() #create filenames outputfile = os.path.basename(filename[sample_idx]).replace( "data-", "training-").replace(".h5", ".png") outputfile = os.path.join(plot_dir, outputfile) #plot viz.plot(filename[sample_idx], outputfile, plot_input, plot_prediction, plot_label) #log if requested if pargs.logging_frequency > 0: img = Image.open(outputfile) wandb.log( { "Training Examples": [ wandb.Image( img, caption="Prediction vs. Ground Truth") ] }, step=step) #log if requested if (pargs.logging_frequency > 0) and ( step % pargs.logging_frequency == 0) and (comm_rank == 0): wandb.log( {"Training Loss": loss_avg.item() / float(comm_size)}, step=step) wandb.log({"Training IoU": iou_avg.item() / float(comm_size)}, step=step) wandb.log({"Current Learning Rate": current_lr}, step=step) # validation step if desired if (step % pargs.validation_frequency == 0): #eval net.eval() count_sum_val = torch.Tensor([0.]).to(device) loss_sum_val = torch.Tensor([0.]).to(device) iou_sum_val = torch.Tensor([0.]).to(device) # disable gradients with torch.no_grad(): # iterate over validation sample step_val = 0 # only print once per eval at most visualized = False for inputs_val, label_val, filename_val in validation_loader: #send to device inputs_val = inputs_val.to(device) label_val = label_val.to(device) # forward pass outputs_val = net.forward(inputs_val) # Compute loss and average across nodes loss_val = criterion(outputs_val, label_val, weight=class_weights) loss_sum_val += loss_val #increase counter count_sum_val += 1. # Compute score predictions_val = torch.max(outputs_val, 1)[1] iou_val = utils.compute_score(predictions_val, label_val, device_id=device, num_classes=3) iou_sum_val += iou_val # Visualize if (step_val % pargs.validation_visualization_frequency == 0) and (not visualized) and (comm_rank == 0): #extract sample id and data tensors sample_idx = np.random.randint( low=0, high=label_val.shape[0]) plot_input = inputs_val.detach()[ sample_idx, 0, ...].cpu().numpy() plot_prediction = predictions_val.detach()[ sample_idx, ...].cpu().numpy() plot_label = label_val.detach()[sample_idx, ...].cpu().numpy() #create filenames outputfile = os.path.basename( filename[sample_idx]).replace( "data-", "validation-").replace(".h5", ".png") outputfile = os.path.join(plot_dir, outputfile) #plot viz.plot(filename[sample_idx], outputfile, plot_input, plot_prediction, plot_label) visualized = True #log if requested if pargs.logging_frequency > 0: img = Image.open(outputfile) wandb.log( { "Validation Examples": [ wandb.Image( img, caption= "Prediction vs. Ground Truth") ] }, step=step) #increase eval step counter step_val += 1 if (pargs.max_validation_steps is not None ) and step_val > pargs.max_validation_steps: break # average the validation loss dist.reduce(count_sum_val, dst=0, op=dist.ReduceOp.SUM) dist.reduce(loss_sum_val, dst=0, op=dist.ReduceOp.SUM) dist.reduce(iou_sum_val, dst=0, op=dist.ReduceOp.SUM) loss_avg_val = loss_sum_val.item() / count_sum_val.item() iou_avg_val = iou_sum_val.item() / count_sum_val.item() # print results printr( '{:14.4f} REPORT validation: step {} loss {} iou {}'. format(dt.datetime.now().timestamp(), step, loss_avg_val, iou_avg_val), 0) # log in wandb if (pargs.logging_frequency > 0) and (comm_rank == 0): wandb.log({"Validation Loss": loss_avg_val}, step=step) wandb.log({"Validation IoU": iou_avg_val}, step=step) # set to train net.train() #save model if desired if (step % pargs.save_frequency == 0) and (comm_rank == 0): checkpoint = { 'step': step, 'epoch': epoch, 'model': net.state_dict(), 'optimizer': optimizer.state_dict() } if have_apex: checkpoint['amp'] = amp.state_dict() torch.save( checkpoint, os.path.join( output_dir, pargs.model_prefix + "_step_" + str(step) + ".cpt")) #do some after-epoch prep, just for the books epoch += 1 if comm_rank == 0: # Save the model checkpoint = { 'step': step, 'epoch': epoch, 'model': net.state_dict(), 'optimizer': optimizer.state_dict() } if have_apex: checkpoint['amp'] = amp.state_dict() torch.save( checkpoint, os.path.join( output_dir, pargs.model_prefix + "_epoch_" + str(epoch) + ".cpt")) #are we done? if epoch >= pargs.max_epochs: break printr( '{:14.4f} REPORT: finishing training'.format( dt.datetime.now().timestamp()), 0)