def start_train(): ''' 训练 ''' use_amp = True # 前向反传N次,再更新参数 目的:增大batch(理论batch= batch_size * N) iter_size = 8 myNet = MyNet(use_amp).to("cuda:0") myNet = torch.nn.DataParallel(myNet, device_ids=[0, 1]) # 数据并行 myNet.train() # 训练开始前初始化 梯度缩放器 scaler = GradScaler() if use_amp else None # 加载预训练权重 if resume_train: scaler.load_state_dict(checkpoint['scaler']) # amp自动混合精度用到 optimizer.load_state_dict(checkpoint['optimizer']) myNet.load_state_dict(checkpoint["model"]) for epoch in range(1, 100): for batch_idx, (input, target) in enumerate(dataloader_train): # 数据 转到每个并行模型的主卡上 input = input.to("cuda:0") target = target.to("cuda:0") # 自动混合精度训练 if use_amp: # 自动广播 将支持半精度操作自动转为FP16 with autocast(): # 提取特征 feature = myNet(input) losses = loss_function(target, feature) loss = losses / iter_size scaler.scale(loss).backward() else: feature = myNet(input, target) losses = loss_function(target, feature) loss = losses / iter_size loss.backward() # 梯度累积,再更新参数 if (batch_idx + 1) % iter_size == 0: # 梯度更新 if use_amp: scaler.step(optimizer) scaler.update() else: optimizer.step() # 梯度清零 optimizer.zero_grad() # scaler 具有状态。恢复训练时需要加载 state = { 'net': myNet.state_dict(), 'optimizer': optimizer.state_dict(), 'scaler': scaler.state_dict() } torch.save(state, "filename.pth")
class AMPTrainer(cls): """Pytorch's automatic mixed precision requires: pytorch >= 1.6 see: https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) print('init pytorch\'s amp') self.scaler = GradScaler() def forward_pass(self, data, vars: StageVars): with autocast(): return super().forward_pass(data=data, vars=vars) def backward_pass(self, vars: StageVars): loss = vars.forward['loss'] if loss is not None: assert loss.dim() == 0, "loss must be reduced" with time_elapsed_to_profiler('backward'): self.opt.zero_grad() self.scaler.scale(loss).backward() # allow modifying the gradient directly self.scaler.unscale_(self.opt) def optimize(self, vars: StageVars): loss = vars.forward['loss'] if loss is not None: with time_elapsed_to_profiler('optimize'): self.scaler.step(self.opt) self.scaler.update() def get_state(self): # including the amp state state = super().get_state() state['scaler'] = self.scaler.state_dict() return state def load_state(self, state): # including the amp state super().load_state(state) print('loading pytorch\'s amp state ...') if 'scaler' in state: self.scaler.load_state_dict(state['scaler']) else: print('warning: scaler state is not available') def __repr__(self): return f'<AMPTrainer {super().__repr__()}>'
def load_model(model_name, model, device, mp=False): filepath = os.path.join('models', model_name + '.pt') checkpoint = torch.load(filepath) model.load_state_dict(checkpoint['model']) model.eval() optimizer = Adam(model.parameters()) optimizer.load_state_dict(checkpoint['optimizer']) scaler = GradScaler(enabled=mp) scaler.load_state_dict(checkpoint['scaler']) # scaler.set_growth_interval(500) # scaler.set_growth_factor(1) # scaler.set_backoff_factor(1) # print('Scale', scaler.get_scale()) epoch = checkpoint['epoch'] or 0 return model, optimizer, scaler, epoch
class NativeScaler: state_dict_key = "amp_scaler" def __init__(self): self._scaler = GradScaler() def __repr__(self) -> str: return repr(self.__class__.__name__) def __call__( self, loss, optimizer, step, accum_grad, clip_grad=None, parameters=None, create_graph=False, ): self._scaler.scale(loss / accum_grad).backward(create_graph=create_graph) if step % accum_grad == 0: if clip_grad is not None: assert parameters is not None self._scaler.unscale_( optimizer ) # unscale the gradients of optimizer's assigned params in-place nn.utils.clip_grad_norm_(parameters, clip_grad) self._scaler.step(optimizer) self._scaler.update() optimizer.zero_grad() def state_dict(self): return self._scaler.state_dict() def load_state_dict(self, state_dict): self._scaler.load_state_dict(state_dict)
class Trainer(object): def __init__(self, cfgs): save_dict = OrderedDict() save_dict["fold"] = cfgs["fold"] if cfgs["memo"] is not None: save_dict["memo"] = cfgs["memo"] # 1,2,3 specific_dir = ["{}-{}".format(key, save_dict[key]) for key in save_dict.keys()] cfgs["save_dir"] = os.path.join( cfgs["save_dir"], # cfgs["model"]["meta"], # cfgs["model"]["inputs"]["label"], "_".join(specific_dir), ) os.makedirs(cfgs["save_dir"], exist_ok=True) ####### CONFIGS self.cfgs = cfgs ####### Logging self.tb_writer = utils.get_writer(self.cfgs) self.txt_logger = utils.get_logger(self.cfgs) self.do_logging = True if len(self.cfgs["gpu"]) > 1: if dist.get_rank() != 0: self.do_logging = False if self.do_logging: self.txt_logger.write("\n\n----train.py----") self.txt_logger.write("\n{}".format(datetime.datetime.now())) self.txt_logger.write( "\n\nSave Directory: \n{}".format(self.cfgs["save_dir"]) ) self.txt_logger.write("\n\nConfigs: \n{}\n".format(self.cfgs)) ####### MODEL model = models.get_model(self.cfgs) if len(self.cfgs["gpu"]) > 1: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) self.device = torch.device("cuda:{}".format(self.cfgs["local_rank"])) self.model = model.to(self.device) self.model = DistributedDataParallel( self.model, device_ids=[self.cfgs["local_rank"]], output_device=self.cfgs["local_rank"], ) else: self.device = torch.device("cuda:{}".format(self.cfgs["local_rank"])) self.model = model.to(self.device) ####### Data train_dataset = inputs.get_dataset(self.cfgs, mode="train") if len(self.cfgs["gpu"]) > 1: train_sampler = DistributedSampler( train_dataset, num_replicas=len(self.cfgs["gpu"]), rank=self.cfgs["local_rank"], ) else: train_sampler = None self.train_loader = DataLoader( dataset=train_dataset, batch_size=self.cfgs["batch_size"], num_workers=self.cfgs["num_workers"], pin_memory=True, drop_last=False, collate_fn=inputs.get_collater(), sampler=train_sampler, ) # if self.do_logging: # self.txt_logger.write("\nDataset: ") # self.txt_logger.write( # "\nTRAIN Abnormal/Normal: {}/{}".format( # len(train_dataset.abnormal_meta_df), # len(train_dataset.normal_meta_df), # ) # ) ####### Opts self.optimizer = opts.get_optimizer(self.cfgs, self.model.parameters()) self.scheduler = opts.get_scheduler(self.cfgs, self.optimizer) self.grad_scaler = GradScaler(enabled=self.cfgs["use_amp"]) ####### Validator self.validator = Validator(self.cfgs, self.device) # if self.do_logging: # self.txt_logger.write( # "\nVAL Abnormal/Normal: {}/{}".format( # len(self.validator.val_loader.dataset.abnormal_meta_df), # len(self.validator.val_loader.dataset.normal_meta_df), # ) # ) # if self.cfgs["model"]["val"]["ignore_normal"]: # self.txt_logger.write("\nVAL Ignore Normal") # self.validator.val_loader.dataset.meta_df = ( # self.validator.val_loader.dataset.abnormal_meta_df # ) def do_train(self): ####### Setup Train self.epoch, self.iter, self.resume_epoch = 0, 0, 0 self.tot_val_record = { "best": {"det_recl": -1, "det_prec": -1, "det_f1": -1, "loss": np.inf} } if self.cfgs["model"]["train"]["resume_train"]: with open( os.path.join(self.cfgs["save_dir"], "tot_val_record.pkl"), "rb" ) as f: self.tot_val_record = pickle.load(f) self.iter, self.resume_epoch = ( self.tot_val_record["best"]["iteration"], self.tot_val_record["best"]["epoch"], ) resume_model_dir = os.path.join( self.cfgs["save_dir"], "epoch_{}.pt".format(self.resume_epoch) ) checkpoint = torch.load(resume_model_dir) self.model.load_state_dict(checkpoint["model"], strict=True) self.optimizer.load_state_dict(checkpoint["optimizer"]) self.grad_scaler.load_state_dict(checkpoint["scaler"]) self.txt_logger.write("\n\nResume Training Here! \n\n") if self.do_logging: self.txt_logger.write("\n\nStart Training! \n\n") header_columns = ["epoch", "iter", "time", "train_loss", "val_loss"] header_columns += ["det_recl", "det_prec", "det_fppi", "det_f1"] header_columns += ["cls_auc", "cls_sens", "cls_spec"] header_columns += ["best_epoch"] self.txt_logger.log_header(header_columns) ####### Train self.start_time = time.time() self.endurance = 0 for epoch in range(self.resume_epoch, self.cfgs["model"]["train"]["max_epoch"]): # self.train_loader.dataset.shuffle() # self.train_loader.dataset.meta_df = ( # self.train_loader.dataset.abnormal_meta_df # ) self.one_epoch_steps = len(self.train_loader) self.display_step = ( self.one_epoch_steps // self.cfgs["model"]["train"]["display_interval"] ) self.epoch = epoch if self.endurance > self.cfgs["model"]["train"]["endurance"]: if self.do_logging: self.txt_logger.write( "\nStop training! No more performance gain expected!" ) best_epoch = self.tot_val_record["best"]["epoch"] self.txt_logger.write( "\n\nBest saved at: {}, {} epoch\n\n".format( self.cfgs["save_dir"], best_epoch ) ) break self.train_val_one_epoch() def train_val_one_epoch(self): self.optimizer.zero_grad() self.model.train() t0 = time.time() for i, data in enumerate(self.train_loader): t1 = time.time() img = data["img"].permute(0, 3, 1, 2).to(self.device) logit = self.model(img) t2 = time.time() # FIXME: GPU Util이 안 나온다 loss = opts.calc_loss(self.cfgs, self.device, data, logit) t3 = time.time() self.grad_scaler.scale(loss).backward() self.grad_scaler.step(self.optimizer) self.grad_scaler.update() self.optimizer.zero_grad() t4 = time.time() # NOTE: Try to avoid excessive CPU-GPU synchronization (.item() calls, or printing values from CUDA tensors). if self.do_logging: loss = loss.detach().item() take_time = tools.convert_time(time.time() - self.start_time) train_logs = [loss, "-"] self.txt_logger.log_result( [self.epoch, "{}/{}".format(i, self.one_epoch_steps), take_time] + train_logs ) self.tb_writer.write_scalars( {"loss": {"train loss": loss}}, self.iter, ) if self.iter % self.display_step == 0: # Visualize # Find abnormal for viz_bi in range(len(data["fp"])): if data["bbox"][viz_bi, 0, -1] != -1: break with torch.no_grad(): self.model.eval() det_preds_viz = ( self.model(img, mode="viz")["preds"][viz_bi] .detach() .cpu() .numpy() ) if len(det_preds_viz) != 0: # sigmoid det_preds_viz[:, -1] = 1 / ( 1 + np.exp(-1 * det_preds_viz[:, -1]) ) else: det_preds_viz = np.ones((1, 6)) * -1 det_anns_viz = data["bbox"][viz_bi].numpy() self.tb_writer.write_images( data["fp"][viz_bi], data["img"][viz_bi].numpy(), det_preds_viz, det_anns_viz, self.iter, "train", ) self.model.train() self.iter += 1 lr0 = self.cfgs["model"]["opts"]["learning_rate"] wep = self.cfgs["model"]["opts"]["warmup_epoch"] if self.epoch < wep: for pg in self.optimizer.param_groups: pg["lr"] = lr0 / wep * (self.epoch + i / self.one_epoch_steps) else: if not self.scheduler is None: self.scheduler.step(self.epoch - wep + i / self.one_epoch_steps) t5 = time.time() if self.cfgs["do_profiling"]: print("\ndata", t1 - t0) print("forward", t2 - t1) print("calc loss", t3 - t2) print("backward", t4 - t3) print("logging", t5 - t4) t0 = t5 if self.epoch > self.cfgs["model"]["val"]["ignore_epoch"]: # Do Validation val_record, val_viz = self.validator.do_validate(self.model) self.tot_val_record[str(self.epoch + 1)] = val_record val_best = val_record[self.cfgs["model"]["val"]["best"]] # Save Model select_metric = self.cfgs["model"]["val"]["best"] val_improved = False if select_metric == "loss": if val_best < self.tot_val_record["best"][select_metric]: val_improved = True elif select_metric == "det_f1": if val_best > self.tot_val_record["best"][select_metric]: val_improved = True if val_improved: checkpoint = { "epoch": self.epoch, "model": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "scaler": self.grad_scaler.state_dict(), } model_name = os.path.join( self.cfgs["save_dir"], "epoch_" + str(self.epoch + 1) + ".pt" ) torch.save(checkpoint, model_name) self.tot_val_record["best"] = val_record self.tot_val_record["best"]["epoch"] = self.epoch + 1 self.tot_val_record["best"]["iteration"] = self.iter self.endurance = 0 else: self.endurance += 1 if self.do_logging: take_time = utils.tools.convert_time(time.time() - self.start_time) vloss = val_record["loss"] vbest_epoch = self.tot_val_record["best"]["epoch"] metric_keys = ["det_recl", "det_prec", "det_fppi", "det_f1"] metric_keys += ["cls_auc", "cls_sens", "cls_spec"] val_logs = [vloss] + [val_record[k] for k in metric_keys] self.txt_logger.log_result( [self.epoch + 1, self.iter, take_time, loss] + val_logs + [vbest_epoch], txt_write=True, ) self.txt_logger.write("\n", txt_write=True) self.tb_writer.write_images( val_viz["fp"], val_viz["img"], val_viz["pred"], val_viz["ann"], self.iter, "val", ) self.tb_writer.write_scalars( { "metrics": { "{}".format(key): val_record[key] for key in metric_keys } }, self.iter, ) self.tb_writer.write_scalars({"loss": {"val loss": vloss}}, self.iter) with open( os.path.join(self.cfgs["save_dir"], "tot_val_record.pkl"), "wb" ) as f: pickle.dump(self.tot_val_record, f)
def main_worker(gpu, ngpus_per_node, args): args.gpu = gpu logger = get_logger(args.logging_file) logger.info("Use GPU: {} for training".format(args.gpu)) args.rank = args.rank * ngpus_per_node + gpu torch.distributed.init_process_group(backend="nccl", init_method=args.dist_url, world_size=args.world_size, rank=args.rank) epochs = args.epochs input_size = args.input_size resume_epoch = args.resume_epoch initializer = KaimingInitializer() zero_gamma = ZeroLastGamma() mix_precision_training = args.mix_precision_training is_first_rank = True if args.rank % ngpus_per_node == 0 else False batches_pre_epoch = args.num_training_samples // (args.batch_size * ngpus_per_node) lr = 0.1 * (args.batch_size * ngpus_per_node // 32) if args.lr == 0 else args.lr model = get_model(models, args.model) model.apply(initializer) if args.last_gamma: model.apply(zero_gamma) logger.info('Apply zero last gamma init.') if is_first_rank and args.model_info: summary(model, torch.rand((1, 3, input_size, input_size))) parameters = model.parameters() if not args.no_wd else no_decay_bias(model) if args.sgd_gc: logger.info('Use SGD_GC optimizer.') optimizer = SGD_GC(parameters, lr=lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) else: optimizer = optim.SGD(parameters, lr=lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = CosineWarmupLr(optimizer, batches_pre_epoch, epochs, base_lr=args.lr, warmup_epochs=args.warmup_epochs) # dropblock_scheduler = DropBlockScheduler(model, batches_pre_epoch, epochs) if args.lookahead: optimizer = Lookahead(optimizer) logger.info('Use lookahead optimizer.') torch.cuda.set_device(args.gpu) model.cuda(args.gpu) args.num_workers = int( (args.num_workers + ngpus_per_node - 1) / ngpus_per_node) if args.mix_precision_training and is_first_rank: logger.info('Train with FP16.') scaler = GradScaler(enabled=args.mix_precision_training) model = nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) Loss = nn.CrossEntropyLoss().cuda(args.gpu) if not args.label_smoothing else \ LabelSmoothingLoss(args.classes, smoothing=0.1).cuda(args.gpu) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if args.autoaugment: train_transform = transforms.Compose([ transforms.RandomResizedCrop(input_size), transforms.RandomHorizontalFlip(), ImageNetPolicy, transforms.ToTensor(), normalize, ]) else: train_transform = transforms.Compose([ transforms.RandomResizedCrop(input_size), # Cutout(), transforms.RandomHorizontalFlip(), transforms.ColorJitter(0.4, 0.4, 0.4), transforms.ToTensor(), normalize, ]) val_transform = transforms.Compose([ transforms.Resize(int(input_size / 0.875)), transforms.CenterCrop(input_size), transforms.ToTensor(), normalize, ]) train_set = ImageNet(args.data_path, split='train', transform=train_transform) val_set = ImageNet(args.data_path, split='val', transform=val_transform) train_sampler = DistributedSampler(train_set) train_loader = DataLoader(train_set, args.batch_size, False, pin_memory=True, num_workers=args.num_workers, drop_last=True, sampler=train_sampler) val_loader = DataLoader(val_set, args.batch_size, False, pin_memory=True, num_workers=args.num_workers, drop_last=False) if resume_epoch > 0: loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.resume_param, map_location=loc) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) scaler.load_state_dict(checkpoint['scaler']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) print("Finish loading resume param.") torch.backends.cudnn.benchmark = True top1_acc = metric.Accuracy(name='Top1 Accuracy') top5_acc = metric.TopKAccuracy(top=5, name='Top5 Accuracy') loss_record = metric.NumericalCost(name='Loss') for epoch in range(resume_epoch, epochs): tic = time.time() train_sampler.set_epoch(epoch) if not args.mixup: train_one_epoch(model, train_loader, Loss, optimizer, epoch, lr_scheduler, logger, top1_acc, loss_record, scaler, args) else: train_one_epoch_mixup(model, train_loader, Loss, optimizer, epoch, lr_scheduler, logger, loss_record, scaler, args) train_speed = int(args.num_training_samples // (time.time() - tic)) if is_first_rank: logger.info( 'Finish one epoch speed: {} samples/s'.format(train_speed)) test(model, val_loader, Loss, epoch, logger, top1_acc, top5_acc, loss_record, args) if args.rank % ngpus_per_node == 0: checkpoint = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scaler': scaler.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), } torch.save( checkpoint, '{}/{}_{}_{:.5}.pt'.format(args.save_dir, args.model, epoch, top1_acc.get()))
if args.chkpoint: chk = torch.load(args.chkpoint, map_location=device) elif args.finetune: if args.chkpointft: chk = torch.load(args.chkpointft, map_location=device) else: sys.exit("Finetune can't be performed if chkpointft not supplied") else: chk = None start_epoch = 0 best_loss = float('-inf') if IsNegLoss else float('inf') if chk is not None: model.load_state_dict(chk['state_dict']) optimizer.load_state_dict(chk['optimizer']) scaler.load_state_dict(chk['AMPScaler']) best_loss = chk['best_loss'] start_epoch = chk['epoch'] + 1 iterations = chk['iterations'] main_train_epcoh = (chk['main_train_epoch'] + 1) if 'main_train_epoch' in chk else start_epoch #only be used for finetune if args.finetune: if args.fteprt: args.epochs = int((main_train_epcoh*(1+args.fteprt))) else: args.iterations = int(iterations*args.ftitrt) n_ft_ep = int(args.iterations // len(train_loader)) args.epochs = main_train_epcoh + n_ft_ep if args.epochs is None: args.epochs = int(args.iterations // len(train_loader) + 1)
class SelfSupervisionTask(ClassificationTask): """ A task prepares and holds all the components of a training like optimizer, datasets, dataloaders, losses, meters etc. Task also contains the variable like training iteration, epoch number etc. that are updated during the training. We prepare every single component according to the parameter settings user wants and specified in the yaml config file. Task also supports 2 additional things: 1) converts the model BatchNorm layers to the synchronized batchnorm 2) sets mixed precision (apex and pytorch both supported) """ def __init__(self, config: AttrDict): super().__init__() self.config = config self.checkpoint_path = None # Register the task to the proper device (cpu, gpu, ...) self.set_device() self.checkpoint_folder = None self.checkpoint = None self.available_splits = [] self.base_loss = None self.meters = None self.datasets = None self.phases = [] self.hooks = [] self.base_model = None self.optimizer = None self.amp_args = None self.amp_type = None self.amp_grad_scaler = None self.data_and_label_keys = [] self.set_amp_args() self._enable_manual_gradient_reduction = None # total number of parameter updates applied to the model by optimizer self.num_updates = 0 # measure time of several training components (data, forward, backward etc..) self.perf_stats = None # total number of phases including test + train self.num_phases = -1 # set by the trainer # number of train only phases self.num_train_phases = -1 # set by the prepare method # number or train epochs set to num_train_phases self.num_epochs = -1 # set by the trainer # total number of "training" iterations. Inferred from dataloader length and # num_train_phases self.max_iteration = -1 # set by trainer # Current phase id (includes train/test). Starts from 0 self.phase_idx = -1 # id of the current training phase training is at. Starts from 0 self.train_phase_idx = -1 # set by trainer # metrics stored during the training. self.metrics = {} # set by the trainer self.start_time = -1 # set by trainer # time of each batch in training and testing. This can be used to get average # batch time etc. batch_time is appended after every parameter update. self.batch_time = [] # set by trainer # we maintain and store the iteration in the state itself. It counts # total number of iterations we do in training phases. Updated # after every forward pass of training step. # Starts from 1 self.iteration = 0 # collect how many total iterations we make irrespective of train/test phase. # Useful for debugging purposes. Starts from 1. self.local_iteration_num = -1 # set by trainer # for every phase, record the start time. Reset at the beginning of each phase # by SetDataSamplerEpochHook hook. self.phase_start_time = -1 # set by the hook at start of each epoch or phase # for every phase, record the number of batches seen. Incremented after every # forward pass. Reset at the start of each phase by # SetDataSamplerEpochHook hook. Useful for debugging. self.batches = -1 # set by the hook at start of each epoch or phase # loss curve. Reset at start of each phase/epoch by SetDataSamplerEpochHook hook. self.losses = [] # set by the hook at start of each epoch or phase # set the bucket_cap_mb for gradient reduction. This can be tuned to overlap # communication as much as possible self.set_ddp_bucket_cap_mb() self.use_gpu = self.device.type == "cuda" # optionally save the exponential moving average (ema) of the base_model. # and/or run the meters on the ema of the base_model. self.ema_model = None self.ema_meters = [] def set_device(self): """ Set the training device: whether gpu or cpu. We use the self.device in the rest of the workflow to determine if we should do cpu only training or use gpu. set MACHINE.DEVICE = "gpu" or "cpu" """ try: self.device = torch.device( "cuda" if self.config.MACHINE.DEVICE == "gpu" else "cpu" ) except AttributeError: self.device = torch.device("cuda") def set_ddp_bucket_cap_mb(self): """ PyTorch DDP supports setting the bucket_cap_mb for all reduce. Tuning this parameter can help with the speed of the model. We use the default pytorch value of 25MB. """ self.ddp_bucket_cap_mb = self.config.DATA.DDP_BUCKET_CAP_MB assert self.ddp_bucket_cap_mb > 0, "bucket_cap_mb must be positive" def set_available_splits(self): """ Given the data settings, we determine if we are using both train and test datasets. If TEST_MODEL=true, we will add the test to the available_splits. If TEST_ONLY=false, we add train to the split as well. """ if self.config.TEST_MODEL: self.available_splits.append("TEST") if not self.config.TEST_ONLY: self.available_splits.append("TRAIN") return self def set_amp_args(self): """ Two automatic mixed precision implementations are available: Apex's and PyTorch's. - If Apex's AMP is enabled, amp_args is a dictionary containing arguments to be passed to amp.initialize. Set to None to disable amp. To enable mixed precision training, pass amp_args={"opt_level": "O1"} here. See https://nvidia.github.io/apex/amp.html for more info. - If Pytorch's AMP is enabled, no arguments are needed. """ if self.config.MODEL.AMP_PARAMS.USE_AMP: assert ( self.device.type == "cuda" ), "Mixed precision is only available on CUDA devices for now" # This will rightly fail if the setting is not correct self.amp_type = AmpType[self.config.MODEL.AMP_PARAMS.AMP_TYPE.upper()] if self.amp_type == AmpType.APEX: self._init_apex_grad_scaler() elif self.amp_type == AmpType.PYTORCH: self._init_pytorch_grad_scaler() logging.info(f"Setting AMP: {self.amp_type} - args: {self.amp_args}") else: self.amp_args, self.amp_type = None, None logging.info("Not using Automatic Mixed Precision") def _init_apex_grad_scaler(self): # Check Apex availability if not is_apex_available(): raise RuntimeError("Apex is not available. Can't use mixed precision") # "amp_args" are actually Apex Amp args self.amp_args = self.config.MODEL.AMP_PARAMS.AMP_ARGS logging.info(f"Setting AMP: using apex, args {self.amp_args}") def _init_pytorch_grad_scaler(self): if self.config["OPTIMIZER"]["name"] == "zero": assert is_fairscale_sharded_available(), ( "To use ZeRO with PyTorch AMP, ShardedGradScaler() " "from fairscale is needed. Please upgrade fairscale" ) from fairscale.optim.grad_scaler import ShardedGradScaler self.amp_grad_scaler = ShardedGradScaler() logging.info("Setting AMP: using sharded grad scaler") else: self.amp_grad_scaler = TorchGradScaler() logging.info("Setting AMP: using pytorch grad scaler") def set_checkpoint_path(self, checkpoint_path: str): """ Set the checkpoint path for the training """ self.checkpoint_path = checkpoint_path def set_checkpoint_folder(self, checkpoint_folder: str): """ Set the checkpoint folder for the training """ self.checkpoint_folder = checkpoint_folder def set_iteration(self, iteration): """ Set the iteration number. we maintain and store the iteration in the state itself. It counts total number of iterations we do in training phases. Updated after every forward pass of training step. Starts from 1 """ assert iteration >= 0, "Iteration number must be positive" self.iteration = iteration @property def enable_manual_gradient_reduction(self) -> bool: """ Lazily initial the enable flag once when model is not None. """ if self._enable_manual_gradient_reduction is None and self.model is not None: self.set_manual_gradient_reduction() if self._enable_manual_gradient_reduction: return True return False def set_manual_gradient_reduction(self) -> None: """ Called during __init__ to set a flag if manual gradient reduction is enabled. """ assert self.model is not None self._enable_manual_gradient_reduction = manual_gradient_reduction( self.model, self.config["DISTRIBUTED"]["MANUAL_GRADIENT_REDUCTION"] ) if self._enable_manual_gradient_reduction: logging.info("Enabling manual gradient reduction") @classmethod def from_config(cls, config): """ Create the task from the yaml config input. """ test_only = config.TEST_ONLY return ( cls(config) .set_available_splits() .set_test_only(test_only) .set_epoch_phase_info() ) def set_epoch_phase_info(self): # In case optimizer doesn't exist. E.g. for feature extraction. optimizer = getattr(self.config, "OPTIMIZER", {}) self.num_epochs = getattr(optimizer, "num_epochs", 1) self.num_train_phases_per_epoch = getattr( self.config["DATA"]["TRAIN"], "TRAIN_PHASES_PER_EPOCH", 1 ) self.num_train_phases = ( self.config["OPTIMIZER"]["num_epochs"] * self.num_train_phases_per_epoch ) return self # We keep the function because this is used by hooks like checkpoint etc. def get_config(self): """ Utility function to store and use the config that was used for the given training. """ return {"config": self.config} def _build_phases(self): """ Returns list of phases from config. These phases will look like: { train: is this a train or test phase (bool)? } If this is a test only run, then only test phases will be generated, if this is a training run, then #phases = #train-phases + #test-phases, interleaved. We also add the test phases every TEST_EVERY_NUM_EPOCH if we don't want the tst to run after every test phase. """ if not self.config["TEST_ONLY"]: phases = [{"train": True} for _ in range(self.num_train_phases)] # whether the model is train or test only. If the model is not test # only, then whether we do test as well or not, is decided from the # config file. test_every = ( self.config.get("TEST_EVERY_NUM_EPOCH", 1) * self.num_train_phases_per_epoch ) output_phases = [] for idx, phase in enumerate(phases): output_phases.append(phase) if idx % test_every == 0 or idx == (len(phases) - 1): output_phases.append({"train": False}) # we do a little surgery here. Either the phases are test only or # [train + test] both interleaved. If we don't want the model to be tested # at all (which is sometimes the case in self-supervised learning), we # remove the test phases. if not self.config["TEST_MODEL"]: output_phases = [phase for phase in output_phases if phase["train"]] else: output_phases = [{"train": False} for _ in range(self.num_train_phases)] return output_phases def build_datasets(self, current_train_phase_idx=0): """ Get the datasets for the data splits we will use in the training. The set_available_splits variable determines the splits used in the training. """ datasets, data_and_label_keys = {}, {} for split in self.available_splits: datasets[split.lower()] = build_dataset( cfg=self.config, split=split, current_train_phase_idx=current_train_phase_idx, ) data_and_label_keys["input"] = self.config.DATA[split].INPUT_KEY_NAMES data_and_label_keys["target"] = self.config.DATA[split].TARGET_KEY_NAMES return datasets, data_and_label_keys def build_dataloaders( self, pin_memory: bool, current_train_phase_idx=0 ) -> torch.utils.data.DataLoader: """ Build PyTorch dataloaders for all the available_splits. By default, we construct the standard PyTorch Dataloader and allow setting all dataloader options. """ # Gives sampler same seed for entire distributed group as per pytorch documentation. sampler_seed = self.config["SEED_VALUE"] loaders = { split.lower(): build_dataloader( dataset=self.datasets[split.lower()], dataset_config=self.config["DATA"][split], num_dataloader_workers=self.config.DATA.NUM_DATALOADER_WORKERS, pin_memory=pin_memory, multi_processing_method=self.config.MULTI_PROCESSING_METHOD, device=self.device, sampler_seed=sampler_seed, split=split.lower(), ) for split in self.available_splits } return loaders def get_global_batchsize(self): """ Return global batchsize used in the training across all the trainers. We check what phase we are in (train or test) and get the dataset used in that phase. We call get_global_batchsize() of the dataset. """ for phase_type in self.datasets: if phase_type.lower() == self.phase_type.lower(): return self.datasets[phase_type].get_global_batchsize() raise ValueError(f"{self.phase_type} not found in self.datasets") def _build_optimizer(self): """ Build optimizers using the optimizer settings specified by user. For SGD, we support LARC as well. In order to use LARC, Apex must be installed. """ optimizer_config = self.config["OPTIMIZER"] if optimizer_config.use_larc and optimizer_config.name != "sgd_fsdp": assert is_apex_available(), "Apex must be available to use LARC" optim = build_optimizer(optimizer_config) return optim def _build_optimizer_schedulers(self): """ Build the param schedulers to be used in training. """ return build_optimizer_schedulers(self.config["OPTIMIZER"]) def _build_loss(self): """ Build the loss used in training. Supports all PyTorch losses and custom defined losses. For some losses that require memory banks (for example in info_nce loss), we need to store the size of data as we use it to allocate memory. Since dataset size is not known at the time of config parsing, we set the data size parameter here. """ # in some cases like memory bank, we need to store the size of data # as we use it to allocate memory. Hence we set that parameter here. logging.info("Building loss...") loss_name = self.config.LOSS["name"] assert loss_name in list(self.config.LOSS.keys()), ( f"Loss {loss_name} params unknown. The loss name and the param dict " f"key name should match. Known: {list(self.config.LOSS.keys())}" ) loss_config = self.config.LOSS[loss_name] if "num_train_samples" in loss_config.keys(): for split in self.available_splits: if split == "TRAIN": loss_config["num_train_samples"] = len(self.datasets["train"]) if split == "TEST": loss_config["num_train_samples"] = len(self.datasets["test"]) loss_config["name"] = loss_name loss = build_loss(loss_config) return loss def _build_meters(self): """ Returns meters for task. """ meter_names = self.config["METERS"].get("names", []) if not meter_names: return [] meters = [] for meter_name in meter_names: meter_params = self.config["METERS"][meter_name] meter_config = {"name": meter_name, **meter_params} meters.append(build_meter(meter_config)) return meters def _restore_model_weights(self, model, strict: bool = False): """ If using a weights file to initialize the model, we load the weights and initialize the model. Since the weights file specified by user might not be VISSL trained weights, we expose several config options like APPEND_PREFIX, etc to allow successful loading of the weights. See MODEL.WEIGHTS_INIT description in vissl/config/defaults.yaml for details. """ params_from_file = self.config["MODEL"]["WEIGHTS_INIT"] init_weights_path = params_from_file["PARAMS_FILE"] assert init_weights_path, "Shouldn't call this when init_weight_path is empty" logging.info(f"Initializing model from: {init_weights_path}") if g_pathmgr.exists(init_weights_path): checkpoint = CheckpointLoader.load_and_broadcast_init_weights( checkpoint_path=init_weights_path, device=torch.device("cpu") ) logging.info(f"Checkpoint loaded: {init_weights_path}...") model.init_model_from_weights_params_file( self.config, checkpoint, strict=strict ) return model def _build_model(self, strict_load: bool = False): """ - Builds and returns model used for task. The returned model is not copied to gpu yet (if using gpu) and neither wrapped with DDP yet. This is done later by self.prepare() - We also convert the model BatchNorm layers to SyncBatchNorm if user has set the config option. We support PyTorch and Apex SyncBatchNorms both. - If the model is set to be in evaluation model and the full model must be frozen, we freeze the model. - If the model must be initialized from a checkpoint or user passed weights file we initialize the model from the checkpoint or the weights. """ logging.info("Building model....") # Instantiate the raw model as specified model = build_model(self.config["MODEL"], self.config["OPTIMIZER"]) # Convert the BatchNorm layers to SyncBatchNorm if needed # Both Apex and Pytorch SyncBatchNorms are GPU only if ( self.config["MODEL"]["SYNC_BN_CONFIG"]["CONVERT_BN_TO_SYNC_BN"] and self.config["MACHINE"]["DEVICE"] == "gpu" ): model = convert_sync_bn(self.config, model) # Enforce eval mode, no matter what the prior tranforms have done. # For instance apex converts batch-norms and sets `requires_grad` to True if self.config["MODEL"]["FEATURE_EVAL_SETTINGS"]["EVAL_MODE_ON"]: if self.config["MODEL"]["FEATURE_EVAL_SETTINGS"]["FREEZE_TRUNK_ONLY"]: logging.info( "config.MODEL.FEATURE_EVAL_SETTINGS.FREEZE_TRUNK_ONLY=True, " "will freeze trunk..." ) model.freeze_trunk() elif self.config["MODEL"]["FEATURE_EVAL_SETTINGS"]["FREEZE_TRUNK_AND_HEAD"]: logging.info( "config.MODEL.FEATURE_EVAL_SETTINGS.FREEZE_TRUNK_AND_HEAD=True, will " "freeze trunk and head..." ) model.freeze_head_and_trunk() # assert that if the user set the PARAMS_FILE, it must exist and be valid. if ( self.checkpoint_path is None and self.config["MODEL"]["WEIGHTS_INIT"]["PARAMS_FILE"] ): assert g_pathmgr.exists( self.config["MODEL"]["WEIGHTS_INIT"]["PARAMS_FILE"] ), "Specified PARAMS_FILE does NOT exist" # If we want to initialize the model in case of finetuning or evaluation, # we do it here. But we check that there is no checkpoint existing before # This is important in cases when the model training dies. if ( self.checkpoint_path is None and self.config["MODEL"]["WEIGHTS_INIT"]["PARAMS_FILE"] and g_pathmgr.exists(self.config["MODEL"]["WEIGHTS_INIT"]["PARAMS_FILE"]) ): model = self._restore_model_weights(model, strict=strict_load) return model def init_distributed_data_parallel_model(self): """ This method overloads the ClassificationTask class's method from ClassyVision. """ if not is_distributed_training_run(): return for module in self.base_model.modules(): if isinstance(module, FullyShardedDataParallel): raise ValueError( "DistributedDataParallel should not be used" "with a FullyShardedDataParallel model.\n" "Please set config.TRAINER.TASK_NAME='self_supervision_fsdp_task'" ) super().init_distributed_data_parallel_model() def set_epoch( self, phase_type: str, epoch: int, start_iter: int, train_phase_idx: int ): if hasattr(self.dataloaders[phase_type], "sampler"): sampler = self.dataloaders[phase_type].sampler # (Re-)Shuffle data: set epoch of distributed (or fairstore) sampler # Resume from the iteration if valid self.set_train_epoch_start_iter(sampler, epoch, start_iter, train_phase_idx) print_sampler_config(sampler) # call set_epoch and set_start_iter for AirstoreDataset since it handles # shuffle and sample skipping behavior internally dataset = self.datasets[phase_type] if hasattr(dataset, "data_objs"): for data_obj in dataset.data_objs: self.set_train_epoch_start_iter( data_obj, epoch, start_iter, train_phase_idx ) def set_train_epoch_start_iter( self, dataset_or_sampler, epoch: int, start_iter: int, train_phase_idx: int ): # (Re-)Shuffle data: set epoch of distributed (or fairstore) sampler if hasattr(dataset_or_sampler, "set_epoch"): dataset_or_sampler.set_epoch(epoch) # Resume from the iteration if valid if hasattr(dataset_or_sampler, "set_start_iter"): dataset_or_sampler.set_start_iter(start_iter) if hasattr(dataset_or_sampler, "set_train_phase_idx"): dataset_or_sampler.set_train_phase_idx(train_phase_idx) def num_phase_samples(self, phase_type: str) -> int: """ Number of samples in a phase. """ dataset = self.datasets[phase_type.lower()] return dataset.num_samples() def _compute_start_iter_from_checkpoint(self, phase_type) -> int: # used for calculating the start iteration (count from current epoch) when resuming # from checkpoint if self.checkpoint is None or self.checkpoint["iteration"] <= 0: return 0 num_iters_in_epochs = len(self.dataloaders[phase_type]) num_epochs = self.checkpoint["train_phase_idx"] + 1 num_train_iters_done = num_epochs * num_iters_in_epochs return self.checkpoint["iteration"] - num_train_iters_done def recreate_data_iterator( self, phase_type: str, epoch: int, compute_start_iter: bool, train_phase_idx: int, ): """ Recreate data iterator (including multiprocessing workers) and destroy the previous iterators. This is called when we load a new checkpoint or when phase changes during the training (one epoch to the next). DataSampler may need to be informed on those events to update the epoch and start_iteration so that the data is deterministically shuffled, so we call them here. """ start_iter = 0 if compute_start_iter: start_iter = self._compute_start_iter_from_checkpoint(phase_type) self.set_epoch(phase_type, epoch, start_iter, train_phase_idx) # Gives sampler same seed for entire distributed group as per pytorch documentation. sampler_seed = self.config["SEED_VALUE"] dataset = self.datasets[phase_type] # For OSS, this will always return false. # Otherwise, we will rebuild the dataloader after every phase. if dataset.rebuild_dataloader(): dataloader = build_dataloader( dataset=dataset, dataset_config=self.config.DATA[phase_type.upper()], num_dataloader_workers=self.config.DATA.NUM_DATALOADER_WORKERS, pin_memory=self.config.DATA.PIN_MEMORY, multi_processing_method=self.config.MULTI_PROCESSING_METHOD, device=self.device, sampler_seed=sampler_seed, split=phase_type, ) # delete old dataloader and reset it. del self.dataloaders[phase_type] gc.collect() self.dataloaders[phase_type] = dataloader # delete old dataiterator and reset it. del self.data_iterator gc.collect() self.data_iterator = iter(self.dataloaders[phase_type]) def _set_classy_state(self, state): """ We load/set the model state setting here to resume correctly from the specified state. Usually called when resuming training from a previous model checkpoint. We set the model phase (train or eval), model weights, copy the model to correct device, initialize meters, initialize optimizers initialize amp state, set loss state, set the train phase number, iteration, recreate data iterators, etc. """ logging.info("=======Updating classy state_dict from checkpoint=======") # here we load the state specific things only. The other extra variables # are init from the checkpoint in the trainer step. self.train = state["train"] self.base_model.set_classy_state(state["base_model"]) # We need to set the model on correct device here unlike in the case of # training from scratch. The optimizer looks at the model parameters like # momentum etc. for getting the device info. Since in case of scratch # training, we don't have those and the optimizer just gets the inputs # as cuda inputs from the model, it can work. However, when we load from # a checkpoint, we already have these parameters and the type is CPU # (since the model isn't copied to gpu yet). The copy_model_to_gpu() # doesn't modify optimizer params device. The optimizer is constructed # with the CPU inputs. When the model runs, it rather sends CUDA. self.base_model.to(self.device) self._set_ema_model_state(state) for meter, meter_state in zip(self.meters, state["meters"]): meter.set_classy_state(meter_state) self.optimizer.set_classy_state(state["optimizer"]) # restore amp state. It's called after amp.initialize is done. if "amp" in state: if self.amp_type == AmpType.APEX: if is_apex_available(): apex.amp.load_state_dict(state["amp"]) else: logging.warning( "Loading a checkpoint which has amp state but apex isn't available now" ) else: self.amp_grad_scaler.load_state_dict(state["amp"]) self.phase_idx = state["phase_idx"] self.train_phase_idx = state["train_phase_idx"] self.num_updates = state["num_updates"] self.losses = state["losses"] phase_type = "train" if self.train else "test" phase = self.phases[self.phase_idx] # Re-create the data iterator. # We are restoring from a checkpoint, which means we need to # (1) set the right epoch # (2) set the right start_iter # epoch number is `phase_idx + 1` since checkpoint's value is the epoch finished. # start_iter is computed in recreate_data_iterator based on iteration # number from the checkpoint state. self.recreate_data_iterator( phase_type, epoch=self.phase_idx + 1, compute_start_iter=True, train_phase_idx=self.train_phase_idx + 1, ) # set the model to train or eval depending on what phase we are in self.base_model.train(phase["train"]) if self.train and self.train_phase_idx >= 0: self.optimizer.on_epoch(self.where) def _set_ema_model_state(self, state): """ Only used if EmaMetersHook is enabled. """ if self.ema_model is not None: logging.info("Loading ema model") self.ema_model.module.set_classy_state(state["ema_model"]) for meter, meter_state in zip(self.ema_meters, state["ema_meters"]): meter.set_classy_state(meter_state) def _update_classy_state(self, state_dict=None): """ Updates classy state with the provided state dict from a checkpoint. state_dict = checkpoint loaded state """ if state_dict is not None: try: self._set_classy_state(state_dict) success = True except Exception as e: logging.exception(f"Could not load the checkpoint: {e}") success = False assert success, "Update classy state from checkpoint failed." return self def _set_ddp_options(self): """ set DDP options if the user has supplied them """ broadcast_buffers = self.config["DISTRIBUTED"]["BROADCAST_BUFFERS"] if broadcast_buffers: logging.info( "Broadcast model BN buffers from primary on every forward pass" ) broadcast_buffers_enum_mode = BroadcastBuffersMode.FORWARD_PASS self.set_distributed_options( broadcast_buffers_mode=broadcast_buffers_enum_mode ) # NOQA def run_hooks(self, hook_function_name: str): """ Override the ClassyTask run_hook function and run the hooks whenever called """ for hook in self.hooks: getattr(hook, hook_function_name, ClassyHook._noop)(self) def prepare_optimizer(self): """ Constructs the optimizer using the user defined settings in the yaml config. The model must be on the correct device (cuda or cpu) by this point. """ param_groups = get_optimizer_param_groups( model=self.base_model, model_config=self.config["MODEL"], optimizer_config=self.config["OPTIMIZER"], optimizer_schedulers=self.optimizer_schedulers, ) self.optimizer.set_param_groups(param_groups) def prepare(self, pin_memory: bool = False): """ Prepares the task: - dataloaders - model - copy model to correct device - meters - loss - optimizer - LR schedulers - AMP state - resume from a checkpoint if available """ self.phases = self._build_phases() self.num_phases = len(self.phases) self.base_model = self._build_model() self._set_ddp_options() self.meters = self._build_meters() self.optimizer = self._build_optimizer() self.optimizer_schedulers = self._build_optimizer_schedulers() if self.device.type == "cuda": self.base_model = copy_model_to_gpu(self.base_model) # initialize the pytorch optimizer now since the model has been moved to # the appropriate device. self.prepare_optimizer() # Enable mixed precision grad scalers if self.amp_type == AmpType.APEX: # Allow Apex Amp to perform casts as specified by the amp_args. # This updates the model and the PyTorch optimizer (which is wrapped # by the ClassyOptimizer in self.optimizer). # NOTE: this must happen before loading the checkpoint. See # https://nvidia.github.io/apex/amp.html#checkpointing for more details. self.base_model, self.optimizer.optimizer = apex.amp.initialize( self.base_model, self.optimizer.optimizer, **self.amp_args ) # Create EMA average of the model if hook is specified. ema_config = self.config["HOOKS"]["EMA_MODEL"] if ema_config["ENABLE_EMA_METERS"] or ema_config["SAVE_EMA_MODEL"]: self._create_ema_model() # Restore an hypothetical checkpoint vissl_state_dict = None if self.checkpoint_path is not None: self.checkpoint = CheckpointLoader.load_and_broadcast_checkpoint( checkpoint_folder=self.checkpoint_folder, checkpoint_path=self.checkpoint_path, device=torch.device("cpu"), ) if self.checkpoint is not None: self.iteration = self.checkpoint["iteration"] self.local_iteration_num = self.checkpoint["iteration_num"] vissl_state_dict = self.checkpoint.get("classy_state_dict") else: raise ValueError(f"Could not load checkpoint: {self.checkpoint_path}") current_train_phase_idx = ( vissl_state_dict["train_phase_idx"] + 1 if vissl_state_dict else 0 ) self.datasets, self.data_and_label_keys = self.build_datasets( current_train_phase_idx ) # set dataset state before building dataloader, in order to capture checkpoint info. if vissl_state_dict and "train" in self.datasets: self.datasets["train"].set_classy_state( vissl_state_dict.get("train_dataset_iterator") ) self.dataloaders = self.build_dataloaders( pin_memory=pin_memory, current_train_phase_idx=current_train_phase_idx ) # Build base loss, move to device, and load from checkpoint if applicable self.base_loss = self._build_loss() self.base_loss = self.base_loss.to(self.device) if self.checkpoint and "loss" in self.checkpoint: self.base_loss.load_state_dict(self.checkpoint["loss"]) logging.info("======Loaded loss state from checkpoint======") return self._update_classy_state(vissl_state_dict) def prepare_extraction(self, pin_memory: bool = False): """ Prepares a light-weight task for feature extraction on multi-gpu. The model runs in eval mode only. """ self.datasets, self.data_and_label_keys = self.build_datasets() self.dataloaders = self.build_dataloaders(pin_memory=pin_memory) # build the meters in case the extraction is for predictions. self.meters = self._build_meters() self.base_model = self._build_model(strict_load=True) if self.device.type == "cuda": self.base_model = copy_model_to_gpu(self.base_model) return self def add_dummy_layer(self): """ In case of feature evaluation mode, if we are freezing both trunk and head, DDP won't work as there are no parameters in the model. Adding the dummy head will lead to features being not right. So we rather add the dummy layer to the model and use DDP. We copy the model to gpu (if using gpus) after the new dummy layer addition. """ fully_frozen_model = self.base_model.is_fully_frozen_model() if fully_frozen_model: self.base_model.dummy_layer = torch.nn.Linear(4, 4) if self.device.type == "cuda": self.base_model = copy_model_to_gpu(self.base_model) def _create_ema_model(self): logging.info("Building the EMA model.") ema_model = build_model(self.config["MODEL"], self.config["OPTIMIZER"]) self.ema_model = ModelEmaV2( ema_model, decay=self.config["HOOKS"]["EMA_MODEL"]["DECAY"], device=self.config["HOOKS"]["EMA_MODEL"]["EMA_DEVICE"], ) self.ema_model.set(self.base_model)
class DeepvacTrain(Deepvac): def __init__(self, deepvac_config): super(DeepvacTrain, self).__init__(deepvac_config) self.initTrainParameters() self.initTrainContext() def setTrainContext(self): self.is_train = True self.is_val = False self.phase = 'TRAIN' self.dataset = self.train_dataset self.loader = self.train_loader self.batch_size = self.conf.train.batch_size self.net.train() if self.qat_net_prepared: self.qat_net_prepared.train() def setValContext(self): self.is_train = False self.is_val = True self.phase = 'VAL' self.dataset = self.val_dataset self.loader = self.val_loader self.batch_size = self.conf.val.batch_size self.net.eval() if self.qat_net_prepared: self.qat_net_prepared.eval() def initTrainContext(self): self.scheduler = None self.initOutputDir() self.initSummaryWriter() self.initCriterion() self.initOptimizer() self.initScheduler() self.initCheckpoint() self.initTrainLoader() self.initValLoader() def initTrainParameters(self): self.dataset = None self.loader = None self.target = None self.epoch = 0 self.step = 0 self.iter = 0 # Creates a GradScaler once at the beginning of training. self.scaler = GradScaler() self.train_time = AverageMeter() self.load_data_time = AverageMeter() self.data_cpu2gpu_time = AverageMeter() self._mandatory_member_name = [ 'train_dataset', 'val_dataset', 'train_loader', 'val_loader', 'net', 'criterion', 'optimizer' ] def initOutputDir(self): if self.conf.output_dir != 'output' or self.conf.output_dir != './output': LOG.logW( "According deepvac standard, you should save model files to [output] directory." ) self.output_dir = '{}/{}'.format(self.conf.output_dir, self.branch) LOG.logI('model save dir: {}'.format(self.output_dir)) #for DDP race condition os.makedirs(self.output_dir, exist_ok=True) def initSummaryWriter(self): event_dir = "{}/{}".format(self.conf.log_dir, self.branch) self.writer = SummaryWriter(event_dir) if not self.conf.tensorboard_port: return from tensorboard import program tensorboard = program.TensorBoard() self.conf.tensorboard_ip = '0.0.0.0' if self.conf.tensorboard_ip is None else self.conf.tensorboard_ip tensorboard.configure(argv=[ None, '--host', str(self.conf.tensorboard_ip), '--logdir', event_dir, "--port", str(self.conf.tensorboard_port) ]) try: url = tensorboard.launch() LOG.logI('Tensorboard at {} '.format(url)) except Exception as e: LOG.logE(e.msg) def initCriterion(self): self.criterion = torch.nn.CrossEntropyLoss() LOG.logW( "You should reimplement initCriterion() to initialize self.criterion, unless CrossEntropyLoss() is exactly what you need" ) def initCheckpoint(self): if not self.conf.checkpoint_suffix or self.conf.checkpoint_suffix == "": LOG.logI('Omit the checkpoint file since not specified...') return LOG.logI('Load checkpoint from {} folder'.format(self.output_dir)) self.net.load_state_dict( torch.load(self.output_dir + '/model__{}'.format(self.conf.checkpoint_suffix), map_location=self.device)) state_dict = torch.load( self.output_dir + '/checkpoint__{}'.format(self.conf.checkpoint_suffix), map_location=self.device) self.optimizer.load_state_dict(state_dict['optimizer']) if self.scheduler: self.scheduler.load_state_dict(state_dict['scheduler']) if self.conf.amp: LOG.logI( "Will load scaler from checkpoint since you enabled amp, make sure the checkpoint was saved with amp enabled." ) try: self.scaler.load_state_dict(state_dict["scaler"]) except: LOG.logI( "checkpoint was saved without amp enabled, so use fresh GradScaler instead." ) self.scaler = GradScaler() self.epoch = state_dict['epoch'] def initScheduler(self): if isinstance(self.conf.lr_step, list): self.scheduler = torch.optim.lr_scheduler.MultiStepLR( self.optimizer, self.conf.lr_step, self.conf.lr_factor) elif isinstance(self.conf.lr_step, FunctionType): self.scheduler = torch.optim.lr_scheduler.LambdaLR( self.optimizer, lr_lambda=self.conf.lr_step) else: self.scheduler = torch.optim.lr_scheduler.StepLR( self.optimizer, self.conf.lr_step, self.conf.lr_factor) LOG.logW( "You should reimplement initScheduler() to initialize self.scheduler, unless lr_scheduler.StepLR() or lr_scheduler.MultiStepLR() is exactly what you need" ) def initTrainLoader(self): self.train_loader = None LOG.logE( "You must reimplement initTrainLoader() to initialize self.train_loader", exit=True) def initValLoader(self): self.val_loader = None LOG.logE( "You must reimplement initTrainLoader() to initialize self.val_loader", exit=True) def initOptimizer(self): self.initSgdOptimizer() LOG.logW( "You should reimplement initOptimizer() to initialize self.optimizer, unless SGD is exactly what you need" ) def initSgdOptimizer(self): self.optimizer = optim.SGD(self.net.parameters(), lr=self.conf.lr, momentum=self.conf.momentum, weight_decay=self.conf.weight_decay, nesterov=self.conf.nesterov) def initAdamOptimizer(self): self.optimizer = optim.Adam( self.net.parameters(), lr=self.conf.lr, ) for group in self.optimizer.param_groups: group.setdefault('initial_lr', group['lr']) def initRmspropOptimizer(self): self.optimizer = optim.RMSprop( self.net.parameters(), lr=self.conf.lr, momentum=self.conf.momentum, weight_decay=self.conf.weight_decay, # alpha=self.conf.rmsprop_alpha, # centered=self.conf.rmsprop_centered ) def addScalar(self, tag, value, step): self.writer.add_scalar(tag, value, step) def addImage(self, tag, image, step): self.writer.add_image(tag, image, step) @syszux_once def addGraph(self, input): self.writer.add_graph(self.net, input) @syszux_once def smokeTestForExport3rd(self): #exportNCNN must before exportONNX self.exportONNX() self.exportNCNN() self.exportCoreML() #whether export TorchScript via trace, only here we can get self.sample self.exportTorchViaTrace() #compile pytorch state dict to TorchScript self.exportTorchViaScript() self.exportDynamicQuant() self.exportStaticQuant(prepare=True) def earlyIter(self): start = time.time() self.sample = self.sample.to(self.device) self.target = self.target.to(self.device) if not self.is_train: return self.data_cpu2gpu_time.update(time.time() - start) try: self.addGraph(self.sample) except: LOG.logW( "Tensorboard addGraph failed. You network foward may have more than one parameters?" ) LOG.logW("Seems you need reimplement preIter function.") def preIter(self): pass def postIter(self): pass def preEpoch(self): pass def postEpoch(self): pass def doForward(self): self.output = self.net(self.sample) def doCalibrate(self): if self.static_quantized_net_prepared is None: return self.static_quantized_net_prepared(self.sample) def doLoss(self): self.loss = self.criterion(self.output, self.target) def doBackward(self): if self.conf.amp: self.scaler.scale(self.loss).backward() else: self.loss.backward() def doOptimize(self): if self.iter % self.conf.nominal_batch_factor != 0: return if self.conf.amp: self.scaler.step(self.optimizer) self.scaler.update() else: self.optimizer.step() self.optimizer.zero_grad() def doLog(self): if self.step % self.conf.log_every != 0: return self.addScalar('{}/Loss'.format(self.phase), self.loss.item(), self.iter) self.addScalar('{}/LoadDataTime(secs/batch)'.format(self.phase), self.load_data_time.val, self.iter) self.addScalar('{}/DataCpu2GpuTime(secs/batch)'.format(self.phase), self.data_cpu2gpu_time.val, self.iter) self.addScalar('{}/TrainTime(secs/batch)'.format(self.phase), self.train_time.val, self.iter) LOG.logI('{}: [{}][{}/{}] [Loss:{} Lr:{}]'.format( self.phase, self.epoch, self.step, self.loader_len, self.loss.item(), self.optimizer.param_groups[0]['lr'])) def saveState(self, current_time): file_partial_name = '{}__acc_{}__epoch_{}__step_{}__lr_{}'.format( current_time, self.accuracy, self.epoch, self.step, self.optimizer.param_groups[0]['lr']) state_file = '{}/model__{}.pth'.format(self.output_dir, file_partial_name) checkpoint_file = '{}/checkpoint__{}.pth'.format( self.output_dir, file_partial_name) output_trace_file = '{}/trace__{}.pt'.format(self.output_dir, file_partial_name) output_script_file = '{}/script__{}.pt'.format(self.output_dir, file_partial_name) output_onnx_file = '{}/onnx__{}.onnx'.format(self.output_dir, file_partial_name) output_ncnn_file = '{}/ncnn__{}.bin'.format(self.output_dir, file_partial_name) output_coreml_file = '{}/coreml__{}.mlmodel'.format( self.output_dir, file_partial_name) output_dynamic_quant_file = '{}/squant__{}.pt'.format( self.output_dir, file_partial_name) output_static_quant_file = '{}/dquant__{}.pt'.format( self.output_dir, file_partial_name) output_qat_file = '{}/qat__{}.pt'.format(self.output_dir, file_partial_name) #save state_dict torch.save(self.net.state_dict(), state_file) #save checkpoint torch.save( { 'optimizer': self.optimizer.state_dict(), 'epoch': self.epoch, 'scheduler': self.scheduler.state_dict() if self.scheduler else None, 'scaler': self.scaler.state_dict() if self.conf.amp else None }, checkpoint_file) #convert for quantize, must before trace and script!!! self.exportDynamicQuant(output_dynamic_quant_file) self.exportStaticQuant(output_quant_file=output_static_quant_file) self.exportQAT(output_quant_file=output_qat_file) #save pt via trace self.exportTorchViaTrace(self.sample, output_trace_file) #save pt vida script self.exportTorchViaScript(output_script_file) #save onnx self.exportONNX(output_onnx_file) #save ncnn self.exportNCNN(output_ncnn_file) #save coreml self.exportCoreML(output_coreml_file) #tensorboard self.addScalar('{}/Accuracy'.format(self.phase), self.accuracy, self.iter) def processTrain(self): self.setTrainContext() self.step = 0 LOG.logI('Phase {} started...'.format(self.phase)) self.loader_len = len(self.loader) save_every = self.loader_len // self.conf.save_num save_list = list(range(0, self.loader_len + 1, save_every)) self.save_list = save_list[1:-1] LOG.logI('Model will be saved on step {} and the epoch end.'.format( self.save_list)) self.addScalar('{}/LR'.format(self.phase), self.optimizer.param_groups[0]['lr'], self.epoch) self.preEpoch() self.train_time.reset() self.load_data_time.reset() self.data_cpu2gpu_time.reset() start = time.time() for i, (sample, target) in enumerate(self.loader): self.load_data_time.update(time.time() - start) self.step = i self.target = target self.sample = sample self.preIter() self.earlyIter() with autocast(enabled=self.conf.amp if self.conf.amp else False): self.doForward() self.doLoss() self.doBackward() self.doOptimize() self.doLog() self.postIter() self.iter += 1 self.train_time.update(time.time() - start) if self.step in self.save_list: self.processVal() self.setTrainContext() start = time.time() self.addScalar('{}/TrainTime(hours/epoch)'.format(self.phase), round(self.train_time.sum / 3600, 2), self.epoch) self.addScalar( '{}/AverageBatchTrainTime(secs/epoch)'.format(self.phase), self.train_time.avg, self.epoch) self.addScalar( '{}/AverageBatchLoadDataTime(secs/epoch)'.format(self.phase), self.load_data_time.avg, self.epoch) self.addScalar( '{}/AverageBatchDataCpu2GpuTime(secs/epoch)'.format(self.phase), self.data_cpu2gpu_time.avg, self.epoch) self.postEpoch() if self.scheduler: self.scheduler.step() def processVal(self, smoke=False): self.setValContext() LOG.logI('Phase {} started...'.format(self.phase)) with torch.no_grad(): self.preEpoch() for i, (sample, target) in enumerate(self.loader): self.target = target self.sample = sample self.preIter() self.earlyIter() self.doForward() #calibrate only for quantization. self.doCalibrate() self.doLoss() self.smokeTestForExport3rd() LOG.logI('{}: [{}][{}/{}]'.format(self.phase, self.epoch, i, len(self.loader))) self.postIter() if smoke: break self.postEpoch() self.saveState(self.getTime()) def processAccept(self): self.setValContext() def process(self): self.auditConfig() self.iter = 0 epoch_start = self.epoch self.processVal(smoke=True) self.optimizer.zero_grad() for epoch in range(epoch_start, self.conf.epoch_num): self.epoch = epoch LOG.logI('Epoch {} started...'.format(self.epoch)) self.processTrain() self.processVal() self.processAccept() def __call__(self): self.process()
class CustomMTSAC(MTSAC): def __init__( self, policy, qf1, qf2, replay_buffer, env_spec, sampler, train_task_sampler, *, num_tasks, gradient_steps_per_itr, task_update_frequency=1, max_episode_length_eval=None, fixed_alpha=None, target_entropy=None, initial_log_entropy=0., discount=0.99, buffer_batch_size=64, min_buffer_size=10000, target_update_tau=5e-3, policy_lr=3e-4, qf_lr=3e-4, reward_scale=1.0, optimizer=torch.optim.Adam, num_evaluation_episodes=5, # added fp16=False, log_per_task=False, share_train_eval_env=False ): super().__init__( policy=policy, qf1=qf1, qf2=qf2, replay_buffer=replay_buffer, env_spec=env_spec, sampler=sampler, test_sampler=sampler, # not used, for compatibility train_task_sampler=train_task_sampler, num_tasks=num_tasks, gradient_steps_per_itr=gradient_steps_per_itr, max_episode_length_eval=max_episode_length_eval, fixed_alpha=fixed_alpha, target_entropy=target_entropy, initial_log_entropy=initial_log_entropy, discount=discount, buffer_batch_size=buffer_batch_size, min_buffer_size=min_buffer_size, target_update_tau=target_update_tau, policy_lr=policy_lr, qf_lr=qf_lr, reward_scale=reward_scale, optimizer=optimizer, steps_per_epoch=1, num_evaluation_episodes=num_evaluation_episodes, ) self._train_task_sampler = train_task_sampler self._task_update_frequency = task_update_frequency self._fp16 = fp16 self._log_per_task = log_per_task self._total_envsteps = 0 # scalers for fp16 # TODO: don't initialize gradscalers if not using fp16 # Also don't save and/or restore self._gs_qf1 = GradScaler() self._gs_qf2 = GradScaler() self._gs_policy = GradScaler() self._gs_alpha = GradScaler() # get updates for evaluation self.eval_env_updates = self.resample_environment(force_update=True) self.share_train_eval_env = share_train_eval_env if self.share_train_eval_env: logging.warn("WARNING: Sharing train and eval environments") # Fix bug with alpha with optimizer self._use_automatic_entropy_tuning = fixed_alpha is None if self._use_automatic_entropy_tuning: self._alpha_optimizer = optimizer([self._log_alpha], lr=self._policy_lr) def state_dict(self): return { # parameters "policy": self.policy.state_dict(), "qf1": self._qf1.state_dict(), "qf2": self._qf2.state_dict(), "target_qf1": self._target_qf1.state_dict(), "target_qf2": self._target_qf2.state_dict(), "log_alpha": self._log_alpha, # scalers "gs_qf1": self._gs_qf1.state_dict(), "gs_qf2": self._gs_qf2.state_dict(), "gs_policy": self._gs_policy.state_dict(), "gs_alpha": self._gs_alpha.state_dict(), # optimizers "policy_optimizer": self._policy_optimizer.state_dict(), "qf1_optimizer": self._qf1_optimizer.state_dict(), "qf2_optimizer": self._qf2_optimizer.state_dict(), "alpha_optimizer": self._alpha_optimizer.state_dict(), # other variables "replay_buffer": self.replay_buffer, "total_envsteps": self._total_envsteps, } def load_env_state(self, env_state): self.eval_env_updates = env_state def load_state(self, state): # parameters self.policy.load_state_dict(state["policy"]) self._qf1.load_state_dict(state["qf1"]) self._qf2.load_state_dict(state["qf2"]) self._target_qf1.load_state_dict(state["target_qf1"]) self._target_qf2.load_state_dict(state["target_qf2"]) self._log_alpha.data = state["log_alpha"] # scalers self._gs_qf1.load_state_dict(state["gs_qf1"]) self._gs_qf2.load_state_dict(state["gs_qf2"]) self._gs_policy.load_state_dict(state["gs_policy"]) self._gs_alpha.load_state_dict(state["gs_alpha"]) # optimizers self._policy_optimizer.load_state_dict(state["policy_optimizer"]) self._qf1_optimizer.load_state_dict(state["qf1_optimizer"]) self._qf2_optimizer.load_state_dict(state["qf2_optimizer"]) self._alpha_optimizer.load_state_dict(state["alpha_optimizer"]) # other variables self.replay_buffer = state["replay_buffer"] self._total_envsteps = state["total_envsteps"] def get_updated_policy(self, policy_hook=None): with torch.no_grad(): updated_policy = copy.deepcopy(self.policy) updated_policy.eval() # attach hooks if policy_hook: policy_hook(updated_policy) return updated_policy def update_buffer(self, trajectories): """Update Buffer""" self._total_envsteps += sum(trajectories.lengths) path_returns = [] for path in trajectories.to_list(): self.replay_buffer.add_path(dict( observation=path["observations"], action=path["actions"], reward=path["rewards"].reshape(-1, 1), next_observation=path["next_observations"], terminal=np.array([ step_type == StepType.TERMINAL for step_type in path["step_types"] ]).reshape(-1, 1) )) path_returns.append(sum(path["rewards"])) self.episode_rewards.append(np.mean(path_returns)) def resample_environment(self, epoch=0, force_update=False): """ TODO: fix env update in sampler Intended behavior: if epoch % self._task_update_frequency == 0 or force_update: return self._train_task_sampler.sample(self._num_tasks) """ # TODO: remove first line to allow force update if epoch % self._task_update_frequency == 0 or force_update: return self._train_task_sampler.sample(self._num_tasks) def run_epoch(self, epoch, env_steps_per_epoch): """ Run one epoch, which is composed of one N sample collections and N training steps. Each training step in their turn is composed of M gradient steps of batch size B Total number of samples used by the algorithm in a epoch is given by N * M * B (steps * gradient_steps * batch size) Samples collected are only used to update the buffer, and there is no direct influence on number of gradient steps or batch size. Returns: float: The average return in last epoch cycle. """ t0 = time() env_updates = ( self.eval_env_updates if self.share_train_eval_env else self.resample_environment(epoch) ) new_trajectories = self._sampler.obtain_samples( num_samples=env_steps_per_epoch, agent_update=self.get_updated_policy(), env_updates=env_updates, ) self.update_buffer(new_trajectories) t1 = time() total_losses = self.run_step() time_to_collect_samples = t1 - t0 time_to_update_gradient = time() - t1 log_dict = self._log_statistics(*total_losses) # TODO: switch to logger.debug once logger is fixed logging.warn(f"Time to collect samples: {time_to_collect_samples:.2f}") logging.warn(f"Time to update gradient: {time_to_update_gradient:.2f}") return log_dict def run_step(self): """ Run one training step, which is composed of M gradient steps For M gradients steps: - sample a batch from buffer - perform one gradient step in all three networks (policy, qf1 and qf2) """ total_losses = [0, 0, 0] for _ in range(self._gradient_steps): if self.replay_buffer.n_transitions_stored >= self._min_buffer_size: samples = as_torch_dict(self.replay_buffer.sample_transitions( self._buffer_batch_size )) policy_loss, qf1_loss, qf2_loss = self.optimize_policy(samples) total_losses[0] += policy_loss total_losses[1] += qf1_loss total_losses[2] += qf2_loss self._update_targets() # Normalize losses by total of gradient updates total_losses = [loss / self._gradient_steps for loss in total_losses] return total_losses def _evaluate_policy(self, epoch, policy_hook=None): """Evaluate the performance of the policy via deterministic sampling. Statistics such as (average) discounted return and success rate are recorded. Args: epoch (int): The current training epoch. Returns: float: The average return across self._num_evaluation_episodes episodes """ t0 = time() # Collect episodes for evaluation eval_trajectories, policy_hook_data = self._sampler.obtain_exact_episodes( n_eps_per_worker=self._num_evaluation_episodes, agent_update=self.get_updated_policy(policy_hook=policy_hook), env_updates=self.eval_env_updates, ) # Log performance undiscounted_returns, log_dict = log_multitask_performance( epoch, batch=eval_trajectories, discount=self._discount, log_per_task=self._log_per_task ) log_dict["average_return"] = np.mean(undiscounted_returns) logging.warn(f"Time to evaluate policy: {time()-t0:.2f}") return undiscounted_returns, log_dict, policy_hook_data def _log_statistics(self, policy_loss, qf1_loss, qf2_loss): """Record training statistics to dowel such as losses and returns. Args: policy_loss (torch.Tensor): loss from actor/policy network. qf1_loss (torch.Tensor): loss from 1st qf/critic network. qf2_loss (torch.Tensor): loss from 2nd qf/critic network. """ log_dict = {} with torch.no_grad(): log_dict["AlphaTemperature/mean"] = self._log_alpha.exp().mean().item() log_dict["Policy/Loss"] = policy_loss.item() log_dict["QF/{}".format("Qf1Loss")] = float(qf1_loss) log_dict["QF/{}".format("Qf2Loss")] = float(qf2_loss) log_dict["ReplayBuffer/buffer_size"] = self.replay_buffer.n_transitions_stored log_dict["Average/TrainAverageReturn"] = np.mean(self.episode_rewards) log_dict["TotalEnvSteps"] = self._total_envsteps return log_dict def _get_log_alpha(self, samples_data): """Return the value of log_alpha. Args: samples_data (dict): Transitions(S,A,R,S') that are sampled from the replay buffer. It should have the keys 'observation', 'action', 'reward', 'terminal', and 'next_observations'. Note: samples_data's entries should be torch.Tensor's with the following shapes: observation: :math:`(N, O^*)` action: :math:`(N, A^*)` reward: :math:`(N, 1)` terminal: :math:`(N, 1)` next_observation: :math:`(N, O^*)` Raises: ValueError: If the number of tasks, num_tasks passed to this algorithm doesn't match the length of the task one-hot id in the observation vector. Returns: torch.Tensor: log_alpha. shape is (1, self.buffer_batch_size) """ obs = samples_data["observation"] log_alpha = self._log_alpha one_hots = obs[:, -self._num_tasks:] if (log_alpha.shape[0] != one_hots.shape[1] or one_hots.shape[1] != self._num_tasks or log_alpha.shape[0] != self._num_tasks): raise ValueError( "The number of tasks in the environment does " "not match self._num_tasks. Are you sure that you passed " "The correct number of tasks?") with autocast(enabled=self._fp16): return torch.mm(one_hots, log_alpha.unsqueeze(0).t()).squeeze() def _temperature_objective(self, log_pi, samples_data): """Compute the temperature/alpha coefficient loss. Args: log_pi(torch.Tensor): log probability of actions that are sampled from the replay buffer. Shape is (1, buffer_batch_size). samples_data (dict): Transitions(S,A,R,S') that are sampled from the replay buffer. It should have the keys 'observation', 'action', 'reward', 'terminal', and 'next_observations'. Note: samples_data's entries should be torch.Tensor's with the following shapes: observation: :math:`(N, O^*)` action: :math:`(N, A^*)` reward: :math:`(N, 1)` terminal: :math:`(N, 1)` next_observation: :math:`(N, O^*)` Returns: torch.Tensor: the temperature/alpha coefficient loss. """ alpha_loss = 0 with autocast(enabled=self._fp16): if self._use_automatic_entropy_tuning: alpha_loss = (-(self._get_log_alpha(samples_data)) * (log_pi.detach() + self._target_entropy)).mean() return alpha_loss def _actor_objective(self, samples_data, new_actions, log_pi_new_actions): """Compute the Policy/Actor loss. Args: samples_data (dict): Transitions(S,A,R,S') that are sampled from the replay buffer. It should have the keys 'observation', 'action', 'reward', 'terminal', and 'next_observations'. new_actions (torch.Tensor): Actions resampled from the policy based based on the Observations, obs, which were sampled from the replay buffer. Shape is (action_dim, buffer_batch_size). log_pi_new_actions (torch.Tensor): Log probability of the new actions on the TanhNormal distributions that they were sampled from. Shape is (1, buffer_batch_size). Note: samples_data's entries should be torch.Tensor's with the following shapes: observation: :math:`(N, O^*)` action: :math:`(N, A^*)` reward: :math:`(N, 1)` terminal: :math:`(N, 1)` next_observation: :math:`(N, O^*)` Returns: torch.Tensor: loss from the Policy/Actor. """ obs = samples_data["observation"] with torch.no_grad(): alpha = self._get_log_alpha(samples_data).exp() with autocast(enabled=self._fp16): min_q_new_actions = torch.min(self._qf1(obs, new_actions), self._qf2(obs, new_actions)) policy_objective = ((alpha * log_pi_new_actions) - min_q_new_actions.flatten()).mean() return policy_objective def _critic_objective(self, samples_data): """Compute the Q-function/critic loss. Args: samples_data (dict): Transitions(S,A,R,S') that are sampled from the replay buffer. It should have the keys 'observation', 'action', 'reward', 'terminal', and 'next_observations'. Note: samples_data's entries should be torch.Tensor's with the following shapes: observation: :math:`(N, O^*)` action: :math:`(N, A^*)` reward: :math:`(N, 1)` terminal: :math:`(N, 1)` next_observation: :math:`(N, O^*)` Returns: torch.Tensor: loss from 1st q-function after optimization. torch.Tensor: loss from 2nd q-function after optimization. """ obs = samples_data["observation"] actions = samples_data["action"] rewards = samples_data["reward"].flatten() terminals = samples_data["terminal"].flatten() next_obs = samples_data["next_observation"] with torch.no_grad(): alpha = self._get_log_alpha(samples_data).exp() with autocast(enabled=self._fp16): q1_pred = self._qf1(obs, actions) q2_pred = self._qf2(obs, actions) new_next_actions_dist = self.policy(next_obs)[0] new_next_actions_pre_tanh, new_next_actions = ( new_next_actions_dist.rsample_with_pre_tanh_value()) new_log_pi = new_next_actions_dist.log_prob( value=new_next_actions, pre_tanh_value=new_next_actions_pre_tanh ) target_q_values = torch.min( self._target_qf1(next_obs, new_next_actions), self._target_qf2(next_obs, new_next_actions) ).flatten() - (alpha * new_log_pi) with torch.no_grad(): q_target = rewards * self._reward_scale + ( 1. - terminals) * self._discount * target_q_values qf1_loss = F.mse_loss(q1_pred.flatten(), q_target) qf2_loss = F.mse_loss(q2_pred.flatten(), q_target) return qf1_loss, qf2_loss def optimize_policy(self, samples_data): """Optimize the policy q_functions, and temperature coefficient. Rezero model weights (if applicable) after each optimizer step. Args: samples_data (dict): Transitions(S,A,R,S') that are sampled from the replay buffer. It should have the keys 'observation', 'action', 'reward', 'terminal', and 'next_observations'. Note: samples_data's entries should be torch.Tensor's with the following shapes: observation: :math:`(N, O^*)` action: :math:`(N, A^*)` reward: :math:`(N, 1)` terminal: :math:`(N, 1)` next_observation: :math:`(N, O^*)` Returns: torch.Tensor: loss from actor/policy network after optimization. torch.Tensor: loss from 1st q-function after optimization. torch.Tensor: loss from 2nd q-function after optimization. """ if self._fp16: return self.optimize_policy_with_autocast(samples_data) obs = samples_data["observation"] qf1_loss, qf2_loss = self._critic_objective(samples_data) self._qf1_optimizer.zero_grad() qf1_loss.backward() self._qf1_optimizer.step() self._qf1.apply(rezero_weights) self._qf2_optimizer.zero_grad() qf2_loss.backward() self._qf2_optimizer.step() self._qf2.apply(rezero_weights) action_dists = self.policy(obs)[0] new_actions_pre_tanh, new_actions = ( action_dists.rsample_with_pre_tanh_value()) log_pi_new_actions = action_dists.log_prob( value=new_actions, pre_tanh_value=new_actions_pre_tanh) policy_loss = self._actor_objective(samples_data, new_actions, log_pi_new_actions) self._policy_optimizer.zero_grad() policy_loss.backward() self._policy_optimizer.step() self.policy.apply(rezero_weights) if self._use_automatic_entropy_tuning: alpha_loss = self._temperature_objective(log_pi_new_actions, samples_data) self._alpha_optimizer.zero_grad() alpha_loss.backward() self._alpha_optimizer.step() return policy_loss, qf1_loss, qf2_loss def optimize_policy_with_autocast(self, samples_data): """Optimize the policy q_functions, and temperature coefficient. Rezero model weights (if applicable) after each optimizer step. Args: samples_data (dict): Transitions(S,A,R,S') that are sampled from the replay buffer. It should have the keys 'observation', 'action', 'reward', 'terminal', and 'next_observations'. Note: samples_data's entries should be torch.Tensor's with the following shapes: observation: :math:`(N, O^*)` action: :math:`(N, A^*)` reward: :math:`(N, 1)` terminal: :math:`(N, 1)` next_observation: :math:`(N, O^*)` Returns: torch.Tensor: loss from actor/policy network after optimization. torch.Tensor: loss from 1st q-function after optimization. torch.Tensor: loss from 2nd q-function after optimization. """ obs = samples_data["observation"] qf1_loss, qf2_loss = self._critic_objective(samples_data) self._qf1_optimizer.zero_grad() self._gs_qf1.scale(qf1_loss).backward() self._gs_qf1.step(self._qf1_optimizer) self._gs_qf1.update() self._qf1.apply(rezero_weights) self._qf2_optimizer.zero_grad() self._gs_qf2.scale(qf2_loss).backward() self._gs_qf2.step(self._qf2_optimizer) self._gs_qf2.update() self._qf2.apply(rezero_weights) with autocast(): action_dists = self.policy(obs)[0] new_actions_pre_tanh, new_actions = ( action_dists.rsample_with_pre_tanh_value() ) log_pi_new_actions = action_dists.log_prob( value=new_actions, pre_tanh_value=new_actions_pre_tanh) policy_loss = self._actor_objective(samples_data, new_actions, log_pi_new_actions) self._policy_optimizer.zero_grad() self._gs_policy.scale(policy_loss).backward() self._gs_policy.step(self._policy_optimizer) self._gs_policy.update() self.policy.apply(rezero_weights) if self._use_automatic_entropy_tuning: alpha_loss = self._temperature_objective(log_pi_new_actions, samples_data) self._alpha_optimizer.zero_grad() self._gs_alpha.scale(alpha_loss).backward() self._gs_alpha.step(self._alpha_optimizer) self._gs_alpha.update() return policy_loss, qf1_loss, qf2_loss def shutdown_worker(self): """Shutdown Plotter and Sampler workers.""" self._sampler.shutdown_worker()
def prepare_optimizers(args, model, checkpoint, global_steps): param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'gamma', 'beta', 'LayerNorm'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] if args.lr_decay == 'poly': Scheduler = PolyWarmUpScheduler elif args.lr_decay == 'linear': Scheduler = LinearWarmUpScheduler else: raise ValueError('Unknown lr decay "{}"'.format(args.lr_decay)) optimizer = FusedLAMB(optimizer_grouped_parameters, lr=args.learning_rate) if checkpoint is not None: if args.resume_step >= args.previous_phase_end_step: keys = list(checkpoint['optimizer']['state'].keys()) # Override hyperparameters from previous checkpoint for key in keys: checkpoint['optimizer']['state'][key]['step'] = global_steps for i, item in enumerate(checkpoint['optimizer']['param_groups']): checkpoint['optimizer']['param_groups'][i][ 'step'] = global_steps checkpoint['optimizer']['param_groups'][i][ 't_total'] = args.max_steps checkpoint['optimizer']['param_groups'][i][ 'warmup'] = args.warmup_proportion checkpoint['optimizer']['param_groups'][i][ 'lr'] = args.learning_rate optimizer.load_state_dict(checkpoint['optimizer']) lr_schedulers = [ Scheduler(optimizer, warmup=args.warmup_proportion, total_steps=args.max_steps) ] scaler = None if args.fp16: scaler = GradScaler() if checkpoint is not None and 'scaler' in checkpoint: scaler.load_state_dict(checkpoint['scaler']) preconditioner = None if args.kfac: preconditioner = kfac.KFAC( model, lr=args.learning_rate, factor_decay=args.kfac_stat_decay, damping=args.kfac_damping, kl_clip=args.kfac_kl_clip, factor_update_freq=args.kfac_factor_interval, inv_update_freq=args.kfac_inv_interval, # Skip TrainingHeads which contains the decoder, a Linear module # with shape (seq_len, vocab_size), such that it is too large to invert skip_layers=args.kfac_skip_layers, # BERT calls KFAC very infrequently so no need to optimize for # communication. Optimize for memory instead. comm_method=kfac.CommMethod.HYBRID_OPT, grad_worker_fraction=0.5, inv_dtype=torch.float16, # Compute the factors and update the running averages during the # forward backward pass b/c we are using grad accumulation but # not accumulating the input/output data accumulate_data=False, compute_factor_in_hook=True, distribute_layer_factors=False, grad_scaler=scaler, ) lrs = Scheduler(preconditioner, warmup=args.warmup_proportion, total_steps=args.max_steps) lr_schedulers.append(lrs) if checkpoint is not None and 'preconditioner' in checkpoint: preconditioner.load_state_dict(checkpoint['preconditioner']) if is_main_process(): logger.info(preconditioner) return optimizer, preconditioner, lr_schedulers, scaler
batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) valloader = DataLoader(valdataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) scaler = GradScaler() if args.resume: ckpt = torch.load(os.path.join(args.save_dir, 'recorder_2.pt')) model.load_state_dict(ckpt['model']) optimizer.load_state_dict(ckpt['optimizer']) scheduler.load_state_dict(ckpt['scheduler']) scaler.load_state_dict(ckpt['scaler']) if args.resume: best_loss = scheduler.best else: best_loss = np.inf save_recorder = 5 for epoch in range(args.epochs): print(f'Epoch {epoch+1}/{args.epochs}') train_loss, train_acc = train_one_epoch(trainloader, model, criterion, optimizer, scaler, device, args, epoch)
class Trainer: """Model trainer Args: model: model to train loss_fn: loss function optimizer: model optimizer generator: pretrained generator projector: pretrained projector device: device to train the model on batch_size: number of batch elements iterations: number of iterations scheduler: learning rate scheduler grad_clip_max_norm: gradient clipping max norm (disabled if None) writer: writer which logs metrics to TensorBoard (disabled if None) save_path: folder in which to save models (disabled if None) checkpoint_path: path to model checkpoint, to resume training mixed_precision: enable mixed precision training """ def __init__( self, model: torch.nn.Module, loss_fn: torch.nn.Module, optimizer: torch.optim.Optimizer, generator: Generator, projector: torch.nn.Module, batch_size: int, iterations: int, device: torch.device, eval_freq: int = 1000, eval_iters: int = 100, scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, grad_clip_max_norm: Optional[float] = None, writer: Optional[SummaryWriter] = None, save_path: Optional[str] = None, checkpoint_path: Optional[str] = None, mixed_precision: bool = False, train_projector: bool = True, feed_layers: Optional[List[int]] = None, ) -> None: # Logging self.logger = logging.getLogger() self.writer = writer # Saving self.save_path = save_path # Device self.device = device # Model self.model = model self.loss_fn = loss_fn self.optimizer = optimizer self.generator = generator self.projector = projector self.train_projector = train_projector self.feed_layers = feed_layers # Eval self.eval_freq = eval_freq self.eval_iters = eval_iters # Scheduler self.scheduler = scheduler self.grad_clip_max_norm = grad_clip_max_norm # Batch & Iteration self.batch_size = batch_size self.iterations = iterations self.start_iteration = 0 # Floating-point precision self.mixed_precision = (True if self.device.type == "cuda" and mixed_precision else False) self.scaler = GradScaler() if self.mixed_precision else None if checkpoint_path: self._load_from_checkpoint(checkpoint_path) # Metrics self.train_acc_metric = LossMetric() self.train_loss_metric = LossMetric() self.val_acc_metric = LossMetric() self.val_loss_metric = LossMetric() # Best self.best_loss = -1 def train(self) -> None: """Trains the model""" self.logger.info("Beginning training") start_time = time.time() epoch = 0 iteration = self.start_iteration while iteration < self.iterations: if iteration + self.eval_freq < self.iterations: num_iters = self.eval_freq else: num_iters = self.iterations - iteration start_epoch_time = time.time() if self.mixed_precision: self._train_loop_amp(epoch, num_iters) else: self._train_loop(epoch, num_iters) self._val_loop(epoch, self.eval_iters) epoch_time = time.time() - start_epoch_time self._end_loop(epoch, epoch_time, iteration) iteration += num_iters epoch += 1 train_time_h = (time.time() - start_time) / 3600 self.logger.info(f"Finished training! Total time: {train_time_h:.2f}h") self._save_model(os.path.join(self.save_path, "final_model.pt"), self.iterations) def _train_loop(self, epoch: int, iterations: int) -> None: """ Regular train loop Args: epoch: current epoch iterations: iterations to run model """ # Progress bar pbar = tqdm.tqdm(total=iterations, leave=False) pbar.set_description(f"Epoch {epoch} | Train") # Set to train self.model.train() # Set to eval self.generator.eval() if self.train_projector: self.projector.train() else: self.projector.eval() for i in range(iterations): # To device z = self.generator.sample_latent(self.batch_size) z = z.to(self.device) z_orig = z # Original features with torch.no_grad(): orig_feats = self.generator.get_features(z) orig_feats = self.projector(orig_feats) # Apply Directions self.optimizer.zero_grad() z = self.model(z) # Forward features = [] for j in range(z.shape[0] // self.batch_size): # Prepare batch start, end = j * self.batch_size, (j + 1) * self.batch_size z_batch = z[start:end, ...] # Manipulate only asked layers if self.feed_layers is not None: n_latent = self.generator.n_latent() z_batch_layers = [] for i in range(n_latent): if i in self.feed_layers: z_batch_layers.append(z_batch) else: z_batch_layers.append(z_orig) z_batch = z_batch_layers # Get features feats = self.generator.get_features(z_batch) feats = self.projector(feats) # Take feature divergence feats = feats - orig_feats feats = feats / torch.reshape(torch.norm(feats, dim=1), (-1, 1)) features.append(feats) features = torch.cat(features, dim=0) # Loss acc, loss = self.loss_fn(features) loss.backward() if self.grad_clip_max_norm is not None: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_max_norm) self.optimizer.step() self.scheduler.step() # Update metrics self.train_acc_metric.update(acc.item(), z.shape[0]) self.train_loss_metric.update(loss.item(), z.shape[0]) # Update progress bar pbar.update() pbar.set_postfix_str( f"Acc: {acc.item():.3f} Loss: {loss.item():.3f}", refresh=False) pbar.close() def _train_loop_amp(self, epoch: int, iterations: int) -> None: """ Train loop with Automatic Mixed Precision Args: epoch: current epoch iterations: iterations to run model """ # Progress bar pbar = tqdm.tqdm(total=len(iterations), leave=False) pbar.set_description(f"Epoch {epoch} | Train") # Set to train self.model.train() # Loop for i in range(iterations): # To device z = self.generator.sample_latent(self.batch_size) z = z.to(self.device) # Forward + backward self.optimizer.zero_grad() # Use amp in forward pass with autocast(): # Original features with torch.no_grad(): orig_feats = self.generator.get_features(z) orig_feats = self.projector(orig_feats) # Apply Directions z = self.model(z) # Forward features = [] for j in range(z.shape[0] // self.batch_size): # Prepare batch start, end = j * self.batch_size, (j + 1) * self.batch_size # Get features feats = self.generator.get_features(z[start:end, ...]) feats = self.projector(feats) # Take feature divergence feats = feats - orig_feats feats = feats / torch.reshape(torch.norm(feats, dim=1), (-1, 1)) features.append(feats) features = torch.cat(features, dim=0) # Loss acc, loss = self.loss_fn(features) # Backward pass with scaler self.scaler.scale(loss).backward() # Unscale before gradient clipping self.scaler.unscale_(self.optimizer) if self.grad_clip_max_norm is not None: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_max_norm) # Update optimizer and scaler self.scaler.step(self.optimizer) self.scaler.update() self.scheduler.step() # Update metrics self.train_acc_metric.update(acc.item(), z.shape[0]) self.train_loss_metric.update(loss.item(), z.shape[0]) # Update progress bar pbar.update() pbar.set_postfix_str( f"Acc: {acc.item():.3f} Loss: {loss.item():.3f}", refresh=False) pbar.close() def _val_loop(self, epoch: int, iterations: int) -> None: """ Standard validation loop Args: epoch: current epoch iterations: iterations to run model """ # Progress bar pbar = tqdm.tqdm(total=iterations, leave=False) pbar.set_description(f"Epoch {epoch} | Validation") # Set to eval self.model.eval() self.generator.eval() self.projector.eval() # Loop for i in range(iterations): with torch.no_grad(): # To device z = self.generator.sample_latent(self.batch_size) z = z.to(self.device) # Original features orig_feats = self.generator.get_features(z) orig_feats = self.projector(orig_feats) # Apply Directions z = self.model(z) # Forward features = [] for j in range(z.shape[0] // self.batch_size): # Prepare batch start, end = j * self.batch_size, (j + 1) * self.batch_size # Get features feats = self.generator.get_features(z[start:end, ...]) feats = self.projector(feats) # Take feature divergence feats = feats - orig_feats feats = feats / torch.reshape(torch.norm(feats, dim=1), (-1, 1)) features.append(feats) features = torch.cat(features, dim=0) # Loss acc, loss = self.loss_fn(features) self.val_acc_metric.update(acc.item(), z.shape[0]) self.val_loss_metric.update(loss.item(), z.shape[0]) # Update progress bar pbar.update() pbar.set_postfix_str( f"Acc: {acc.item():.3f} Loss: {loss.item():.3f}", refresh=False) pbar.close() def _end_loop(self, epoch: int, epoch_time: float, iteration: int): # Print epoch results self.logger.info(self._epoch_str(epoch, epoch_time)) # Write to tensorboard if self.writer is not None: self._write_to_tb(epoch) # Save model if self.save_path is not None: self._save_model(os.path.join(self.save_path, "most_recent.pt"), iteration) eval_loss = self.val_loss_metric.compute() if self.best_loss == -1 or eval_loss < self.best_loss: self.best_loss = eval_loss self._save_model(os.path.join(self.save_path, "best_model.pt"), iteration) # Clear metrics self.train_loss_metric.reset() self.train_acc_metric.reset() self.val_loss_metric.reset() self.val_acc_metric.reset() def _epoch_str(self, epoch: int, epoch_time: float): s = f"Epoch {epoch} " s += f"| Train acc: {self.train_acc_metric.compute():.3f} " s += f"| Train loss: {self.train_loss_metric.compute():.3f} " s += f"| Val acc: {self.val_acc_metric.compute():.3f} " s += f"| Val loss: {self.val_loss_metric.compute():.3f} " s += f"| Epoch time: {epoch_time:.1f}s" return s def _write_to_tb(self, iteration): self.writer.add_scalar("Loss/train", self.train_loss_metric.compute(), iteration) self.writer.add_scalar("Acc/train", self.train_acc_metric.compute(), iteration) self.writer.add_scalar("Loss/val", self.val_loss_metric.compute(), iteration) self.writer.add_scalar("Acc/val", self.val_acc_metric.compute(), iteration) def _save_model(self, path, iteration): obj = { "iteration": iteration + 1, "optimizer": self.optimizer.state_dict(), "model": self.model.state_dict(), "projector": self.projector.state_dict(), "scheduler": self.scheduler.state_dict() if self.scheduler is not None else None, "scaler": self.scaler.state_dict() if self.mixed_precision else None, } torch.save(obj, os.path.join(self.save_path, path)) def _load_from_checkpoint(self, checkpoint_path: str) -> None: checkpoint = torch.load(checkpoint_path, map_location=self.device) self.model.load_state_dict(checkpoint["model"]) self.projector.load_state_dict(checkpoint["projector"]) self.optimizer.load_state_dict(checkpoint["optimizer"]) self.start_iteration = checkpoint["iteration"] if self.scheduler: self.scheduler.load_state_dict(checkpoint["scheduler"]) if self.mixed_precision and "scaler" in checkpoint: self.scaler.load_state_dict(checkpoint["scheduler"]) if self.start_iteration > self.iterations: raise ValueError( "Starting iteration is larger than total iterations") self.logger.info( f"Checkpoint loaded, resuming from iteration {self.start_iteration}" )
class ClassificationTask(ClassyTask): """Basic classification training task. This task encapsultates all of the components and steps needed to train a classifier using a :class:`classy_vision.trainer.ClassyTrainer`. Assumes a train / test phase per each epoch and that the datasets have the same API as the map-style Dataset class in `torch.utils.data.dataset <https://pytorch.org/docs/stable/data.html #torch.utils.data.Dataset>`_ (in particular, this task makes use of the len). If you are using an `IterableDataset <https://pytorch.org/docs/ stable/data.html#torch.utils.data.IterableDataset>`_ then a custom task may be appropriate. :var loss: Loss (see :class:`classy_vision.losses.ClassyLoss`) function used for computing the loss in each forward pass :var datasets: Mapping from a ``phase_type`` in ["train", "test'] to dataset used for training (or testing) :var meters: List of meters (see :class:`classy_vision.meters.ClassyMeter`) to calculate during training :var num_epochs: Number of epochs (passes over dataset) to train :var test_only: Used to only run the test phase :var base_model: Model to be trained, unwrapped in DDP or DP wrappers :var optimizer: Optimizer used in train step :var optimizer_schedulers: Dictionary. Key is the name of the optimizer option (e.g. lr), value is a ClassyParamScheduler :var checkpoint: Serializable dict which represents state in training :var phases: List of phase specific information, e.g. if phase is train / test. :var hooks: List of hooks to apply during training :var train: Phase type, if true it means we are training, false means testing :var distributed_model: Base model, but wrapped in DDP (DistributedDataParallel) :var phase_idx: Current phase id, first phase is 0, if task has not started training then returns -1 :var train_phase_idx: Only counts train phases :var num_updates: Number of total parameter updates applied to model by the optimizer :var data_iterator: Iterator which can be used to obtain batches :var losses: Loss curve :var perf_log: list of training speed measurements, to be logged :var clip_grad_norm: maximum gradient norm (default None) :var simulated_global_batchsize: batch size simulated via gradient accumulation :var optimizer_period: apply optimizer after this many steps; derived from simulated_global_batchsize, default 1. """ def __init__(self): """Constructs a ClassificationTask""" super().__init__() self.base_loss = None self.datasets = {} self.meters = [] self.num_epochs = 1 self.test_phase_period = 1 self.train_phases_per_epoch = 0 self.test_only = False self.base_model = None self.optimizer = None self.optimizer_schedulers = {} self.checkpoint_dict = None self.checkpoint_path = None self.checkpoint_load_strict = True self.phases = [] self.hooks = [] self.train = True self.distributed_model = None self.distributed_loss = None self.phase_idx = -1 self.train_phase_idx = -1 self.num_updates = 0 self.dataloader = None self.data_iterator = None self.losses = [] self.broadcast_buffers_mode: BroadcastBuffersMode = ( BroadcastBuffersMode.BEFORE_EVAL) self.amp_args = None self.amp_type = None self.amp_grad_scaler = None self.mixup_transform = None self.perf_log = [] self.last_batch = None self.batch_norm_sync_mode = BatchNormSyncMode.DISABLED self.find_unused_parameters = False self.use_gpu = torch.cuda.is_available() self.dataloader_mp_context = "spawn" self.bn_weight_decay = False self._train_only = True self.clip_grad_norm = None self.simulated_global_batchsize = None self.optimizer_period = 1 self.ddp_bucket_cap_mb = 25 self.use_sharded_ddp = False self.fp16_grad_compress = False def set_use_sharded_ddp(self, use_sharded_ddp: bool): self.use_sharded_ddp = use_sharded_ddp if self.use_sharded_ddp: logging.info("Using Sharded DDP") return self def set_use_gpu(self, use_gpu: bool): self.use_gpu = use_gpu assert (not self.use_gpu or torch.cuda.is_available()), "CUDA required to train on GPUs" return self def set_clip_grad_norm(self, clip_grad_norm: Optional[float]): """Sets maximum gradient norm. None means gradient clipping is disabled. Defaults to None.""" self.clip_grad_norm = clip_grad_norm if clip_grad_norm is None: logging.info("Disabled gradient norm clipping.") else: logging.info( f"Enabled gradient norm clipping with threshold: {clip_grad_norm}" ) return self def set_simulated_global_batchsize( self, simulated_global_batchsize: Optional[int]): """Sets a simulated batch size by gradient accumulation. Gradient accumulation adds up gradients from multiple minibatches and steps the optimizer every N train_steps, where N is optimizer_period. When enabled, the very last train_steps might end up not updating the model, depending on the number of total steps. None means gradient accumulation is disabled. Defaults to None.""" self.simulated_global_batchsize = simulated_global_batchsize return self def set_checkpoint(self, checkpoint_path: str): """Sets checkpoint on task. Args: checkpoint_path: The path to load the checkpoint from. Can be a file or a directory. See :func:`load_checkpoint` for more information. """ self.checkpoint_path = checkpoint_path return self def set_checkpoint_load_strict(self, checkpoint_load_strict: bool): """Sets checkpoint on task. Args: checkpoint_load_strict: Whether to use load_strict when copying model weights """ self.checkpoint_load_strict = checkpoint_load_strict return self def _set_checkpoint_dict(self, checkpoint_dict: Dict[str, Any]): """Sets the checkpoint dict in the task. Only used for testing. Args: checkpoint_dict: A serializable dict representing current task state """ self.checkpoint_dict = checkpoint_dict return self def set_num_epochs(self, num_epochs: Union[int, float]): """Set number of epochs to be run. Args: num_epochs: Number of epochs to run task """ self.num_epochs = num_epochs return self def set_test_phase_period(self, test_phase_period: int): """Set the period of test phase. Args: test_phase_period: The period of test phase """ self.test_phase_period = test_phase_period return self def set_dataset(self, dataset: ClassyDataset, phase_type: str): """Set dataset for phase type on task Args: dataset: ClassyDataset for returning samples. phase_type: str must be one of "train" or "test" """ assert phase_type in [ "train", "test", ], "phase_type must be in ['train', 'test']" self.datasets[phase_type] = dataset if phase_type == "train": self.train_phases_per_epoch = getattr(dataset, "phases_per_epoch", 1) else: self._train_only = False return self def set_dataloader_mp_context(self, dataloader_mp_context: Optional[str]): """Set the multiprocessing context used by the dataloader. The context can be either 'spawn', 'fork', 'forkserver' or None (uses the default context). See https://docs.python.org/3/library/multiprocessing.html#multiprocessing.get_context for more details.""" self.dataloader_mp_context = dataloader_mp_context return self def set_optimizer(self, optimizer: ClassyOptimizer): """Set optimizer for task Args: optimizer: optimizer for task """ self.optimizer = optimizer return self def set_loss(self, loss: ClassyLoss): """Set loss function for task Args: loss: loss for task """ self.base_loss = loss return self def set_meters(self, meters: List["ClassyMeter"]): """Set meters for task Args: meters: list of meters to compute during training """ self.meters = meters return self def set_distributed_options( self, broadcast_buffers_mode: BroadcastBuffersMode = BroadcastBuffersMode. BEFORE_EVAL, batch_norm_sync_mode: BatchNormSyncMode = BatchNormSyncMode.DISABLED, batch_norm_sync_group_size: int = 0, find_unused_parameters: bool = False, bucket_cap_mb: int = 25, fp16_grad_compress: bool = False, ): """Set distributed options. Args: broadcast_buffers_mode: Broadcast buffers mode. See :class:`BroadcastBuffersMode` for options. batch_norm_sync_mode: Batch normalization synchronization mode. See :class:`BatchNormSyncMode` for options. batch_norm_sync_group_size: Group size to use for synchronized batch norm. 0 means that the stats are synchronized across all replicas. For efficient synchronization, set it to the number of GPUs in a node ( usually 8). find_unused_parameters: See :class:`torch.nn.parallel.DistributedDataParallel` for information. bucket_cap_mb: See :class:`torch.nn.parallel.DistributedDataParallel` for information. Raises: RuntimeError: If batch_norm_sync_mode is `BatchNormSyncMode.APEX` and apex is not installed. """ self.broadcast_buffers_mode = broadcast_buffers_mode if batch_norm_sync_group_size > 0: if not batch_norm_sync_mode == BatchNormSyncMode.APEX: # this should ideally work with PyTorch Sync BN as well, but it # fails while initializing DDP for some reason. raise ValueError( "batch_norm_sync_group_size can be > 0 only when " "Apex Synchronized Batch Normalization is being used.") self.batch_norm_sync_group_size = batch_norm_sync_group_size if batch_norm_sync_mode == BatchNormSyncMode.DISABLED: logging.info("Synchronized Batch Normalization is disabled") else: if batch_norm_sync_mode == BatchNormSyncMode.APEX and not apex_available: raise RuntimeError("apex is not installed") msg = f"Using Synchronized Batch Normalization using {batch_norm_sync_mode}" if self.batch_norm_sync_group_size > 0: msg += f" and group size {batch_norm_sync_group_size}" logging.info(msg) self.batch_norm_sync_mode = batch_norm_sync_mode if find_unused_parameters: logging.info("Enabling find_unused_parameters in DDP") self.find_unused_parameters = find_unused_parameters self.ddp_bucket_cap_mb = bucket_cap_mb if fp16_grad_compress: if get_torch_version() < [1, 8]: raise RuntimeError( "FP16 grad compression is only supported since PyTorch 1.8" ) logging.info("Enabling FP16 grad compression") self.fp16_grad_compress = fp16_grad_compress return self def set_hooks(self, hooks: List["ClassyHook"]): """Set hooks for task Args: hooks: List of hooks to apply during training """ from classy_vision.hooks import ClassyHook assert isinstance(hooks, list) assert all(isinstance(hook, ClassyHook) for hook in hooks) assert len({ hook.name() for hook in hooks }) == len(hooks), "Cannot have repeated hooks of the same class" # TODO (zyan3): we move checkpoint hook to the end of the list because some hooks # may change the state of the model, and we want to save changed state in the checkpoint. # This is temporary fix. non_checkpoint_hooks = [ hook for hook in hooks if not isinstance(hook, CheckpointHook) ] checkpoint_hooks = [ hook for hook in hooks if isinstance(hook, CheckpointHook) ] hooks = non_checkpoint_hooks + checkpoint_hooks self.hooks = hooks return self def set_model(self, model: ClassyModel): """Set model for task Args: model: Model to be trained """ self.base_model = model return self def set_test_only(self, test_only: bool): """Set test only flag Args: test_only: If true, only test phases will be run """ self.test_only = test_only return self def set_bn_weight_decay(self, bn_weight_decay: bool): assert type(bn_weight_decay) == bool self.bn_weight_decay = bn_weight_decay return self def set_amp_args(self, amp_args: Optional[Dict[str, Any]]): """Disable / enable apex.amp and set the automatic mixed precision parameters. apex.amp can be utilized for mixed / half precision training. Args: amp_args: Dictionary containing arguments to be passed to amp.initialize. Set to None to disable amp. To enable mixed precision training, pass amp_args={"opt_level": "O1"} here. See https://nvidia.github.io/apex/amp.html for more info. Raises: RuntimeError: If opt_level is not None and apex is not installed. Warning: apex needs to be installed to utilize this feature. """ self.amp_args = amp_args if amp_args is None: logging.info("AMP disabled") else: # Check that the requested AMP type is known try: self.amp_type = AmpType[self.amp_args["amp_type"].upper()] except KeyError: logging.info("AMP type not specified, defaulting to Apex") self.amp_type = AmpType.APEX # Check for CUDA availability, required for both Apex and Pytorch AMP if not torch.cuda.is_available(): raise RuntimeError( "AMP is required but CUDA is not supported, cannot enable AMP" ) # Check for Apex availability if self.amp_type == AmpType.APEX and not apex_available: raise RuntimeError( "Apex AMP is required but Apex is not installed, cannot enable AMP" ) if self.use_sharded_ddp: if self.amp_type == AmpType.APEX: raise RuntimeError( "ShardedDDP has been requested, which is incompatible with Apex AMP" ) if not fairscale_available: raise RuntimeError( "ShardedDDP has been requested, but fairscale is not installed in the current environment" ) # Set Torch AMP grad scaler, used to prevent gradient underflow elif self.amp_type == AmpType.PYTORCH: if self.use_sharded_ddp: logging.info( "Using ShardedGradScaler to manage Pytorch AMP") self.amp_grad_scaler = ShardedGradScaler() else: self.amp_grad_scaler = TorchGradScaler() logging.info(f"AMP enabled with args {amp_args}") return self def set_mixup_transform(self, mixup_transform: Optional["MixupTransform"]): """Disable / enable mixup transform for data augmentation Args:: mixup_transform: a callable object which performs mixup data augmentation """ self.mixup_transform = mixup_transform if mixup_transform is None: logging.info("mixup disabled") else: logging.info("mixup enabled") return self def set_optimizer_schedulers(self, schedulers): self.optimizer_schedulers = schedulers return self @classmethod def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask": """Instantiates a ClassificationTask from a configuration. Args: config: A configuration for a ClassificationTask. See :func:`__init__` for parameters expected in the config. Returns: A ClassificationTask instance. """ test_only = config.get("test_only", False) if not test_only: # TODO Make distinction between epochs and phases in optimizer clear train_phases_per_epoch = config["dataset"]["train"].get( "phases_per_epoch", 1) optimizer_config = config["optimizer"] optimizer_config["num_epochs"] = (config["num_epochs"] * train_phases_per_epoch) optimizer = build_optimizer(optimizer_config) param_schedulers = build_optimizer_schedulers(optimizer_config) datasets = {} phase_types = ["train", "test"] for phase_type in phase_types: if phase_type in config["dataset"]: datasets[phase_type] = build_dataset( config["dataset"][phase_type]) loss = build_loss(config["loss"]) amp_args = config.get("amp_args") meters = build_meters(config.get("meters", {})) model = build_model(config["model"]) mixup_transform = None if config.get("mixup") is not None: assert "alpha" in config[ "mixup"], "key alpha is missing in mixup dict" mixup_transform = MixupTransform( config["mixup"]["alpha"], num_classes=config["mixup"].get("num_classes"), cutmix_alpha=config["mixup"].get("cutmix_alpha", 0), cutmix_minmax=config["mixup"].get("cutmix_minmax"), mix_prob=config["mixup"].get("mix_prob", 1.0), switch_prob=config["mixup"].get("switch_prob", 0.5), mode=config["mixup"].get("mode", "batch"), label_smoothing=config["mixup"].get("label_smoothing", 0.0), ) # hooks config is optional hooks_config = config.get("hooks") hooks = [] if hooks_config is not None: hooks = build_hooks(hooks_config) distributed_config = config.get("distributed", {}) distributed_options = { "broadcast_buffers_mode": BroadcastBuffersMode[distributed_config.get( "broadcast_buffers", "before_eval").upper()], "batch_norm_sync_mode": BatchNormSyncMode[distributed_config.get("batch_norm_sync_mode", "disabled").upper()], "batch_norm_sync_group_size": distributed_config.get("batch_norm_sync_group_size", 0), "find_unused_parameters": distributed_config.get("find_unused_parameters", False), "bucket_cap_mb": distributed_config.get("bucket_cap_mb", 25), "fp16_grad_compress": distributed_config.get("fp16_grad_compress", False), } task = ( cls().set_num_epochs(config["num_epochs"]).set_test_phase_period( config.get( "test_phase_period", 1)).set_loss(loss).set_test_only(test_only).set_model( model).set_meters(meters).set_amp_args(amp_args). set_mixup_transform(mixup_transform).set_distributed_options( **distributed_options).set_hooks(hooks).set_bn_weight_decay( config.get("bn_weight_decay", False)).set_clip_grad_norm( config.get("clip_grad_norm")). set_simulated_global_batchsize( config.get("simulated_global_batchsize")).set_use_sharded_ddp( config.get("use_sharded_ddp", False))) if not test_only: task.set_optimizer(optimizer) task.set_optimizer_schedulers(param_schedulers) use_gpu = config.get("use_gpu") if use_gpu is not None: task.set_use_gpu(use_gpu) for phase_type in datasets: task.set_dataset(datasets[phase_type], phase_type) # NOTE: this is a private member and only meant to be used for # logging/debugging purposes. See __repr__ implementation task._config = config return task @property def num_batches_per_phase(self): """Returns number of batches in current phase iterator""" return len(self.data_iterator) @property def model(self): """Returns model used in training (can be wrapped with DDP)""" return (self.distributed_model if is_distributed_training_run() else self.base_model) @property def loss(self): """Returns loss used in training (can be wrapped with DDP)""" return self.distributed_loss if self.distributed_loss else self.base_loss @property def phase_type(self): """Returns current phase type. String with value "train" or "test" """ return "train" if self.train else "test" @property def eval_phase_idx(self): """Returns current evaluation phase""" return self.phase_idx - self.train_phase_idx - 1 def get_total_training_phases(self): """ Returns the total number of "train" phases in the task """ num_training_phases = 0 for phase in self.phases: if phase["train"] is True: num_training_phases += 1 return num_training_phases def get_total_test_phases(self): """ Returns the total number of "test" phases in the task """ num_test_phases = 0 for phase in self.phases: if phase["train"] is False: num_test_phases += 1 return num_test_phases def _build_phases(self): """Returns list of phases from config. These phases will look like: { train: is this a train or test phase? optimizer: optimizer settings } - If this is a test only run, then only test phases will be generated - If this is a training run with both train and test datasets, then x phases = x train phases + x test phases, interleaved. If test_phase_period > 1, test phases are only added after test_phase_period train phases. The last phase is always a test phase. - If this is a training run with only a train dataset, then x phases = x train phases. """ if not self.test_only: phases = [{ "train": True } for _ in range( math.ceil(self.train_phases_per_epoch * self.num_epochs))] if self._train_only: return phases final_phases = [] for i, phase in enumerate(phases): final_phases.append(phase) if (i + 1) % self.test_phase_period == 0: final_phases.append({"train": False}) if final_phases[-1]["train"]: final_phases.append({"train": False}) return final_phases return [{"train": False} for _ in range(self.num_epochs)] def build_dataloader_from_dataset(self, dataset, **kwargs): """Builds a dataloader from the provided dataset Args: dataset: A ClassyDataset kwargs: Additional kwargs to pass during dataloader construction for derived classes """ return dataset.iterator( phase_type=self.phase_type, current_phase_id=self.train_phase_idx if self.train else 0, pin_memory=self.use_gpu and torch.cuda.device_count() > 1, multiprocessing_context=mp.get_context(self.dataloader_mp_context), **kwargs, ) def build_dataloaders_for_current_phase(self): """Builds dataloader(s) for the current phase. Deriving classes can override this method to support custom behavior, like supporting multiple dataloaders in parallel. """ self.dataloader = self.build_dataloader_from_dataset( self.datasets[self.phase_type]) def prepare_optimizer(self, optimizer, model, loss=None): bn_params, other_params = split_batchnorm_params(model) if loss is not None: bn_params_loss, params_loss = split_batchnorm_params(loss) bn_params = bn_params + bn_params_loss other_params = other_params + params_loss bn_schedulers = self.optimizer_schedulers.copy() if not self.bn_weight_decay: bn_schedulers["weight_decay"] = 0 param_groups = [{"params": other_params, **self.optimizer_schedulers}] if len(bn_params) > 0: param_groups.append({"params": bn_params, **bn_schedulers}) self.optimizer.set_param_groups(param_groups) def prepare(self): """Prepares task for training, populates all derived attributes""" self.phases = self._build_phases() self.train = False if self.test_only else self.train if self.batch_norm_sync_mode == BatchNormSyncMode.PYTORCH: self.base_model = nn.SyncBatchNorm.convert_sync_batchnorm( self.base_model) elif self.batch_norm_sync_mode == BatchNormSyncMode.APEX: sync_bn_process_group = apex.parallel.create_syncbn_process_group( self.batch_norm_sync_group_size) self.base_model = apex.parallel.convert_syncbn_model( self.base_model, process_group=sync_bn_process_group) # move the model and loss to the right device if self.use_gpu: self.base_model, self.base_loss = copy_model_to_gpu( self.base_model, self.base_loss) else: self.base_loss.cpu() self.base_model.cpu() if self.optimizer is not None: self.prepare_optimizer(optimizer=self.optimizer, model=self.base_model, loss=self.base_loss) if self.amp_args is not None: if self.amp_type == AmpType.APEX: # Initialize apex.amp. This updates the model and the PyTorch optimizer ( # if training, which is wrapped by the ClassyOptimizer in self.optimizer). # Please note this must happen before loading the checkpoint, cause # there's amp state to be restored. if self.optimizer is None: self.base_model = apex.amp.initialize(self.base_model, optimizers=None, **self.amp_args) else: self.base_model, self.optimizer.optimizer = apex.amp.initialize( self.base_model, self.optimizer.optimizer, **self.amp_args) if self.simulated_global_batchsize is not None: if self.simulated_global_batchsize % self.get_global_batchsize( ) != 0: raise ValueError( f"Global batch size ({self.get_global_batchsize()}) must divide " f"simulated_global_batchsize ({self.simulated_global_batchsize})" ) else: self.simulated_global_batchsize = self.get_global_batchsize() self.optimizer_period = (self.simulated_global_batchsize // self.get_global_batchsize()) if self.optimizer_period > 1: logging.info( f"Using gradient accumulation with a period of {self.optimizer_period}" ) if self.checkpoint_path: self.checkpoint_dict = load_and_broadcast_checkpoint( self.checkpoint_path) classy_state_dict = (None if self.checkpoint_dict is None else self.checkpoint_dict["classy_state_dict"]) if classy_state_dict is not None: state_load_success = update_classy_state(self, classy_state_dict) assert (state_load_success ), "Update classy state from checkpoint was unsuccessful." self.init_distributed_data_parallel_model() def init_distributed_data_parallel_model(self): """ Initialize `torch.nn.parallel.distributed.DistributedDataParallel <https://pytorch.org/ docs/stable/nn.html#distributeddataparallel>`_. Needed for distributed training. This is where a model should be wrapped by DDP. """ if not is_distributed_training_run(): return assert (self.distributed_model is None), "init_ddp_non_elastic must only be called once" broadcast_buffers = ( self.broadcast_buffers_mode == BroadcastBuffersMode.FORWARD_PASS) if self.use_sharded_ddp: if not isinstance(self.optimizer, ZeRO): raise ValueError( "ShardedDataParallel engine should only be used in conjunction with ZeRO optimizer" ) from fairscale.nn.data_parallel import ShardedDataParallel # Replace the original DDP wrap by the shard-aware ShardedDDP self.distributed_model = ShardedDataParallel( module=self.base_model, sharded_optimizer=self.optimizer.optimizer, broadcast_buffers=broadcast_buffers, ) else: self.distributed_model = init_distributed_data_parallel_model( self.base_model, broadcast_buffers=broadcast_buffers, find_unused_parameters=self.find_unused_parameters, bucket_cap_mb=self.ddp_bucket_cap_mb, ) if self.fp16_grad_compress: from torch.distributed.algorithms import ddp_comm_hooks # FP16 hook is stateless and only takes a process group as the state. # We use the default process group so we set the state to None. process_group = None self.distributed_model.register_comm_hook( process_group, ddp_comm_hooks.default_hooks.fp16_compress_hook) if (isinstance(self.base_loss, ClassyLoss) and self.base_loss.has_learned_parameters()): logging.info("Initializing distributed loss") self.distributed_loss = init_distributed_data_parallel_model( self.base_loss, broadcast_buffers=broadcast_buffers, find_unused_parameters=self.find_unused_parameters, bucket_cap_mb=self.ddp_bucket_cap_mb, ) @property def where(self): """Returns the proportion of training that has completed. If in test only mode, returns proportion of testing completed Returned value is a float in the range [0, 1) """ current_step = self.num_updates / self.get_global_batchsize() num_phases = (self.get_total_test_phases() if self.test_only else self.get_total_training_phases()) if self.num_batches_per_phase <= 0: raise RuntimeError("No batches to read. Is the dataset empty?") num_steps = num_phases * self.num_batches_per_phase where = current_step / num_steps return where def get_classy_state(self, deep_copy: bool = False): """Returns serialiable state of task Args: deep_copy: If true, does a deep copy of state before returning. """ optimizer_state = {} if self.optimizer is not None: optimizer_state = self.optimizer.get_classy_state() classy_state_dict = { "train": self.train, "base_model": self.base_model.get_classy_state(), "meters": [meter.get_classy_state() for meter in self.meters], "optimizer": optimizer_state, "phase_idx": self.phase_idx, "train_phase_idx": self.train_phase_idx, "num_updates": self.num_updates, "losses": self.losses, "hooks": {hook.name(): hook.get_classy_state() for hook in self.hooks}, "loss": {}, } if "train" in self.datasets and self._is_checkpointable_dataset( self.datasets["train"]): classy_state_dict["train_dataset_iterator"] = self.datasets[ "train"].get_classy_state() if isinstance(self.base_loss, ClassyLoss): classy_state_dict["loss"] = self.base_loss.get_classy_state() if self.amp_args is not None: if self.amp_type == AmpType.APEX: classy_state_dict["amp"] = apex.amp.state_dict() elif self.amp_grad_scaler is not None: classy_state_dict["amp"] = self.amp_grad_scaler.state_dict() if deep_copy: classy_state_dict = copy.deepcopy(classy_state_dict) return classy_state_dict def set_classy_state(self, state): """Set task state Args: state: Dict containing state of a task """ self.train = False if self.test_only else state["train"] self.base_model.set_classy_state(state["base_model"]) if self.test_only: # if we're only testing, just need the state of the model to be updated return self.phase_idx = state["phase_idx"] self.num_updates = state["num_updates"] self.train_phase_idx = state["train_phase_idx"] self.losses = state["losses"] for meter, meter_state in zip(self.meters, state["meters"]): meter.set_classy_state(meter_state) if self.optimizer is not None: self.optimizer.set_classy_state(state["optimizer"]) if state.get("loss") and isinstance(self.base_loss, ClassyLoss): self.base_loss.set_classy_state(state["loss"]) if "amp" in state: if self.amp_type == AmpType.APEX: apex.amp.load_state_dict(state["amp"]) else: self.amp_grad_scaler.load_state_dict(state["amp"]) for hook in self.hooks: # we still want to be able to run when new hooks are added or old # hooks are removed if hook.name() in state["hooks"]: hook.set_classy_state(state["hooks"][hook.name()]) else: logging.warning(f"No state found for hook: {hook.name()}") if "train" in self.datasets and self._is_checkpointable_dataset( self.datasets["train"]): self.datasets["train"].set_classy_state( state.get("train_dataset_iterator")) @staticmethod def _is_checkpointable_dataset(dataset): return hasattr(dataset, "get_classy_state") and hasattr( dataset, "set_classy_state") def eval_step(self): self.last_batch = None # Process next sample with Timer() as timer: sample = next(self.data_iterator) assert isinstance( sample, dict) and "input" in sample and "target" in sample, ( f"Returned sample [{sample}] is not a map with 'input' and" + "'target' keys") target = sample["target"] if self.use_gpu: sample = recursive_copy_to_gpu(sample, non_blocking=True) # Optional Pytorch AMP context torch_amp_context = (torch.cuda.amp.autocast() if self.amp_type == AmpType.PYTORCH else contextlib.suppress()) with torch.no_grad(), torch_amp_context: output = self.model(sample["input"]) local_loss = self.compute_loss(output, sample) loss = local_loss.detach().clone() self.losses.append(loss.data.cpu().item()) self.update_meters(output, sample) # Move some data to the task so hooks get a chance to access it self.last_batch = LastBatchInfo( loss=loss, output=output, target=target, sample=sample, step_data={"sample_fetch_time": timer.elapsed_time}, ) def check_inf_nan(self, loss): if loss == float("inf") or loss == float("-inf") or loss != loss: raise FloatingPointError(f"Loss is infinity or NaN: {loss}") def _should_do_step(self): """Tells if we will be performing an optimizer step. Returns True always if there is no gradient accumulation. With gradient accumulation returns True only when the gradients will be synchronized and we will be performing an optimizer step. """ update_idx = self.num_updates // self.get_global_batchsize() return (update_idx % self.optimizer_period) == self.optimizer_period - 1 def train_step(self): """Train step to be executed in train loop.""" self.last_batch = None # Process next sample with Timer() as timer: sample = next(self.data_iterator) assert isinstance( sample, dict) and "input" in sample and "target" in sample, ( f"Returned sample [{sample}] is not a map with 'input' and" + "'target' keys") # Copy sample to GPU target = sample["target"] if self.use_gpu: sample = recursive_copy_to_gpu(sample, non_blocking=True) if self.mixup_transform is not None: sample = self.mixup_transform(sample) # Optional Pytorch AMP context torch_amp_context = (torch.cuda.amp.autocast() if self.amp_type == AmpType.PYTORCH else contextlib.suppress()) # only sync with DDP when we need to perform an optimizer step # an optimizer step can be skipped if gradient accumulation is enabled do_step = self._should_do_step() ctx_mgr_model = (self.distributed_model.no_sync() if self.distributed_model is not None and not do_step else contextlib.suppress()) ctx_mgr_loss = (self.distributed_loss.no_sync() if self.distributed_loss is not None and not do_step else contextlib.suppress()) with ctx_mgr_model, ctx_mgr_loss: # Forward pass with torch.enable_grad(), torch_amp_context: output = self.compute_model(sample) local_loss = self.compute_loss(output, sample) loss = local_loss.detach().clone() self.losses.append(loss.data.cpu().item()) self.update_meters(output, sample) # Backwards pass + optimizer step self.run_optimizer(local_loss) self.num_updates += self.get_global_batchsize() # Move some data to the task so hooks get a chance to access it self.last_batch = LastBatchInfo( loss=loss, output=output, target=target, sample=sample, step_data={"sample_fetch_time": timer.elapsed_time}, ) def compute_model(self, sample): return self.model(sample["input"]) def compute_loss(self, model_output, sample): return self.loss(model_output, sample["target"]) def run_optimizer(self, loss): """Runs backwards pass and update the optimizer""" self.check_inf_nan(loss) # Gradient accumulation logic. We always set optimizer_period, even # if gradient accumulation is disabled. Assumes all batches have the # same size update_idx = self.num_updates // self.get_global_batchsize() do_zero_grad = (update_idx % self.optimizer_period) == 0 do_step = self._should_do_step() if do_zero_grad: self.optimizer.zero_grad() if self.amp_type == AmpType.APEX: with apex.amp.scale_loss(loss, self.optimizer.optimizer) as scaled_loss: scaled_loss.backward() elif self.amp_type == AmpType.PYTORCH: self.amp_grad_scaler.scale(loss).backward() else: loss.backward() if do_step: # Handle gradient accumulation related gradient rescaling if self.optimizer_period != 1: self._rescale_gradients(1 / self.optimizer_period) # Clipping must happen after grad accumulation if self.clip_grad_norm is not None: self._clip_gradients(self.clip_grad_norm) if self.amp_type == AmpType.PYTORCH: # If using mixed precision, handle underflow-related scaling # See https://pytorch.org/docs/stable/amp.html#gradient-scaling # for context self.amp_grad_scaler.step(self.optimizer, where=self.where) self.amp_grad_scaler.update() else: self.optimizer.step(where=self.where) def _rescale_gradients(self, scale): for param in master_params(self.optimizer): if param.grad is not None: param.grad.data.mul_(scale) def _clip_gradients(self, max_norm): nn.utils.clip_grad_norm_(master_params(self.optimizer), max_norm) def update_meters(self, model_output, sample): target = sample["target"].detach().cpu() model_output = model_output.detach().cpu() # Update meters for meter in self.meters: meter.update(model_output, target, is_train=self.train) def synchronize_losses(self): """Average the losses across the different replicas""" # Average losses across nodes losses_tensor = torch.tensor(self.losses) synchronized_losses_tensor = all_reduce_mean(losses_tensor) self.losses = synchronized_losses_tensor.tolist() def advance_phase(self): """Performs bookkeeping / task updates between phases Increments phase idx, resets meters, resets loss history, resets counters, shuffles dataset, rebuilds iterators, and sets the train / test state for phase. """ logging.debug("Advancing phase") # Reset meters for next phase / epoch for meter in self.meters: meter.reset() # Reset loss history for next epoch self.losses = [] # Setup new phase self.phase_idx += 1 phase = self.phases[self.phase_idx] self.train = True if phase["train"] else False if self.train: self.train_phase_idx += 1 # Re-build dataloader & re-create iterator anytime membership changes. self.build_dataloaders_for_current_phase() self.create_data_iterators() # Set up pytorch module in train vs eval mode, update optimizer. self._set_model_train_mode() def done_training(self): """Stop condition for training""" return self.phase_idx + 1 >= len(self.phases) def create_data_iterators(self): """Creates data iterator(s) for the current phase.""" # Delete iterator explicitly so that all dataloader processes # are cleaned up. del self.data_iterator self.data_iterator = iter(self.dataloader) def _set_model_train_mode(self): """Set train mode for model""" phase = self.phases[self.phase_idx] self.base_model.train(phase["train"]) self.base_loss.train(phase["train"]) if (self.broadcast_buffers_mode == BroadcastBuffersMode.BEFORE_EVAL and not self.train): self._broadcast_buffers() def _broadcast_buffers(self): """Explicitly synchronize buffers across all devices.""" if self.distributed_model is None: return buffers = list(self.base_model.buffers()) if len(buffers) > 0: logging.info("Synchronizing buffers before evaluation.") for buffer in buffers: broadcast(buffer, 0, group=self.distributed_model.process_group) # TODO: Functions below should be better abstracted into the dataloader # abstraction def get_batchsize_per_replica(self): """Return local replica's batchsize for dataset (e.g. batchsize per GPU)""" return self.datasets[self.phase_type].get_batchsize_per_replica() def get_global_batchsize(self): """Return global batchsize across all trainers""" return self.datasets[self.phase_type].get_global_batchsize() def on_start(self): for hook in self.hooks: hook.on_start(self) def on_phase_start(self): self.phase_start_time_total = time.perf_counter() self.advance_phase() for hook in self.hooks: hook.on_phase_start(self) self.phase_start_time_train = time.perf_counter() def on_phase_end(self): self.log_phase_end(self.phase_type) if self.train: self.optimizer.on_epoch(where=self.where) logging.debug("Syncing losses on phase end...") self.synchronize_losses() logging.debug("...losses synced") logging.debug("Syncing meters on phase end...") for meter in self.meters: meter.sync_state() logging.debug("...meters synced") barrier() for hook in self.hooks: hook.on_phase_end(self) self.perf_log = [] self.log_phase_end(f"{self.phase_type}_total") if hasattr(self.datasets[self.phase_type], "on_phase_end"): self.datasets[self.phase_type].on_phase_end() def on_end(self): for hook in self.hooks: hook.on_end(self) def log_phase_end(self, tag): start_time = (self.phase_start_time_train if tag == self.phase_type else self.phase_start_time_total) phase_duration = time.perf_counter() - start_time im_per_sec = (self.get_global_batchsize() * self.num_batches_per_phase) / phase_duration self.perf_log.append({ "tag": tag, "phase_idx": self.train_phase_idx, "im_per_sec": im_per_sec }) def __repr__(self): if hasattr(self, "_config"): config = json.dumps(self._config, indent=4) return f"{super().__repr__()} initialized with config:\n{config}" return super().__repr__()
class Trainer: """Model trainer Args: model: model to train loss_fn: loss function optimizer: model optimizer epochs: number of epochs device: device to train the model on train_loader: training dataloader val_loader: validation dataloader scheduler: learning rate scheduler update_sched_on_iter: whether to call the scheduler every iter or every epoch grad_clip_max_norm: gradient clipping max norm (disabled if None) writer: writer which logs metrics to TensorBoard (disabled if None) save_path: folder in which to save models (disabled if None) checkpoint_path: path to model checkpoint, to resume training """ def __init__( self, model: torch.nn.Module, loss_fn: torch.nn.Module, optimizer: torch.optim.Optimizer, epochs: int, device: torch.device, train_loader: DataLoader, val_loader: Optional[DataLoader] = None, scheduler: Optional = None, # Type: torch.optim.lr_scheduler._LRScheduler update_sched_on_iter: bool = False, grad_clip_max_norm: Optional[float] = None, writer: Optional[SummaryWriter] = None, save_path: Optional[str] = None, checkpoint_path: Optional[str] = None, mixed_precision: bool = False, ) -> None: # Logging self.logger = logging.getLogger() self.writer = writer # Saving self.save_path = save_path # Device self.device = device # Data self.train_loader = train_loader self.val_loader = val_loader # Model self.model = model self.loss_fn = loss_fn self.optimizer = optimizer self.scheduler = scheduler self.update_sched_on_iter = update_sched_on_iter self.grad_clip_max_norm = grad_clip_max_norm self.epochs = epochs self.start_epoch = 0 # Floating-point precision self.mixed_precision = ( True if self.device.type == "cuda" and mixed_precision else False ) self.scaler = GradScaler() if self.mixed_precision else None if checkpoint_path: self._load_from_checkpoint(checkpoint_path) # Metrics self.train_loss_metric = LossMetric() self.train_acc_metric = AccuracyMetric(k=1) self.val_loss_metric = LossMetric() self.val_acc_metric = AccuracyMetric(k=1) def train(self) -> None: """Trains the model""" self.logger.info("Beginning training") start_time = time.time() for epoch in range(self.start_epoch, self.epochs): start_epoch_time = time.time() if self.mixed_precision: self._train_loop_amp(epoch) else: self._train_loop(epoch) if self.val_loader is not None: self._val_loop(epoch) epoch_time = time.time() - start_epoch_time self._end_loop(epoch, epoch_time) train_time_h = (time.time() - start_time) / 3600 self.logger.info(f"Finished training! Total time: {train_time_h:.2f}h") self._save_model(os.path.join(self.save_path, "final_model.pt"), self.epochs) def _train_loop(self, epoch: int) -> None: """ Regular train loop Args: epoch: current epoch """ # Progress bar pbar = tqdm.tqdm(total=len(self.train_loader), leave=False) pbar.set_description(f"Epoch {epoch} | Train") # Set to train self.model.train() # Loop for data, target in self.train_loader: # To device data, target = data.to(self.device), target.to(self.device) # Forward + backward self.optimizer.zero_grad() out = self.model(data) loss = self.loss_fn(out, target) loss.backward() if self.grad_clip_max_norm is not None: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.grad_clip_max_norm ) self.optimizer.step() # Update scheduler if it is iter-based if self.scheduler is not None and self.update_sched_on_iter: self.scheduler.step() # Update metrics self.train_loss_metric.update(loss.item(), data.shape[0]) self.train_acc_metric.update(out, target) # Update progress bar pbar.update() pbar.set_postfix_str(f"Loss: {loss.item():.3f}", refresh=False) # Update scheduler if it is epoch-based if self.scheduler is not None and not self.update_sched_on_iter: self.scheduler.step() pbar.close() def _train_loop_amp(self, epoch: int) -> None: """ Train loop with Automatic Mixed Precision Args: epoch: current epoch """ # Progress bar pbar = tqdm.tqdm(total=len(self.train_loader), leave=False) pbar.set_description(f"Epoch {epoch} | Train") # Set to train self.model.train() # Loop for data, target in self.train_loader: # To device data, target = data.to(self.device), target.to(self.device) # Forward + backward self.optimizer.zero_grad() # Use amp in forward pass with autocast(): out = self.model(data) loss = self.loss_fn(out, target) # Backward pass with scaler self.scaler.scale(loss).backward() # Unscale before gradient clipping self.scaler.unscale_(self.optimizer) if self.grad_clip_max_norm is not None: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.grad_clip_max_norm ) # Update optimizer and scaler self.scaler.step(self.optimizer) self.scaler.update() # Update scheduler if it is iter-based if self.scheduler is not None and self.update_sched_on_iter: self.scheduler.step() # Update metrics self.train_loss_metric.update(loss.item(), data.shape[0]) self.train_acc_metric.update(out, target) # Update progress bar pbar.update() pbar.set_postfix_str(f"Loss: {loss.item():.3f}", refresh=False) # Update scheduler if it is epoch-based if self.scheduler is not None and not self.update_sched_on_iter: self.scheduler.step() pbar.close() def _val_loop(self, epoch: int) -> None: """ Standard validation loop Args: epoch: current epoch """ # Progress bar pbar = tqdm.tqdm(total=len(self.val_loader), leave=False) pbar.set_description(f"Epoch {epoch} | Validation") # Set to eval self.model.eval() # Loop for data, target in self.val_loader: with torch.no_grad(): # To device data, target = data.to(self.device), target.to(self.device) # Forward out = self.model(data) loss = self.loss_fn(out, target) # Update metrics self.val_loss_metric.update(loss.item(), data.shape[0]) self.val_acc_metric.update(out, target) # Update progress bar pbar.update() pbar.set_postfix_str(f"Loss: {loss.item():.3f}", refresh=False) pbar.close() def _end_loop(self, epoch: int, epoch_time: float): # Print epoch results self.logger.info(self._epoch_str(epoch, epoch_time)) # Write to tensorboard if self.writer is not None: self._write_to_tb(epoch) # Save model if self.save_path is not None: self._save_model(os.path.join(self.save_path, "most_recent.pt"), epoch) # Clear metrics self.train_loss_metric.reset() self.train_acc_metric.reset() if self.val_loader is not None: self.val_loss_metric.reset() self.val_acc_metric.reset() def _epoch_str(self, epoch: int, epoch_time: float): s = f"Epoch {epoch} " s += f"| Train loss: {self.train_loss_metric.compute():.3f} " s += f"| Train acc: {self.train_acc_metric.compute():.3f} " if self.val_loader is not None: s += f"| Val loss: {self.val_loss_metric.compute():.3f} " s += f"| Val acc: {self.val_acc_metric.compute():.3f} " s += f"| Epoch time: {epoch_time:.1f}s" return s def _write_to_tb(self, epoch): self.writer.add_scalar("Loss/train", self.train_loss_metric.compute(), epoch) self.writer.add_scalar("Accuracy/train", self.train_acc_metric.compute(), epoch) if self.val_loader is not None: self.writer.add_scalar("Loss/val", self.val_loss_metric.compute(), epoch) self.writer.add_scalar("Accuracy/val", self.val_acc_metric.compute(), epoch) def _save_model(self, path, epoch): obj = { "epoch": epoch + 1, "optimizer": self.optimizer.state_dict(), "model": self.model.state_dict(), "scheduler": self.scheduler.state_dict() if self.scheduler is not None else None, "scaler": self.scaler.state_dict() if self.mixed_precision else None, } torch.save(obj, os.path.join(self.save_path, path)) def _load_from_checkpoint(self, checkpoint_path: str) -> None: checkpoint = torch.load(checkpoint_path, map_location=self.device) self.model.load_state_dict(checkpoint["model"]) self.optimizer.load_state_dict(checkpoint["optimizer"]) self.start_epoch = checkpoint["epoch"] if self.scheduler: self.scheduler.load_state_dict(checkpoint["scheduler"]) if self.mixed_precision and "scaler" in checkpoint: self.scaler.load_state_dict(checkpoint["scheduler"]) if self.start_epoch > self.epochs: raise ValueError("Starting epoch is larger than total epochs") self.logger.info(f"Checkpoint loaded, resuming from epoch {self.start_epoch}")
optimizer = optim.Adam([v for v in model.parameters() if v.requires_grad], lr=args.lr, betas=(.5, .9), eps=1e-6) scaler = GradScaler() # reload checkpoint parameters epoch = 0 num_samples_treated = 0 num_batches_treated = 0 if args.base_model is not None: if os.path.isfile(args.base_model): checkpoint = torch.load(args.base_model) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) scaler.load_state_dict(checkpoint["scaler"]) epoch = checkpoint['epoch'] + 1 # we start a new epoch num_samples_treated = checkpoint['num_samples_treated'] num_batches_treated = checkpoint['num_batches_treated'] else: # tf model from load_tf_models import load_ssrn_from_tf, load_t2m_from_tf # imported here so that installing tf is # not mandatory load_t2m_from_tf(model, args.base_model) if args.net == "Text2Mel" else \ load_ssrn_from_tf(model, args.base_model) max_num_samples_to_train_on = num_samples_treated + args.max_num_samples_to_train_on \ if args.max_num_samples_to_train_on is not None else 1e10 # 1e10 in case we # want to loop "indefinitely" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device)
class GenericTrainingManager: def __init__(self, params): self.type = None self.is_master = False self.params = params self.models = {} self.begin_time = None self.dataset = None self.paths = None self.latest_epoch = -1 self.latest_batch = 0 self.total_batch = 0 self.latest_train_metrics = dict() self.latest_valid_metrics = dict() self.phase = None self.max_mem_usage_by_epoch = list() self.scaler = None self.optimizer = None self.lr_scheduler = None self.best = None self.writer = None reset_optimizer = "reset_optimizer" in self.params["training_params"] and self.params["training_params"]["reset_optimizer"] self.init_hardware_config() self.init_paths() self.load_dataset() self.load_model(reset_optimizer) def init_paths(self): ## Create output folders output_path = os.path.join("outputs", self.params["training_params"]["output_folder"]) os.makedirs(output_path, exist_ok=True) checkpoints_path = os.path.join(output_path, "checkpoints") os.makedirs(checkpoints_path, exist_ok=True) results_path = os.path.join(output_path, "results") os.makedirs(results_path, exist_ok=True) self.paths = { "results": results_path, "checkpoints": checkpoints_path, "output_folder": output_path } def load_dataset(self): self.params["dataset_params"]["use_ddp"] = self.params["training_params"]["use_ddp"] self.params["dataset_params"]["batch_size"] = self.params["training_params"]["batch_size"] self.params["dataset_params"]["num_gpu"] = self.params["training_params"]["nb_gpu"] self.dataset = DatasetManager(self.params["dataset_params"]) if self.dataset.charset: self.params["model_params"]["vocab_size"] = len(self.dataset.charset) def init_hardware_config(self): # Debug mode if self.params["training_params"]["force_cpu"]: self.params["training_params"]["use_ddp"] = False self.params["training_params"]["use_amp"] = False # Manage Distributed Data Parallel & GPU usage self.manual_seed = 1111 if "manual_seed" not in self.params["training_params"].keys() else \ self.params["training_params"]["manual_seed"] self.ddp_config = { "master": self.params["training_params"]["use_ddp"] and self.params["training_params"]["ddp_rank"] == 0, "address": "localhost" if "ddp_addr" not in self.params["training_params"].keys() else self.params["training_params"]["ddp_addr"], "port": "11111" if "ddp_port" not in self.params["training_params"].keys() else self.params["training_params"]["ddp_port"], "backend": "nccl" if "ddp_backend" not in self.params["training_params"].keys() else self.params["training_params"]["ddp_backend"], "rank": self.params["training_params"]["ddp_rank"], } self.is_master = self.ddp_config["master"] or not self.params["training_params"]["use_ddp"] if self.params["training_params"]["force_cpu"]: self.device = "cpu" else: if self.params["training_params"]["use_ddp"]: self.device = torch.device(self.ddp_config["rank"]) self.params["dataset_params"]["ddp_rank"] = self.ddp_config["rank"] self.launch_ddp() else: self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Print GPU info # global if (self.params["training_params"]["use_ddp"] and self.ddp_config["master"]) or not self.params["training_params"]["use_ddp"]: print("##################") print("Available GPUS: {}".format(self.params["training_params"]["nb_gpu"])) for i in range(self.params["training_params"]["nb_gpu"]): print("Rank {}: {} {}".format(i, torch.cuda.get_device_name(i), torch.cuda.get_device_properties(i))) print("##################") # local print("Local GPU:") if self.device != "cpu": print("Rank {}: {} {}".format(self.params["training_params"]["ddp_rank"], torch.cuda.get_device_name(), torch.cuda.get_device_properties(self.device))) else: print("WORKING ON CPU !\n") print("##################") def load_model(self, reset_optimizer=False): self.params["model_params"]["use_amp"] = self.params["training_params"]["use_amp"] # Instanciate Model for model_name in self.params["model_params"]["models"].keys(): self.models[model_name] = self.params["model_params"]["models"][model_name](self.params["model_params"]) self.models[model_name].to(self.device) # To GPU or CPU # Instanciate optimizer self.reset_optimizer() if "lr_scheduler" in self.params["training_params"] and self.params["training_params"]["lr_scheduler"]: self.lr_scheduler = self.params["training_params"]["lr_scheduler"]["type"](self.optimizer, gamma=self.params["training_params"]["lr_scheduler"]["gamma"]) self.scaler = GradScaler(enabled=self.params["training_params"]["use_amp"]) # Load previous weights checkpoint = None if self.params["training_params"]["load_epoch"] in ("best", "last"): for filename in os.listdir(self.paths["checkpoints"]): # Continue training if self.params["training_params"]["load_epoch"] in filename: checkpoint_path = os.path.join(self.paths["checkpoints"], filename) checkpoint = torch.load(checkpoint_path) self.load_save_info(checkpoint) self.latest_epoch = checkpoint["epoch"] self.best = checkpoint["best"] self.scaler.load_state_dict(checkpoint["scaler_state_dict"]) # Make model compatible with Distributed Data Parallel if used if self.params["training_params"]["use_ddp"]: for model_name in self.models.keys(): self.models[model_name] = DDP(self.models[model_name], [self.ddp_config["rank"]]) # Load model weights from past training for model_name in self.models.keys(): self.models[model_name].load_state_dict(checkpoint["{}_state_dict".format(model_name)]) # Load optimizer state from past training if not reset_optimizer: self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) # Load optimizer scheduler config from past training if used if "lr_scheduler" in self.params["training_params"] and self.params["training_params"]["lr_scheduler"] and "lr_scheduler_state_dict" in checkpoint.keys(): self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"]) break # Print the number of trained epoch so far with the model if self.is_master: print("LOADED EPOCH: {}\n".format(self.latest_epoch), flush=True) # New training if not checkpoint: # Weights initialization for model_name in self.models.keys(): self.models[model_name].apply(self.weights_init) # Handle transfer learning instructions if self.params["model_params"]["transfer_learning"]: # Iterates over models for model_name in self.params["model_params"]["transfer_learning"].keys(): state_dict_name, path, learnable, strict = self.params["model_params"]["transfer_learning"][model_name] # Loading pretrained weights file checkpoint = torch.load(path) try: # Load pretrained weights for model self.models[model_name].load_state_dict(checkpoint["{}_state_dict".format(state_dict_name)], strict=strict) print("transfered weights for {}".format(state_dict_name), flush=True) except RuntimeError as e: print(e, flush=True) # if error, try to load each parts of the model (useful if only few layers are different) for key in checkpoint["{}_state_dict".format(state_dict_name)].keys(): try: self.models[model_name].load_state_dict({key: checkpoint["{}_state_dict".format(state_dict_name)][key]}, strict=False) except RuntimeError as e: print(e, flush=True) # Set parameters no trainable if not learnable: self.set_model_learnable(self.models[model_name], False) # make the model compatible with Distributed Data Parallel if used if self.params["training_params"]["use_ddp"]: for model_name in self.models.keys(): self.models[model_name] = DDP(self.models[model_name], [self.ddp_config["rank"]]) return @staticmethod def set_model_learnable(model, learnable=True): for p in list(model.parameters()): p.requires_grad = learnable def save_model(self, epoch, name, keep_weights=False): """ Save model weights """ if not self.is_master: return to_del = [] for filename in os.listdir(self.paths["checkpoints"]): if name in filename: to_del.append(os.path.join(self.paths["checkpoints"], filename)) path = os.path.join(self.paths["checkpoints"], "{}_{}.pt".format(name, epoch)) content = { 'optimizer_state_dict': self.optimizer.state_dict(), 'epoch': epoch, "scaler_state_dict": self.scaler.state_dict(), 'best': self.best, } if self.lr_scheduler: content["lr_scheduler_state_dict"] = self.lr_scheduler.state_dict() content = self.add_save_info(content) for model_name in self.models.keys(): content["{}_state_dict".format(model_name)] = self.models[model_name].state_dict() torch.save(content, path) if not keep_weights: for path_to_del in to_del: if path_to_del != path: os.remove(path_to_del) def reset_optimizer(self): """ Reset optimizer learning rate """ parameters = list() for model_name in self.models.keys(): parameters += list(self.models[model_name].parameters()) self.optimizer = self.params["training_params"]["optimizer"]["class"]\ (parameters, **self.params["training_params"]["optimizer"]["args"]) @staticmethod def weights_init(m): """ Weights initialization for model training from scratch """ if isinstance(m, Conv2d) or isinstance(m, Linear): if m.weight is not None: kaiming_uniform_(m.weight, nonlinearity="relu") if m.bias is not None: zeros_(m.bias) elif isinstance(m, InstanceNorm2d): if m.weight is not None: ones_(m.weight) if m.bias is not None: zeros_(m.bias) def save_params(self): """ Output text file containing a summary of all hyperparameters chosen for the training """ def compute_nb_params(module): return sum([np.prod(p.size()) for p in list(module.parameters())]) def class_to_str_dict(my_dict): for key in my_dict.keys(): if callable(my_dict[key]): my_dict[key] = my_dict[key].__name__ elif isinstance(my_dict[key], np.ndarray): my_dict[key] = my_dict[key].tolist() elif isinstance(my_dict[key], dict): my_dict[key] = class_to_str_dict(my_dict[key]) return my_dict path = os.path.join(self.paths["results"], "params") if os.path.isfile(path): return params = copy.deepcopy(self.params) params = class_to_str_dict(params) total_params = 0 for model_name in self.models.keys(): current_params = compute_nb_params(self.models[model_name]) params["model_params"]["models"][model_name] = [params["model_params"]["models"][model_name], "{:,}".format(current_params)] total_params += current_params params["model_params"]["total_params"] = "{:,}".format(total_params) params["hardware"] = dict() if self.device != "cpu": for i in range(self.params["training_params"]["nb_gpu"]): params["hardware"][str(i)] = "{} {}".format(torch.cuda.get_device_name(i), torch.cuda.get_device_properties(i)) else: params["hardware"]["0"] = "CPU" with open(path, 'w') as f: json.dump(params, f, indent=4) def update_memory_consumption(self): self.max_mem_usage_by_epoch.append(torch.cuda.max_memory_allocated()) torch.cuda.reset_max_memory_allocated() with open(os.path.join(self.paths["results"], "memory.txt"), 'a') as f: current = round(self.max_mem_usage_by_epoch[-1]/1e9, 2) max = round(np.max(self.max_mem_usage_by_epoch)/1e9, 2) min = round(np.min(self.max_mem_usage_by_epoch)/1e9, 2) median = round(np.median(self.max_mem_usage_by_epoch)/1e9, 2) mean = round(np.mean(self.max_mem_usage_by_epoch)/1e9, 2) f.write("E{} - Current: {} Go - Max: {} Go - Min: {} Go - Mean: {} Go - Median: {} Go\n".format( self.latest_epoch, current, max, min, mean, median)) @staticmethod def init_metrics(metrics_name): """ Initialization of the metrics specified in metrics_name """ metrics = { "nb_samples": 0, "weights": 0, "names": list(), "ids": list(), } for metric_name in metrics_name: if metric_name == "cer": metrics["nb_chars"] = 0 metrics[metric_name] = list() continue elif metric_name == "wer": metrics["nb_words"] = 0 elif metric_name in ["pred", "proba", "cer_force_len"]: metrics[metric_name] = list() continue elif metric_name == "diff_len": metrics[metric_name] = None continue metrics[metric_name] = 0 return metrics @staticmethod def update_metrics(metrics, batch_metrics): """ Add batch metrics to the metrics """ for key in batch_metrics.keys(): if key in ["diff_len", ]: if metrics[key] is None: metrics[key] = batch_metrics[key] else: metrics[key] = np.concatenate([metrics[key], batch_metrics[key]], axis=0) elif key in ["pred", ]: if len(metrics[key]) == 0: metrics[key] = batch_metrics[key] else: for i in range(len(metrics[key])): metrics[key][i] += batch_metrics[key][i] else: metrics[key] += batch_metrics[key] return metrics def get_display_values(self, metrics, metrics_name, num_batch): """ format metrics values for shell display purposes """ display_values = {} for metric_name in metrics_name: if metric_name in ["cer", "cer_force_len", ]: edit = np.sum(metrics[metric_name]) display_values[metric_name] = round(edit / metrics["nb_chars"], 4) elif metric_name == "wer": display_values[metric_name] = round(metrics[metric_name] / metrics["nb_words"], 4) elif metric_name in ["f_measure", "precision", "recall", "IoU", "mAP", "pp_f_measure", "pp_precision", "pp_recall", "pp_IoU", "pp_mAP"]: display_values[metric_name] = round(metrics[metric_name] / metrics["weights"], 4) elif metric_name in ["diff_len", ]: display_values[metric_name] = np.round(np.mean(np.abs(metrics[metric_name])), 3) elif metric_name in ["time", "pred", "probas", "nb_max_len", "worst_cer", ]: continue elif metric_name in ["loss", "loss_ctc", "loss_ce", "loss_ce_end", "loss_mse"]: display_values[metric_name] = round(metrics[metric_name] / self.latest_batch, 4) else: display_values[metric_name] = round(metrics[metric_name] / metrics["nb_samples"], 4) return display_values def backward_loss(self, loss, retain_graph=False): self.scaler.scale(loss).backward(retain_graph=retain_graph) def step_optimizer(self): self.scaler.step(self.optimizer) self.scaler.update() def train(self): # init tensorboard file and output param summary file if self.is_master: self.writer = SummaryWriter(self.paths["results"]) self.save_params() # init variables self.begin_time = time() focus_metric_name = self.params["training_params"]["focus_metric"] nb_epochs = self.params["training_params"]["max_nb_epochs"] interval_save_weights = self.params["training_params"]["interval_save_weights"] metrics_name = self.params["training_params"]["train_metrics"] display_values = None # perform epochs for num_epoch in range(self.latest_epoch+1, nb_epochs): self.phase = "train" # Check maximum training time stop condition if self.params["training_params"]["max_training_time"] and time() - self.begin_time > self.params["training_params"]["max_training_time"]: break # set models trainable for model_name in self.models.keys(): self.models[model_name].train() self.latest_epoch = num_epoch # init epoch metrics values metrics = self.init_metrics(metrics_name) t = tqdm(self.dataset.train_loader) t.set_description("EPOCH {}/{}".format(num_epoch, nb_epochs)) # iterates over mini-batch data for ind_batch, batch_data in enumerate(t): self.latest_batch = ind_batch + 1 self.total_batch += 1 # train on batch data and compute metrics batch_metrics = self.train_batch(batch_data, metrics_name) batch_metrics["names"] = batch_data["names"] batch_metrics["ids"] = batch_data["ids"] # Merge metrics if Distributed Data Parallel is used if self.params["training_params"]["use_ddp"]: batch_metrics = self.merge_ddp_metrics(batch_metrics) # Update learning rate via scheduler if one is used if self.lr_scheduler and ind_batch % self.params["training_params"]["lr_scheduler"]["step_interval"] == 0: self.lr_scheduler.step() # Add batch metrics values to epoch metrics values metrics = self.update_metrics(metrics, batch_metrics) display_values = self.get_display_values(metrics, metrics_name, ind_batch) t.set_postfix(values=str(display_values)) # log metrics in tensorboard file if self.is_master: for key in display_values.keys(): self.writer.add_scalar('{}_{}'.format(self.params["dataset_params"]["train"]["name"], key), display_values[key], num_epoch) self.latest_train_metrics = display_values # evaluate and compute metrics for valid sets if self.params["training_params"]["eval_on_valid"] and num_epoch % self.params["training_params"]["eval_on_valid_interval"] == 0: for valid_set_name in self.dataset.valid_loaders.keys(): # evaluate set and compute metrics eval_values = self.evaluate(valid_set_name) self.latest_valid_metrics = eval_values # log valid metrics in tensorboard file if self.is_master: for key in eval_values.keys(): self.writer.add_scalar('{}_{}'.format(valid_set_name, key), eval_values[key], num_epoch) if valid_set_name == self.params["training_params"]["set_name_focus_metric"] and (self.best is None or \ (eval_values[focus_metric_name] < self.best and self.params["training_params"]["expected_metric_value"] == "low") or\ (eval_values[focus_metric_name] > self.best and self.params["training_params"]["expected_metric_value"] == "high")): self.save_model(epoch=num_epoch, name="best") self.best = eval_values[focus_metric_name] ## save model weights if self.is_master: self.save_model(epoch=num_epoch, name="last") self.update_memory_consumption() if interval_save_weights and num_epoch % interval_save_weights == 0: self.save_model(epoch=num_epoch, name="weigths", keep_weights=True) self.writer.flush() def evaluate(self, set_name, **kwargs): self.phase = "eval" loader = self.dataset.valid_loaders[set_name] # Set models in eval mode for model_name in self.models.keys(): self.models[model_name].eval() metrics_name = self.params["training_params"]["eval_metrics"] display_values = None # initialize epoch metrics metrics = self.init_metrics(metrics_name) t = tqdm(loader) t.set_description("Evaluation E{}".format(self.latest_epoch)) with torch.no_grad(): # iterate over batch data for ind_batch, batch_data in enumerate(t): self.latest_batch = ind_batch + 1 # eval batch data and compute metrics batch_metrics = self.evaluate_batch(batch_data, metrics_name) batch_metrics["names"] = batch_data["names"] batch_metrics["ids"] = batch_data["ids"] # merge metrics values if Distributed Data Parallel is used if self.params["training_params"]["use_ddp"]: batch_metrics = self.merge_ddp_metrics(batch_metrics) # add batch metrics to epoch metrics metrics = self.update_metrics(metrics, batch_metrics) display_values = self.get_display_values(metrics, metrics_name, ind_batch) t.set_postfix(values=str(display_values)) return display_values def predict(self, custom_name, sets_list, metrics_name, output=False): self.phase = "predict" metrics_name = metrics_name.copy() self.dataset.generate_test_loader(custom_name, sets_list) loader = self.dataset.test_loaders[custom_name] # Set models in eval mode for model_name in self.models.keys(): self.models[model_name].eval() pred_time_metric = False if "time" in metrics_name: metrics_name.remove("time") pred_time_metric = True # initialize epoch metrics metrics = self.init_metrics(metrics_name) t = tqdm(loader) t.set_description("Prediction") begin_time = time() with torch.no_grad(): for ind_batch, batch_data in enumerate(t): # iterates over batch data self.latest_batch = ind_batch + 1 # eval batch data and compute metrics batch_metrics = self.evaluate_batch(batch_data, metrics_name) batch_metrics["names"] = batch_data["names"] batch_metrics["ids"] = batch_data["ids"] # merge batch metrics if Distributed Data Parallel is used if self.params["training_params"]["use_ddp"]: batch_metrics = self.merge_ddp_metrics(batch_metrics) # add batch metrics to epoch metrics metrics = self.update_metrics(metrics, batch_metrics) display_values = self.get_display_values(metrics, metrics_name, ind_batch) t.set_postfix(values=str(display_values)) pred_time = time() - begin_time # add time metric values if requested if pred_time_metric: metrics["total_time"] = np.round(pred_time, 3) metrics["sample_time"] = np.round(pred_time / len(self.dataset.test_datasets[custom_name]), 4) # output metrics values if requested if output: for name in ["probas", ]: if name in metrics.keys(): path = os.path.join(self.paths["results"], "{}_{}_{}.txt".format(name, custom_name, self.latest_epoch)) info = "\n".join(metrics[name]) with open(path, "w") as f: f.write(info) del metrics[name] self.output(metrics, custom_name) def launch_ddp(self): """ Initialize Distributed Data Parallel system """ mp.set_start_method('fork', force=True) os.environ['MASTER_ADDR'] = self.ddp_config["address"] os.environ['MASTER_PORT'] = str(self.ddp_config["port"]) dist.init_process_group(self.ddp_config["backend"], rank=self.ddp_config["rank"], world_size=self.params["training_params"]["nb_gpu"]) torch.cuda.set_device(self.ddp_config["rank"]) random.seed(self.manual_seed) np.random.seed(self.manual_seed) torch.manual_seed(self.manual_seed) torch.cuda.manual_seed(self.manual_seed) def merge_ddp_metrics(self, metrics): """ Merge metrics when Distributed Data Parallel is used """ for metric_name in metrics.keys(): if metric_name in ["wer", "wer_force_len", "nb_samples", "nb_words", "nb_chars", "nb_max_len", "f_measure", "precision", "recall", "IoU", "mAP", "pp_f_measure", "pp_precision", "pp_recall", "pp_IoU", "pp_mAP"]: metrics[metric_name] = self.sum_ddp_metric(metrics[metric_name]) elif metric_name in ["loss", "loss_ce", "loss_ctc", "loss_ce_end"]: metrics[metric_name] = self.sum_ddp_metric(metrics[metric_name], average=True) elif metric_name in ["diff_len", "cer", "cer_force_len", "ids"]: metrics[metric_name] = self.cat_ddp_metric(metrics[metric_name]) return metrics def sum_ddp_metric(self, metric, average=False): """ Sum metrics for Distributed Data Parallel """ sum = torch.tensor(metric).to(self.device) dist.all_reduce(sum, op=dist.ReduceOp.SUM) if average: sum.true_divide(dist.get_world_size()) return sum.item() def cat_ddp_metric(self, metric): """ Concatenate metrics for Distributed Data Parallel """ tensor = torch.tensor(metric).unsqueeze(0).to(self.device) res = [torch.zeros(tensor.size()).long().to(self.device) for _ in range(dist.get_world_size())] dist.all_gather(res, tensor) return list(torch.cat(res, dim=0).flatten().cpu().numpy()) @staticmethod def cleanup(): dist.destroy_process_group() def train_batch(self, batch_data, metric_names): raise NotImplementedError def evaluate_batch(self, batch_data, metric_names): raise NotImplementedError def output_pred(self, pred, set_name): raise NotImplementedError def add_checkpoint_info(self, load_mode="last", **kwargs): for filename in os.listdir(self.paths["checkpoints"]): if load_mode in filename: checkpoint_path = os.path.join(self.paths["checkpoints"], filename) checkpoint = torch.load(checkpoint_path) for key in kwargs.keys(): checkpoint[key] = kwargs[key] torch.save(checkpoint, checkpoint_path) return self.save_model(self.latest_epoch, "last") def output(self, metrics, set_name): """ Output metrics in text file """ path = os.path.join(self.paths["results"], "predict_{}_{}.txt".format(set_name, self.latest_epoch)) with open(path, "w") as f: for metric_name in metrics.keys(): if metric_name in ["cer", "cer_force_len"]: edit = np.sum(metrics[metric_name]) value = round(edit / metrics["nb_chars"], 4) elif metric_name in ["wer", ]: value = round(metrics[metric_name] / metrics["nb_words"], 4) elif metric_name in ["loss_ce", ]: value = round(metrics[metric_name] / metrics["nb_samples"], 4) elif metric_name in ["total_time", "sample_time", "total_output_time", "sample_output_time"]: value = metrics[metric_name] elif metric_name in ["nb_samples", "nb_words", "nb_chars", "nb_max_len"]: value = metrics[metric_name] elif metric_name in ["diff_len", ]: f.write("{}: {}\n".format(metric_name, sorted(list(metrics[metric_name])))) f.write("{}-mean_abs: {}\n".format(metric_name, np.mean(np.abs(metrics[metric_name])))) continue elif metric_name in ["worst_cer", ]: m = metric_name.split("_")[-1] value = [[c, id] for c, id in zip(metrics[m], metrics["ids"])] value = sorted(value, key=lambda x: x[0], reverse=True) value = value[:50] else: continue f.write("{}: {}\n".format(metric_name, value)) def load_save_info(self, info_dict): """ Load curriculum info from saved model info """ if "curriculum_config" in info_dict.keys(): self.dataset.train_dataset.curriculum_config = info_dict["curriculum_config"] def add_save_info(self, info_dict): """ Add curriculum info to model info to be saved """ info_dict["curriculum_config"] = self.dataset.train_dataset.curriculum_config return info_dict
def training(args): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #===================================# #==============Logging==============# #===================================# logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) handler = TqdmLoggingHandler() handler.setFormatter( logging.Formatter(" %(asctime)s - %(message)s", "%Y-%m-%d %H:%M:%S")) logger.addHandler(handler) logger.propagate = False #===================================# #============Data Load==============# #===================================# # 1) Data open write_log(logger, "Load data...") gc.disable() with open(os.path.join(args.preprocess_path, 'processed.pkl'), 'rb') as f: data_ = pickle.load(f) train_src_indices = data_['train_src_indices'] valid_src_indices = data_['valid_src_indices'] train_trg_indices = data_['train_trg_indices'] valid_trg_indices = data_['valid_trg_indices'] src_word2id = data_['src_word2id'] trg_word2id = data_['trg_word2id'] src_vocab_num = len(src_word2id) trg_vocab_num = len(trg_word2id) del data_ gc.enable() write_log(logger, "Finished loading data!") # 2) Dataloader setting dataset_dict = { 'train': CustomDataset(train_src_indices, train_trg_indices, min_len=args.min_len, src_max_len=args.src_max_len, trg_max_len=args.trg_max_len), 'valid': CustomDataset(valid_src_indices, valid_trg_indices, min_len=args.min_len, src_max_len=args.src_max_len, trg_max_len=args.trg_max_len), } dataloader_dict = { 'train': DataLoader(dataset_dict['train'], drop_last=True, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.num_workers), 'valid': DataLoader(dataset_dict['valid'], drop_last=False, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.num_workers) } write_log( logger, f"Total number of trainingsets iterations - {len(dataset_dict['train'])}, {len(dataloader_dict['train'])}" ) #===================================# #===========Train setting===========# #===================================# # 1) Model initiating write_log(logger, 'Instantiating model...') model = Transformer( src_vocab_num=src_vocab_num, trg_vocab_num=trg_vocab_num, pad_idx=args.pad_id, bos_idx=args.bos_id, eos_idx=args.eos_id, d_model=args.d_model, d_embedding=args.d_embedding, n_head=args.n_head, dim_feedforward=args.dim_feedforward, num_common_layer=args.num_common_layer, num_encoder_layer=args.num_encoder_layer, num_decoder_layer=args.num_decoder_layer, src_max_len=args.src_max_len, trg_max_len=args.trg_max_len, dropout=args.dropout, embedding_dropout=args.embedding_dropout, trg_emb_prj_weight_sharing=args.trg_emb_prj_weight_sharing, emb_src_trg_weight_sharing=args.emb_src_trg_weight_sharing, parallel=args.parallel) model.train() model = model.to(device) tgt_mask = model.generate_square_subsequent_mask(args.trg_max_len - 1, device) # 2) Optimizer & Learning rate scheduler setting optimizer = optimizer_select(model, args) scheduler = shceduler_select(optimizer, dataloader_dict, args) scaler = GradScaler() # 3) Model resume start_epoch = 0 if args.resume: write_log(logger, 'Resume model...') checkpoint = torch.load( os.path.join(args.save_path, 'checkpoint.pth.tar')) start_epoch = checkpoint['epoch'] + 1 model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) scaler.load_state_dict(checkpoint['scaler']) del checkpoint #===================================# #=========Model Train Start=========# #===================================# best_val_acc = 0 write_log(logger, 'Traing start!') for epoch in range(start_epoch + 1, args.num_epochs + 1): start_time_e = time() for phase in ['train', 'valid']: if phase == 'train': model.train() if phase == 'valid': write_log(logger, 'Validation start...') val_loss = 0 val_acc = 0 model.eval() for i, (src, trg) in enumerate( tqdm(dataloader_dict[phase], bar_format='{l_bar}{bar:30}{r_bar}{bar:-2b}')): # Optimizer setting optimizer.zero_grad(set_to_none=True) # Input, output setting src = src.to(device, non_blocking=True) trg = trg.to(device, non_blocking=True) trg_sequences_target = trg[:, 1:] non_pad = trg_sequences_target != args.pad_id trg_sequences_target = trg_sequences_target[ non_pad].contiguous().view(-1) # Train if phase == 'train': # Loss calculate with autocast(): predicted = model(src, trg[:, :-1], tgt_mask, non_pad_position=non_pad) predicted = predicted.view(-1, predicted.size(-1)) loss = label_smoothing_loss(predicted, trg_sequences_target, args.pad_id) scaler.scale(loss).backward() scaler.unscale_(optimizer) clip_grad_norm_(model.parameters(), args.clip_grad_norm) scaler.step(optimizer) scaler.update() if args.scheduler in ['constant', 'warmup']: scheduler.step() if args.scheduler == 'reduce_train': scheduler.step(loss) # Print loss value only training if i == 0 or freq == args.print_freq or i == len( dataloader_dict['train']): acc = (predicted.max(dim=1)[1] == trg_sequences_target ).sum() / len(trg_sequences_target) iter_log = "[Epoch:%03d][%03d/%03d] train_loss:%03.3f | train_acc:%03.2f%% | learning_rate:%1.6f | spend_time:%02.2fmin" % \ (epoch, i, len(dataloader_dict['train']), loss.item(), acc*100, optimizer.param_groups[0]['lr'], (time() - start_time_e) / 60) write_log(logger, iter_log) freq = 0 freq += 1 # Validation if phase == 'valid': with torch.no_grad(): predicted = model(src, trg[:, :-1], tgt_mask, non_pad_position=non_pad) loss = F.cross_entropy(predicted, trg_sequences_target) val_loss += loss.item() val_acc += (predicted.max(dim=1)[1] == trg_sequences_target ).sum() / len(trg_sequences_target) if args.scheduler == 'reduce_valid': scheduler.step(val_loss) if args.scheduler == 'lambda': scheduler.step() if phase == 'valid': val_loss /= len(dataloader_dict[phase]) val_acc /= len(dataloader_dict[phase]) write_log(logger, 'Validation Loss: %3.3f' % val_loss) write_log(logger, 'Validation Accuracy: %3.2f%%' % (val_acc * 100)) if val_acc > best_val_acc: write_log(logger, 'Checkpoint saving...') torch.save( { 'epoch': epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'scaler': scaler.state_dict() }, f'checkpoint_{args.parallel}.pth.tar') best_val_acc = val_acc best_epoch = epoch else: else_log = f'Still {best_epoch} epoch accuracy({round(best_val_acc.item()*100, 2)})% is better...' write_log(logger, else_log) # 3) Print results print(f'Best Epoch: {best_epoch}') print(f'Best Accuracy: {round(best_val_acc.item(), 2)}')
class Fp16OptimizerHook(OptimizerHook): """FP16 optimizer hook (using PyTorch's implementation). If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend, to take care of the optimization procedure. Args: loss_scale (float | str | dict): Scale factor configuration. If loss_scale is a float, static loss scaling will be used with the specified scale. If loss_scale is a string, it must be 'dynamic', then dynamic loss scaling will be used. It can also be a dict containing arguments of GradScalar. Defaults to 512. For Pytorch >= 1.6, mmcv uses official implementation of GradScaler. If you use a dict version of loss_scale to create GradScaler, please refer to: https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler for the parameters. Examples: >>> loss_scale = dict( ... init_scale=65536.0, ... growth_factor=2.0, ... backoff_factor=0.5, ... growth_interval=2000 ... ) >>> optimizer_hook = Fp16OptimizerHook(loss_scale=loss_scale) """ def __init__(self, grad_clip=None, coalesce=True, bucket_size_mb=-1, loss_scale=512., distributed=True): self.grad_clip = grad_clip self.coalesce = coalesce self.bucket_size_mb = bucket_size_mb self.distributed = distributed self._scale_update_param = None if loss_scale == 'dynamic': self.loss_scaler = GradScaler() elif isinstance(loss_scale, float): self._scale_update_param = loss_scale self.loss_scaler = GradScaler(init_scale=loss_scale) elif isinstance(loss_scale, dict): self.loss_scaler = GradScaler(**loss_scale) else: raise ValueError('loss_scale must be of type float, dict, or ' f'"dynamic", got {loss_scale}') def before_run(self, runner): """Preparing steps before Mixed Precision Training.""" # wrap model mode to fp16 wrap_fp16_model(runner.model) # resume from state dict if 'fp16' in runner.meta and 'loss_scaler' in runner.meta['fp16']: scaler_state_dict = runner.meta['fp16']['loss_scaler'] self.loss_scaler.load_state_dict(scaler_state_dict) def copy_grads_to_fp32(self, fp16_net, fp32_weights): """Copy gradients from fp16 model to fp32 weight copy.""" for fp32_param, fp16_param in zip(fp32_weights, fp16_net.parameters()): if fp16_param.grad is not None: if fp32_param.grad is None: fp32_param.grad = fp32_param.data.new( fp32_param.size()) fp32_param.grad.copy_(fp16_param.grad) def copy_params_to_fp16(self, fp16_net, fp32_weights): """Copy updated params from fp32 weight copy to fp16 model.""" for fp16_param, fp32_param in zip(fp16_net.parameters(), fp32_weights): fp16_param.data.copy_(fp32_param.data) def after_train_iter(self, runner): """Backward optimization steps for Mixed Precision Training. For dynamic loss scaling, please refer to https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler. 1. Scale the loss by a scale factor. 2. Backward the loss to obtain the gradients. 3. Unscale the optimizer’s gradient tensors. 4. Call optimizer.step() and update scale factor. 5. Save loss_scaler state_dict for resume purpose. """ # clear grads of last iteration runner.model.zero_grad() runner.optimizer.zero_grad() self.loss_scaler.scale(runner.outputs['loss']).backward() self.loss_scaler.unscale_(runner.optimizer) # grad clip if self.grad_clip is not None: grad_norm = self.clip_grads(runner.model.parameters()) if grad_norm is not None: # Add grad norm to the logger runner.log_buffer.update({'grad_norm': float(grad_norm)}, runner.outputs['num_samples']) # backward and update scaler self.loss_scaler.step(runner.optimizer) self.loss_scaler.update(self._scale_update_param) # save state_dict of loss_scaler runner.meta.setdefault( 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
class Trainer: def __init__(self, config: DictConfig, model: FlyModel, name: str = "task1", *args, **kwargs): """ Args: config: FlyConfig dictionary model: must be FlyModel dataloader_fn: a Callable function which returns dataloaders """ logger.info("TrainerLoop is initializing!") if not isinstance(model, FlyModel): logger.warn("model is not defined as FlyModel") self.config = config self.model = model self.name = name # class properties self.rank = None self.local_rank = None self.node_rank = None self.world_size = None self.distributed_training = None self.device = None self.fp16 = config.fp16 self.gradient_accumulation_batches = config.gradient_accumulation_batches self.callback_handler = None self.optimizers = [] self.schedulers = [] self.init_distributed_environment() # Model is sent to GPU or CPU self.init_device() # self.optimizers, self.schedulers = self.configure_optimizers() self.model = move_to_device(self.model, self.device) self.model.device = self.device self.init_fp16() if self.distributed_training: self.init_distributed_model(self.model) # make sure the model has access to trainer info self.model.set_trainer(self) self.callback_handler = CallbackHandler(config, trainer=self, callbacks=[], verbose=config.logging.level == "DEBUG") # Configure all callbacks self.configure_callbacks() self.callback_handler.fire_event(Events.INITIALIZE) def init_distributed_environment(self): # For distributed self.rank = int(os.environ.get("RANK", 0)) self.local_rank = int(os.environ.get("LOCAL_RANK", 0)) self.world_size = int(os.environ.get("WORLD_SIZE", 1)) self.distributed_training = (self.world_size > 1) # TODO: add error message when num_gpus is set, but distributed training is False here if self.distributed_training and not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend='nccl', init_method='env://') assert torch.distributed.is_initialized() if self.distributed_training and not torch.distributed.is_initialized(): self.node_rank = os.environ.get("NODE_RANK", "N/A") logger.info( f"Initialized Rank:{torch.distributed.get_rank()} Locak-rank: {self.local_rank} on Node:{self.node_rank} Node-name:{socket.gethostname()}" ) def init_device(self): # set cuda device if self.config.num_gpus_per_node > 0: torch.cuda.set_device(self.local_rank) self.device = torch.device("cuda", self.local_rank) else: self.device = torch.device("cpu") def init_fp16(self): if self.config.num_gpus_per_node == 0: raise NotImplementedError("For mixed precision training, you need to use GPU!") self.loss_scaler = GradScaler() def init_training_constants(self): self.total_num_update_steps = int(self.config.total_num.update_steps) self.total_num_batches = self.total_num_update_steps * int(self.gradient_accumulation_batches) self.total_num_epochs = int(self.config.total_num.epochs) # check if training in epoch or update_steps if self.total_num_update_steps < 0 and self.total_num_epochs < 0: raise NotImplementedError("config.total_num.updated_steps must be larger than 0") elif self.total_num_update_steps > 0 and self.total_num_epochs > 0: raise NotImplementedError( "Please only set either config.total_num.updated_steps or config.total_num.epochs greater than 0") elif self.total_num_update_steps > 0 and self.total_num_epochs < 0: self.training_in_epoch = False elif self.total_num_update_steps < 0 and self.total_num_epochs > 0: self.training_in_epoch = True # get the number of batches in the dataloader for one epoch try: self.epoch_num_batches = len(self.train_dataloader) except TypeError: logger.warning("Cannot determine the length of train_dataloader!") self.epoch_num_batches = None if self.training_in_epoch: if self.epoch_num_batches is not None: self.total_num_batches = self.epoch_num_batches * self.total_num_epochs self.total_num_update_steps = self.total_num_batches // self.gradient_accumulation_batches self.epoch_num_update_steps = self.epoch_num_batches // self.gradient_accumulation_batches else: # this is set to wait until the epoch finishes first self.total_num_update_steps = sys.maxsize def configure_optimizers(self, total_num_update_steps=None, optimizers=None, schedulers=None): if optimizers is not None and schedulers is not None: self.optimizers, self.schedulers = optimizers, schedulers elif total_num_update_steps is not None: self.optimizers, self.schedulers = self.model.configure_optimizers(total_num_update_steps) else: raise ValueError("Please provide the correct argument!") return self.optimizers, self.schedulers def configure_callbacks(self): # Resume callback runs for all ranks if self.config.resume.enabled: self.resume_callback = Resume(self.config) self.add_callback(self.resume_callback) self.log_callback = TrainLogger(self.config) self.add_callback(self.log_callback) self.eval_callback = Evaluation(self.config) self.add_callback(self.eval_callback) # For logging and inference, use rank 0 by default if self.rank == 0: if self.config.console: self.console_callback = Console(self.config) self.add_callback(self.console_callback) if self.config.checkpointing.enabled: self.checkpoint_callback = Checkpoint(self.config) self.add_callback(self.checkpoint_callback) def init_distributed_model(self, model): """ Default distributed training uses reducer for simplicity. """ # Distributed training (should be after apex fp16 initialization) self.reducer = Reducer(model) # for param in self.model.parameters(): # dist.broadcast(param.data, 0) def train(self, train_dataloader, validation_dataloader=None, test_dataloader=None, configure_optimizers=True, name=None, *args, **kwargs): self.total_num_update_steps = 0 self.total_num_batches = 0 self.total_num_epochs = 0 self.epoch_num_batches = 0 self.global_batch_count = 0 self.global_step_count = 0 self.epochs_trained = 0 self.local_step_count = 0 self.train_dataloader = train_dataloader self.validation_dataloader = validation_dataloader self.test_dataloader = test_dataloader self.init_training_constants() if configure_optimizers or len(self.optimizers) == 0: self.configure_optimizers(self.total_num_update_steps) if name is not None: self.name = name # Training begins self.callback_handler.fire_event(Events.TRAIN_BEGIN) while True: self.callback_handler.fire_event(Events.EPOCH_BEGIN) self.train_epoch() self.callback_handler.fire_event(Events.EPOCH_END) self.epochs_trained += 1 if self.training_in_epoch: if self.epochs_trained >= self.total_num_epochs: break else: if self.global_step_count < self.total_num_update_steps: continue else: break # Training ends self.callback_handler.fire_event(Events.TRAIN_END) def train_epoch(self): self.optimizer = self.optimizers[0] self.scheduler = self.schedulers[0] self.local_step_count = 0 if self.train_dataloader is None: return for batch in self.train_dataloader: self.callback_handler.fire_event(Events.BATCH_BEGIN) batch = move_to_device(batch, self.device) output = self.backward_batch(batch) # Update the model if (self.global_batch_count + 1) % self.gradient_accumulation_batches == 0: # Update the model with optimizer self.step_update(self.model, self.optimizer, self.scheduler) self.global_step_count += 1 self.local_step_count += 1 self.callback_handler.fire_event(Events.BATCH_END) if self.global_step_count >= self.total_num_update_steps: break self.global_batch_count += 1 def backward_batch(self, batch): self.model.train() with torch.cuda.amp.autocast(self.fp16): output = self.model(batch) # get the loss from output if hasattr(output, "loss"): loss = output.loss elif isinstance(output, dict): loss = output["loss"] if self.gradient_accumulation_batches > 1: loss = loss / self.gradient_accumulation_batches self.loss_backward(loss) return output def step_update(self, model, optimizer, scheduler=None): """ self.loss_scaler is defined in `configure_fp16` """ self.callback_handler.fire_event(Events.STEP_BEGIN) # collect gradient if self.distributed_training: self.reducer.reduce() gradient_clip = self.config.optimization.max_gradient_norm # Gradient Clipping if gradient_clip > 0: if self.fp16: self.loss_scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip) else: torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip) # Update the model if self.fp16: self.loss_scaler.step(optimizer) self.loss_scaler.update() else: optimizer.step() # Step learning rate if scheduler: scheduler.step() # Gradient to zero optimizer.zero_grad() self.callback_handler.fire_event(Events.STEP_END) def loss_backward(self, loss): self.callback_handler.fire_event(Events.BACKWARD_BEGIN) # Loss backward if self.fp16: self.loss_scaler.scale(loss).backward() else: loss.backward() self.callback_handler.fire_event(Events.BACKWARD_END) def validate(self, dataloader): # Start Validation self.model.reset_evaluation_metrics() self.callback_handler.fire_event(Events.VALIDATE_BEGIN) self.model.validation_loop(dataloader) self.callback_handler.fire_event(Events.VALIDATE_END) def test(self, dataloader): # Start Testing self.model.reset_evaluation_metrics() self.callback_handler.fire_event(Events.TEST_BEGIN) self.model.test_loop(dataloader) self.callback_handler.fire_event(Events.TEST_END) def set_model_state(self, model_state_dict): self.model.load_state_dict(model_state_dict) def get_model_state(self): return self.model.state_dict() def set_trainer_state(self, trainer_state_dict): self.epochs_trained = trainer_state_dict["epochs_trained"] self.global_step_count = trainer_state_dict["global_step_count"] self.local_step_count = trainer_state_dict["local_step_count"] # Resume the training state if self.config.resume.resume: # Scheduler States if self.config.resume.resume_scheduler: for idx, scheduler in enumerate(self.schedulers): try: scheduler.load_state_dict(trainer_state_dict["schedulers_state_dict"][idx]) except: if self.rank == 0: logger.warning(f"Cannot Load Scheduler {idx}'s State!") if self.config.resume.resume_optimizer: for idx, optimizer in enumerate(self.optimizers): try: optimizer.load_state_dict(trainer_state_dict["optimizers_state_dict"][idx]) except: if self.rank == 0: logger.warning(f"Cannot Load Optimizer {idx}'s State!") # save amp states if self.fp16: self.loss_scaler.load_state_dict(trainer_state_dict["amp_state_dict"]) # Random States if self.config.resume.resume_rng_state: torch.set_rng_state(trainer_state_dict["cpu_rng_state"]) trainer_state_dict["cuda_rng_state"] = trainer_state_dict["cuda_rng_state"][:torch.cuda.device_count()] torch.cuda.set_rng_state_all(trainer_state_dict["cuda_rng_state"]) # All Callbacks for callback in self.callback_handler.callbacks: try: callback.load_state_dict(trainer_state_dict[str(type(callback))]) except: logger.error(f"{type(callback)} seems not to exist in the checkpoint state!") def get_trainer_state(self): trainer_state_dict = { "epochs_trained": self.epochs_trained, "global_step_count": self.global_step_count, "local_step_count": self.local_step_count, "optimizers_state_dict": [optimizer.state_dict() for optimizer in self.optimizers], "schedulers_state_dict": [scheduler.state_dict() for scheduler in self.schedulers], "cpu_rng_state": torch.get_rng_state(), "cuda_rng_state": torch.cuda.get_rng_state_all(), } # save amp states if self.fp16: trainer_state_dict["amp_state_dict"] = self.loss_scaler.state_dict() # All Callbacks for callback in self.callback_handler.callbacks: trainer_state_dict[str(type(callback))] = callback.state_dict() return trainer_state_dict def add_callback(self, callback: Callback): self.callback_handler.add_callback(callback) # def get_lr(optimizer): # for param_group in optimizer.param_groups: # return param_group['lr'] # def get_log_variable(x): # if isinstance(x, torch.Tensor): # x = x.detach() # return x.item() # else: # raise NotImplementedError
def main(args): # ensures that weight initializations are all the same torch.manual_seed(args.seed) np.random.seed(args.seed) torch.cuda.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) logging = utils.Logger(args.global_rank, args.save) writer = utils.Writer(args.global_rank, args.save) # Get data loaders. train_queue, valid_queue, num_classes, _ = datasets.get_loaders(args) args.num_total_iter = len(train_queue) * args.epochs warmup_iters = len(train_queue) * args.warmup_epochs swa_start = len(train_queue) * (args.epochs - 1) arch_instance = utils.get_arch_cells(args.arch_instance) model = AutoEncoder(args, writer, arch_instance) model = model.cuda() logging.info('args = %s', args) logging.info('param size = %fM ', utils.count_parameters_in_M(model)) logging.info('groups per scale: %s, total_groups: %d', model.groups_per_scale, sum(model.groups_per_scale)) if args.fast_adamax: # Fast adamax has the same functionality as torch.optim.Adamax, except it is faster. cnn_optimizer = Adamax(model.parameters(), args.learning_rate, weight_decay=args.weight_decay, eps=1e-3) else: cnn_optimizer = torch.optim.Adamax(model.parameters(), args.learning_rate, weight_decay=args.weight_decay, eps=1e-3) cnn_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( cnn_optimizer, float(args.epochs - args.warmup_epochs - 1), eta_min=args.learning_rate_min) grad_scalar = GradScaler(2**10) num_output = utils.num_output(args.dataset, args) bpd_coeff = 1. / np.log(2.) / num_output # if load checkpoint_file = os.path.join(args.save, 'checkpoint.pt') if args.cont_training: logging.info('loading the model.') checkpoint = torch.load(checkpoint_file, map_location='cpu') init_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) model = model.cuda() cnn_optimizer.load_state_dict(checkpoint['optimizer']) grad_scalar.load_state_dict(checkpoint['grad_scalar']) cnn_scheduler.load_state_dict(checkpoint['scheduler']) global_step = checkpoint['global_step'] else: global_step, init_epoch = 0, 0 for epoch in range(init_epoch, args.epochs): # update lrs. if args.distributed: train_queue.sampler.set_epoch(global_step + args.seed) valid_queue.sampler.set_epoch(0) if epoch > args.warmup_epochs: cnn_scheduler.step() # Logging. logging.info('epoch %d', epoch) # Training. train_nelbo, global_step = train(train_queue, model, cnn_optimizer, grad_scalar, global_step, warmup_iters, writer, logging) logging.info('train_nelbo %f', train_nelbo) writer.add_scalar('train/nelbo', train_nelbo, global_step) model.eval() # generate samples less frequently eval_freq = 1 if args.epochs <= 50 else 20 if epoch % eval_freq == 0 or epoch == (args.epochs - 1): with torch.no_grad(): num_samples = 16 n = int(np.floor(np.sqrt(num_samples))) for t in [0.7, 0.8, 0.9, 1.0]: logits = model.sample(num_samples, t) output = model.decoder_output(logits) output_img = output.mean if isinstance( output, torch.distributions.bernoulli.Bernoulli ) else output.sample(t) output_tiled = utils.tile_image(output_img, n) writer.add_image('generated_%0.1f' % t, output_tiled, global_step) valid_neg_log_p, valid_nelbo = test(valid_queue, model, num_samples=10, args=args, logging=logging) logging.info('valid_nelbo %f', valid_nelbo) logging.info('valid neg log p %f', valid_neg_log_p) logging.info('valid bpd elbo %f', valid_nelbo * bpd_coeff) logging.info('valid bpd log p %f', valid_neg_log_p * bpd_coeff) writer.add_scalar('val/neg_log_p', valid_neg_log_p, epoch) writer.add_scalar('val/nelbo', valid_nelbo, epoch) writer.add_scalar('val/bpd_log_p', valid_neg_log_p * bpd_coeff, epoch) writer.add_scalar('val/bpd_elbo', valid_nelbo * bpd_coeff, epoch) save_freq = int(np.ceil(args.epochs / 100)) if epoch % save_freq == 0 or epoch == (args.epochs - 1): if args.global_rank == 0: logging.info('saving the model.') torch.save( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': cnn_optimizer.state_dict(), 'global_step': global_step, 'args': args, 'arch_instance': arch_instance, 'scheduler': cnn_scheduler.state_dict(), 'grad_scalar': grad_scalar.state_dict() }, checkpoint_file) # Final validation valid_neg_log_p, valid_nelbo = test(valid_queue, model, num_samples=1000, args=args, logging=logging) logging.info('final valid nelbo %f', valid_nelbo) logging.info('final valid neg log p %f', valid_neg_log_p) writer.add_scalar('val/neg_log_p', valid_neg_log_p, epoch + 1) writer.add_scalar('val/nelbo', valid_nelbo, epoch + 1) writer.add_scalar('val/bpd_log_p', valid_neg_log_p * bpd_coeff, epoch + 1) writer.add_scalar('val/bpd_elbo', valid_nelbo * bpd_coeff, epoch + 1) writer.close()
def main_worker(rank, size, args_in): global args args = args_in is_root = rank == 0 dist.init_process_group(backend='nccl', init_method=f"tcp://localhost:{args.port}", world_size=size, rank=rank) """ Config writer, seed and device """ CheckpointFunction.use_amp = args.amp writer = config_summary_writer(is_root=is_root, output_dir=args.output_dir) seed = args.seed + rank seed_all(seed) device = config_device() torch.backends.cudnn.benchmark = True """ Load Dataloaders """ if args.ckpt is not None: gpt_ckpt = torch.load(args.ckpt, map_location=device) if is_root: print(f"Loading GPT from checkpoint {args.ckpt} with loss {gpt_ckpt['best_loss']}") dset_configs = gpt_ckpt['dset_configs'] # overwrite args.dataset = dset_configs['dataset'] args.resolution = dset_configs['resolution'] else: gpt_ckpt = None dset_configs = dict(dataset=args.dataset, resolution=args.resolution, n_frames=args.n_frames) train_loader, test_loader, dset = get_distributed_loaders( dset_configs=dset_configs, batch_size=args.batch_size, seed=seed ) if is_root: print(f"dset loader n_batch: train = {len(train_loader)}, test = {len(test_loader)}") """ Load VQ-VAE """ vqvae_ckpt = args.vqvae_ckpt if gpt_ckpt is None else gpt_ckpt['vqvae_ckpt'] if is_root: print(f'Loading VQ-VAE from {vqvae_ckpt}') vqvae_ckpt_loaded = torch.load(vqvae_ckpt, map_location=device) vqvae, vq_hp = load_model( ckpt=vqvae_ckpt_loaded, device=device, freeze_model=True, cond_types=tuple() ) del vqvae_ckpt_loaded latent_shape = vqvae.latent_shape quantized_shape = vqvae.quantized_shape if is_root: print('latent shape', latent_shape) print('quantized shape', quantized_shape) print('total latents', np.prod(latent_shape)) """ Config cond_types""" if gpt_ckpt is not None: cond_hp = gpt_ckpt['cond_hp'] else: cond_hp = dict( n_cond_frames=args.n_cond_frames, class_cond=args.class_cond, cond_init_configs=dict( type='enc_attn', model='resnet_v1', resnet_dim=576, resnet_depth=34, resnet_output_shape=(1, 16, 16), width_multiplier=1, ), ) def load_prior(layer_ckpt): """ Check consistency """ layer_cond_types, _ = config_cond_types( cond_hp=layer_ckpt['cond_hp'], dset=dset) # freeze all previous priors, not the current one layer_prior, layer_hp = load_model( ckpt=layer_ckpt, device=device, freeze_model=False, cond_types=layer_cond_types) layer_codebook = vqvae.codebook return layer_prior, layer_hp, layer_codebook def inputs_fn(batch): with torch.no_grad(): videos = batch['video'].to(device, non_blocking=True) # (b, c, t, h, w) cond = [] if cond_hp['n_cond_frames'] > 0: cond_frames = videos[:, :, :cond_hp['n_cond_frames']] cond.append(cond_frames) if cond_hp['class_cond']: cond.append(batch['label'].to(device, non_blocking=True)) quantized, encodings = vqvae.encode(x=videos, no_flatten=True) # latent_shape = (t, h, w, l) quantized = shift_dim(quantized, 1, -1) # (b, d, t, h, w, l) -> (b, t, h, w, l, d) # channel first -> last encodings = encodings.long() cond = tuple(cond) return dict(encodings=encodings, quantized=quantized, cond=cond, decode_step=None, decode_idx=None) cond_types, cond_hp = config_cond_types( cond_hp=cond_hp, dset=dset ) if is_root: print('cond_types', [(c.name, c.type, c.out_size) for c in cond_types]) """ Load GPT snapshot, if any """ if gpt_ckpt is not None: prior, hp, codebook = load_prior(layer_ckpt=gpt_ckpt) best_loss = gpt_ckpt['best_loss'] optimizer = optim.Adam(prior.parameters(), lr=args.lr) optimizer.load_state_dict(gpt_ckpt['optimizer']) scheduler = lr_scheduler.CosineAnnealingLR(optimizer, args.total_iters) scheduler.load_state_dict(gpt_ckpt['scheduler']) scaler = GradScaler() scaler.load_state_dict(gpt_ckpt['scaler']) epoch_start = gpt_ckpt['epoch'] iteration_start = gpt_ckpt['iteration'] + 1 del gpt_ckpt else: # TODO: use (self_gen_n_embd*num_self_gen_in_use,) i.e. concat, or use below i.e. sum up y_gen? prior, hp = config_model( configs_str=args.cfg, shape=latent_shape, in_features=vq_hp['embedding_dim'], n_vocab=vq_hp['codes_per_book'], cond_types=cond_types, ) prior = prior.to(device) codebook = vqvae.codebook optimizer = optim.Adam(prior.parameters(), lr=args.lr) scheduler = lr_scheduler.CosineAnnealingLR(optimizer, args.total_iters) scaler = GradScaler() best_loss = float('inf') epoch_start = 0 iteration_start = 1 # find_unused_parameters needs to be False for gradient checkpointing to work prior = DistributedDataParallel(prior, device_ids=[rank], find_unused_parameters=False, broadcast_buffers=False) if is_root: for cond_net in prior.cond_nets: print('cond_net size with grad', sum(p.numel() for p in cond_net.parameters() if p.requires_grad)) print('cond_net size', sum(p.numel() for p in cond_net.parameters())) if is_root: if args.amp: print('Training with AMP') # to be saved to model checkpoints default_ckpt_dict = { 'dset_configs': dset_configs, 'cond_hp': cond_hp, 'hp': hp, 'vqvae_ckpt': vqvae_ckpt, } def get_ckpt_dict(**ckpt_dict): return {**ckpt_dict, **default_ckpt_dict} if is_root: total_parameters = sum([np.prod(p.shape) for p in prior.parameters() if p.requires_grad]) print('model size: prior params count with grads = {}'.format(total_parameters)) train_loader = InfDataLoader(train_loader, epoch_start) # training and validation, all in latent space train_for = functools.partial( train, train_loader=train_loader, inputs_fn=inputs_fn, prior=prior, optimizer=optimizer, scheduler=scheduler, scaler=scaler, writer=writer, is_root=is_root, size=size, device=device, ) validate_for = functools.partial( validate, test_loader=test_loader, inputs_fn=inputs_fn, prior=prior, writer=writer, is_root=is_root, size=size, device=device, ) # end to end sampling in pixel space sample_fn = functools.partial( sample, cond_hp=cond_hp, vae=vqvae, prior=prior, codebook=codebook, device=device, temperature=args.temperature, rank=rank, size=size, ) # takes in n_samples, batch, returns samples of size min(n_samples, batch_size * size (roughly, not verified)) # tensor (n, c, t, h, w) in [0, 1] save_samples_for = functools.partial( save_samples, sample_fn=sample_fn, loader=test_loader, writer=writer, is_root=is_root, size=size, ) iteration = iteration_start log_mem_usage, log_time_usage = True, True time_start = time.time() while iteration <= args.total_iters: train_loss, iteration = train_for(iteration=iteration) # average gen_loss if iteration % args.test_every == 0: test_loss = validate_for(iteration=iteration) if is_root: writer.add_scalar('test/gen_loss_gap', test_loss - train_loss, iteration * args.batch_size) is_best = test_loss < best_loss best_loss = min(test_loss, best_loss) ckpt_dict = get_ckpt_dict( epoch=train_loader.epoch, iteration=iteration, n_obs=iteration * args.batch_size, state_dict=prior.module.state_dict(), optimizer=optimizer.state_dict(), scheduler=scheduler.state_dict(), scaler=scaler.state_dict(), best_loss=best_loss, ) save_checkpoint(ckpt_dict, is_best=is_best, is_root=is_root, output_dir=args.output_dir) if iteration % args.generate_every == 0 and save_samples_for: save_samples_for(iteration=iteration) iteration += 1 if is_root: print(f'Final iteration: {iteration}, best loss: {best_loss}') print(f'Logs saved under {args.output_dir}') writer.close()
try: G.load_state_dict(torch.load('./saved_models/AEI_G_latest.pth', map_location=torch.device('cpu')), strict=False) D.load_state_dict(torch.load('./saved_models/AEI_D_latest.pth', map_location=torch.device('cpu')), strict=False) opt_G.load_state_dict( torch.load('./saved_models/AEI_optG_latest.pth', map_location=torch.device('cpu'))) opt_D.load_state_dict( torch.load('./saved_models/AEI_optD_latest.pth', map_location=torch.device('cpu'))) scaler.load_state_dict( torch.load('./saved_models/AEI_scaler_latest.pth', map_location=torch.device('cpu'))) except Exception as e: print(e) try: with open('./saved_models/AEI_niter.pkl', 'rb') as f: min_iter = pickle.load(f) except Exception as e: print(e) writer = SummaryWriter('runs/FaceShifterAEInet', purge_step=min_iter) TrainFaceSources = [ '/home/olivier/Images/FaceShifter/celeba-256/', '/home/olivier/Images/FaceShifter/Perso/', '/home/olivier/Images/FaceShifter/VGGFaceTrain/', '/home/olivier/Images/FaceShifter/FFHQ/',
def main(cfg: DictConfig) -> None: if cfg.trainer.print_torch_setup is True: print_torch_setup() if cfg.trainer.seed is not None: random.seed(cfg.trainer.seed) torch.manual_seed(cfg.trainer.seed) torch.backends.cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') assert torch.cuda.is_available(), 'This code requires a GPU to train' torch.backends.cudnn.benchmark = True assert cfg.trainer.output_dir, 'You need to specify an output directory' mkdir(cfg.trainer.output_dir) experiment_name = time.strftime("%Y%m%d-%H%M%S") print(f'The current experiment will be tracked as {experiment_name}') output_dir = os.path.join(cfg.trainer.output_dir, experiment_name) print(f'Results will be saved in {output_dir}') writer = SummaryWriter(output_dir) # this is just a workaround for now # hparams logging to a file and as text into tensorboard # it is certainly not perfect... :/ hparams = flatten_dict(OmegaConf.to_container(cfg, resolve=True)) hparams_as_str = [ str(k) + ' >>> ' + str(v) + '\n' for k, v in hparams.items() ] # TODO: this seems to not work properly! # writer.add_hparams(hparams, metric_dict={'acc': 1}, run_name=experiment_name) with open(os.path.join(output_dir, 'hparams.txt'), 'w', encoding='utf-8') as hparams_file: for line in hparams_as_str: hparams_file.write(line) writer.add_text('hparams', '\r\n'.join(hparams_as_str), global_step=0) device = torch.device(cfg.trainer.device) assert device.type == 'cuda', 'Only GPU based training is supported' dataset = instantiate(cfg.dataset.train) assert cfg.dataset.val_split is not None, 'Handling a separate validation set is not implemented as of now!' train_size = int((1 - cfg.dataset.val_split) * len(dataset)) val_size = len(dataset) - train_size train_dataset, val_dataset = torch.utils.data.random_split( dataset, [train_size, val_size]) train_sampler_weights = dataset.make_weights_for_dataset_sampling( train_dataset) sampler = WeightedRandomSampler( train_sampler_weights, num_samples=cfg.dataset.train_samples_per_epoch, replacement=True) train_collate_fn = dataset.get_collate_fn( mode='train', channels_last=cfg.trainer.channels_last) train_dataloader = instantiate(cfg.dataloader.train, dataset=train_dataset, collate_fn=train_collate_fn, sampler=sampler) val_collate_fn = dataset.get_collate_fn( mode='val', channels_last=cfg.trainer.channels_last) val_dataloader = instantiate(cfg.dataloader.val, dataset=val_dataset, collate_fn=val_collate_fn) # this handler moves a batch to the GPU as uint8, casts it to a float after transferring it # and normalizes the images to_device_handler = ToDeviceFunction(device=device, mean=cfg.dataset.mean, std=cfg.dataset.std) # the prefetch loader prefetches the next batch onto the GPU which makes up a couple # of percent in the training loop train_dataloader = PrefetchLoader(loader=train_dataloader, to_device_handler=to_device_handler) # val_dataloader = PrefetchLoader(loader=val_dataloader, # to_device_handler=to_device_handler) model = instantiate(cfg.models.model, device=device).to(device) if cfg.trainer.channels_last is True: model = model.to(memory_format=torch.channels_last) if cfg.trainer.anomaly_detection is True: torch.autograd.set_detect_anomaly(mode=True) params_to_optimize = [{ "params": [p for p in model.parameters() if p.requires_grad] }] optimizer = instantiate(cfg.optimizer, params_to_optimize) scaler = GradScaler(enabled=cfg.trainer.amp) if cfg.trainer.resume is not None: if os.path.isfile(cfg.trainer.resume): print("Trying to load checkpoint '{}'".format(cfg.trainer.resume)) if cfg.trainer.from_u2net_checkpoint is True: checkpoint = torch.load(cfg.trainer.resume, map_location=device) model.load_state_dict(checkpoint) else: checkpoint = torch.load(cfg.trainer.resume, map_location=device) model.load_state_dict(checkpoint['model']) if cfg.trainer.weights_only is False: cfg.trainer.start_epoch = checkpoint['epoch'] optimizer.load_state_dict(checkpoint['optimizer']) scaler.load_state_dict(checkpoint['scaler']) print( f'Loaded checkpoint {cfg.trainer.resume}. Resuming training at epoch {cfg.trainer.start_epoch}' ) else: warnings.warn(f'Checkpoint f{cfg.trainer.resume} not found!') print("Start training...") start_time = time.time() if cfg.trainer.dry_run is True: print("Doing dry run, running val on train dataset...") # validate_one_epoch(writer, model, train_dataloader, device, 0, cfg.trainer.print_freq) return for epoch in range(cfg.trainer.start_epoch, cfg.trainer.epochs): train_one_epoch(writer, device, model, optimizer, scaler, train_dataloader, epoch, cfg) # validate_one_epoch(writer, model, val_dataloader, epoch, cfg) checkpoint = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scaler': scaler.state_dict(), 'epoch': epoch, 'cfg': cfg } save_on_master(checkpoint, os.path.join(output_dir, 'model_{}.pth'.format(epoch))) save_on_master(checkpoint, os.path.join(output_dir, 'checkpoint.pth')) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
class DynamicIterBasedRunner(IterBasedRunner): """Dynamic Iterbased Runner. In this Dynamic Iterbased Runner, we will pass the ``reducer`` to the ``train_step`` so that the models can be trained with dynamic architecture. More details and clarification can be found in this [tutorial](docs/tutorials/ddp_train_gans.md). # noqa Args: is_dynamic_ddp (bool, optional): Whether to adopt the dynamic ddp. Defaults to False. pass_training_status (bool, optional): Whether to pass the training status. Defaults to False. fp16_loss_scaler (dict | None, optional): Config for fp16 GradScaler from ``torch.cuda.amp``. Defaults to None. use_apex_amp (bool, optional): Whether to use apex.amp to start mixed precision training. Defaults to False. """ def __init__(self, *args, is_dynamic_ddp=False, pass_training_status=False, fp16_loss_scaler=None, use_apex_amp=False, **kwargs): super().__init__(*args, **kwargs) if is_module_wrapper(self.model): _model = self.model.module else: _model = self.model self.is_dynamic_ddp = is_dynamic_ddp self.pass_training_status = pass_training_status # add a flag for checking if `self.optimizer` comes from `_model` self.optimizer_from_model = False # add support for optimizer is None. # sanity check for whether `_model` contains self-defined optimizer if hasattr(_model, 'optimizer'): assert self.optimizer is None, ( 'Runner and model cannot contain optimizer at the same time.') self.optimizer_from_model = True self.optimizer = _model.optimizer # add fp16 grad scaler, using pytorch official GradScaler self.with_fp16_grad_scaler = False if fp16_loss_scaler is not None: self.loss_scaler = GradScaler(**fp16_loss_scaler) self.with_fp16_grad_scaler = True mmcv.print_log('Use FP16 grad scaler in Training', 'mmgen') # flag to use amp in apex (NVIDIA) self.use_apex_amp = use_apex_amp def call_hook(self, fn_name): """Call all hooks. Args: fn_name (str): The function name in each hook to be called, such as "before_train_epoch". """ for hook in self._hooks: if hasattr(hook, fn_name): getattr(hook, fn_name)(self) def train(self, data_loader, **kwargs): if is_module_wrapper(self.model): _model = self.model.module else: _model = self.model self.model.train() self.mode = 'train' # check if self.optimizer from model and track it if self.optimizer_from_model: self.optimizer = _model.optimizer self.data_loader = data_loader self._epoch = data_loader.epoch self.call_hook('before_fetch_train_data') data_batch = next(self.data_loader) self.call_hook('before_train_iter') # prepare input args for train_step # running status if self.pass_training_status: running_status = dict(iteration=self.iter, epoch=self.epoch) kwargs['running_status'] = running_status # ddp reducer for tracking dynamic computational graph if self.is_dynamic_ddp: kwargs.update(dict(ddp_reducer=self.model.reducer)) if self.with_fp16_grad_scaler: kwargs.update(dict(loss_scaler=self.loss_scaler)) if self.use_apex_amp: kwargs.update(dict(use_apex_amp=True)) outputs = self.model.train_step(data_batch, self.optimizer, **kwargs) # the loss scaler should be updated after ``train_step`` if self.with_fp16_grad_scaler: self.loss_scaler.update() # further check for the cases where the optimizer is built in # `train_step`. if self.optimizer is None: if hasattr(_model, 'optimizer'): self.optimizer_from_model = True self.optimizer = _model.optimizer # check if self.optimizer from model and track it if self.optimizer_from_model: self.optimizer = _model.optimizer if not isinstance(outputs, dict): raise TypeError('model.train_step() must return a dict') if 'log_vars' in outputs: self.log_buffer.update(outputs['log_vars'], outputs['num_samples']) self.outputs = outputs self.call_hook('after_train_iter') self._inner_iter += 1 self._iter += 1 def run(self, data_loaders, workflow, max_iters=None, **kwargs): """Start running. Args: data_loaders (list[:obj:`DataLoader`]): Dataloaders for training and validation. workflow (list[tuple]): A list of (phase, iters) to specify the running order and iterations. E.g, [('train', 10000), ('val', 1000)] means running 10000 iterations for training and 1000 iterations for validation, iteratively. """ assert isinstance(data_loaders, list) assert mmcv.is_list_of(workflow, tuple) assert len(data_loaders) == len(workflow) if max_iters is not None: warnings.warn( 'setting max_iters in run is deprecated, ' 'please set max_iters in runner_config', DeprecationWarning) self._max_iters = max_iters assert self._max_iters is not None, ( 'max_iters must be specified during instantiation') work_dir = self.work_dir if self.work_dir is not None else 'NONE' self.logger.info('Start running, host: %s, work_dir: %s', get_host_info(), work_dir) self.logger.info('workflow: %s, max: %d iters', workflow, self._max_iters) self.call_hook('before_run') iter_loaders = [IterLoader(x, self) for x in data_loaders] self.call_hook('before_epoch') while self.iter < self._max_iters: for i, flow in enumerate(workflow): self._inner_iter = 0 mode, iters = flow if not isinstance(mode, str) or not hasattr(self, mode): raise ValueError( 'runner has no method named "{}" to run a workflow'. format(mode)) iter_runner = getattr(self, mode) for _ in range(iters): if mode == 'train' and self.iter >= self._max_iters: break iter_runner(iter_loaders[i], **kwargs) time.sleep(1) # wait for some hooks like loggers to finish self.call_hook('after_epoch') self.call_hook('after_run') def resume(self, checkpoint, resume_optimizer=True, resume_loss_scaler=True, map_location='default'): """Resume model from checkpoint. Args: checkpoint (str): Checkpoint to resume from. resume_optimizer (bool, optional): Whether resume the optimizer(s) if the checkpoint file includes optimizer(s). Default to True. resume_loss_scaler (bool, optional): Whether to resume the loss scaler (GradScaler) from ``torch.cuda.amp``. Defaults to True. map_location (str, optional): Same as :func:`torch.load`. Default to 'default'. """ if map_location == 'default': device_id = torch.cuda.current_device() checkpoint = self.load_checkpoint( checkpoint, map_location=lambda storage, loc: storage.cuda(device_id)) else: checkpoint = self.load_checkpoint(checkpoint, map_location=map_location) self._epoch = checkpoint['meta']['epoch'] self._iter = checkpoint['meta']['iter'] self._inner_iter = checkpoint['meta']['iter'] if 'optimizer' in checkpoint and resume_optimizer: if isinstance(self.optimizer, Optimizer): self.optimizer.load_state_dict(checkpoint['optimizer']) elif isinstance(self.optimizer, dict): for k in self.optimizer.keys(): self.optimizer[k].load_state_dict( checkpoint['optimizer'][k]) else: raise TypeError( 'Optimizer should be dict or torch.optim.Optimizer ' f'but got {type(self.optimizer)}') if 'loss_scaler' in checkpoint and resume_loss_scaler: self.loss_scaler.load_state_dict(checkpoint['loss_scaler']) if self.use_apex_amp: from apex import amp amp.load_state_dict(checkpoint['amp']) self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}') def save_checkpoint(self, out_dir, filename_tmpl='iter_{}.pth', meta=None, save_optimizer=True, create_symlink=True): """Save checkpoint to file. Args: out_dir (str): Directory to save checkpoint files. filename_tmpl (str, optional): Checkpoint file template. Defaults to 'iter_{}.pth'. meta (dict, optional): Metadata to be saved in checkpoint. Defaults to None. save_optimizer (bool, optional): Whether save optimizer. Defaults to True. create_symlink (bool, optional): Whether create symlink to the latest checkpoint file. Defaults to True. """ if meta is None: meta = dict(iter=self.iter + 1, epoch=self.epoch + 1) elif isinstance(meta, dict): meta.update(iter=self.iter + 1, epoch=self.epoch + 1) else: raise TypeError( f'meta should be a dict or None, but got {type(meta)}') if self.meta is not None: meta.update(self.meta) filename = filename_tmpl.format(self.iter + 1) filepath = osp.join(out_dir, filename) optimizer = self.optimizer if save_optimizer else None _loss_scaler = self.loss_scaler if self.with_fp16_grad_scaler else None save_checkpoint(self.model, filepath, optimizer=optimizer, loss_scaler=_loss_scaler, save_apex_amp=self.use_apex_amp, meta=meta) # in some environments, `os.symlink` is not supported, you may need to # set `create_symlink` to False if create_symlink: dst_file = osp.join(out_dir, 'latest.pth') if platform.system() != 'Windows': mmcv.symlink(filename, dst_file) else: shutil.copy(filepath, dst_file) def register_lr_hook(self, lr_config): if lr_config is None: return if isinstance(lr_config, dict): assert 'policy' in lr_config policy_type = lr_config.pop('policy') # If the type of policy is all in lower case, e.g., 'cyclic', # then its first letter will be capitalized, e.g., to be 'Cyclic'. # This is for the convenient usage of Lr updater. # Since this is not applicable for ` # CosineAnnealingLrUpdater`, # the string will not be changed if it contains capital letters. if policy_type == policy_type.lower(): policy_type = policy_type.title() hook_type = policy_type + 'LrUpdaterHook' lr_config['type'] = hook_type hook = mmcv.build_from_cfg(lr_config, HOOKS) else: hook = lr_config self.register_hook(hook)
class BaseTrainer: def __init__(self, dist, rank, config, resume, only_validation, model, loss_function, optimizer): self.color_tool = colorful self.color_tool.use_style("solarized") model = DistributedDataParallel(model.to(rank), device_ids=[rank]) self.model = model self.optimizer = optimizer self.loss_function = loss_function # DistributedDataParallel (DDP) self.rank = rank self.dist = dist # Automatic mixed precision (AMP) self.use_amp = config["meta"]["use_amp"] self.scaler = GradScaler(enabled=self.use_amp) # Acoustics self.acoustic_config = config["acoustics"] # Supported STFT n_fft = self.acoustic_config["n_fft"] hop_length = self.acoustic_config["hop_length"] win_length = self.acoustic_config["win_length"] self.torch_stft = partial(stft, n_fft=n_fft, hop_length=hop_length, win_length=win_length) self.torch_istft = partial(istft, n_fft=n_fft, hop_length=hop_length, win_length=win_length) self.librosa_stft = partial(librosa.stft, n_fft=n_fft, hop_length=hop_length, win_length=win_length) self.librosa_istft = partial(librosa.istft, hop_length=hop_length, win_length=win_length) # Trainer.train in the config self.train_config = config["trainer"]["train"] self.epochs = self.train_config["epochs"] self.save_checkpoint_interval = self.train_config[ "save_checkpoint_interval"] self.clip_grad_norm_value = self.train_config["clip_grad_norm_value"] assert self.save_checkpoint_interval >= 1, "Check the 'save_checkpoint_interval' parameter in the config. It should be large than one." # Trainer.validation in the config self.validation_config = config["trainer"]["validation"] self.validation_interval = self.validation_config[ "validation_interval"] self.save_max_metric_score = self.validation_config[ "save_max_metric_score"] assert self.validation_interval >= 1, "Check the 'validation_interval' parameter in the config. It should be large than one." # Trainer.visualization in the config self.visualization_config = config["trainer"]["visualization"] # In the 'train.py' file, if the 'resume' item is 'True', we will update the following args: self.start_epoch = 1 self.best_score = -np.inf if self.save_max_metric_score else np.inf self.save_dir = Path(config["meta"]["save_dir"]).expanduser().absolute( ) / config["meta"]["experiment_name"] self.checkpoints_dir = self.save_dir / "checkpoints" self.logs_dir = self.save_dir / "logs" if resume: self._resume_checkpoint() # Debug validation, which skips training self.only_validation = only_validation if config["meta"]["preloaded_model_path"]: self._preload_model(Path(config["preloaded_model_path"])) if self.rank == 0: prepare_empty_dir([self.checkpoints_dir, self.logs_dir], resume=resume) self.writer = SummaryWriter(self.logs_dir.as_posix(), max_queue=5, flush_secs=30) self.writer.add_text( tag="Configuration", text_string=f"<pre> \n{toml.dumps(config)} \n</pre>", global_step=1) print(self.color_tool.cyan("The configurations are as follows: ")) print(self.color_tool.cyan("=" * 40)) print(self.color_tool.cyan(toml.dumps(config)[:-1])) # except "\n" print(self.color_tool.cyan("=" * 40)) with open( (self.save_dir / f"{time.strftime('%Y-%m-%d %H:%M:%S')}.toml").as_posix(), "w") as handle: toml.dump(config, handle) self._print_networks([self.model]) def _preload_model(self, model_path): """ Preload model parameters (in "*.tar" format) at the start of experiment. Args: model_path (Path): The file path of the *.tar file """ model_path = model_path.expanduser().absolute() assert model_path.exists( ), f"The file {model_path.as_posix()} is not exist. please check path." model_checkpoint = torch.load(model_path.as_posix(), map_location="cpu") self.model.load_state_dict(model_checkpoint["model"], strict=False) self.model.to(self.rank) if self.rank == 0: print( f"Model preloaded successfully from {model_path.as_posix()}.") def _resume_checkpoint(self): """ Resume the experiment from the latest checkpoint. """ latest_model_path = self.checkpoints_dir.expanduser().absolute( ) / "latest_model.tar" assert latest_model_path.exists( ), f"{latest_model_path} does not exist, can not load latest checkpoint." # Make sure all processes (GPUs) do not start loading before the saving is finished. # see https://stackoverflow.com/questions/59760328/how-does-torch-distributed-barrier-work self.dist.barrier() # Load it on the CPU and later use .to(device) on the model # Maybe slightly slow than use map_location="cuda:<...>" # https://stackoverflow.com/questions/61642619/pytorch-distributed-data-parallel-confusion checkpoint = torch.load(latest_model_path.as_posix(), map_location="cpu") self.start_epoch = checkpoint["epoch"] + 1 self.best_score = checkpoint["best_score"] self.optimizer.load_state_dict(checkpoint["optimizer"]) self.scaler.load_state_dict(checkpoint["scaler"]) if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): self.model.module.load_state_dict(checkpoint["model"]) else: self.model.load_state_dict(checkpoint["model"]) # self.model.to(self.rank) if self.rank == 0: print( f"Model checkpoint loaded. Training will begin at {self.start_epoch} epoch." ) def _save_checkpoint(self, epoch, is_best_epoch=False): """ Save checkpoint to "<save_dir>/<config name>/checkpoints" directory, which consists of: - epoch - best metric score in historical epochs - optimizer parameters - model parameters Args: is_best_epoch (bool): In the current epoch, if the model get a best metric score (is_best_epoch=True), the checkpoint of model will be saved as "<save_dir>/checkpoints/best_model.tar". """ print(f"\t Saving {epoch} epoch model checkpoint...") state_dict = { "epoch": epoch, "best_score": self.best_score, "optimizer": self.optimizer.state_dict(), "scaler": self.scaler.state_dict() } if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): state_dict["model"] = self.model.module.state_dict() else: state_dict["model"] = self.model.state_dict() # Saved in "latest_model.tar" # Contains all checkpoint information, including the optimizer parameters, the model parameters, etc. # New checkpoint will overwrite the older one. torch.save(state_dict, (self.checkpoints_dir / "latest_model.tar").as_posix()) # "model_{epoch_number}.pth" # Contains only model. torch.save(state_dict["model"], (self.checkpoints_dir / f"model_{str(epoch).zfill(4)}.pth").as_posix()) # If the model get a best metric score (means "is_best_epoch=True") in the current epoch, # the model checkpoint will be saved as "best_model.tar" # The newer best-scored checkpoint will overwrite the older one. if is_best_epoch: print( self.color_tool.red( f"\t Found a best score in the {epoch} epoch, saving...")) torch.save(state_dict, (self.checkpoints_dir / "best_model.tar").as_posix()) def _is_best_epoch(self, score, save_max_metric_score=True): """ Check if the current model got the best metric score """ if save_max_metric_score and score >= self.best_score: self.best_score = score return True elif not save_max_metric_score and score <= self.best_score: self.best_score = score return True else: return False @staticmethod def _print_networks(models: list): print( f"This project contains {len(models)} models, the number of the parameters is: " ) params_of_all_networks = 0 for idx, model in enumerate(models, start=1): params_of_network = 0 for param in model.parameters(): params_of_network += param.numel() print(f"\tNetwork {idx}: {params_of_network / 1e6} million.") params_of_all_networks += params_of_network print( f"The amount of parameters in the project is {params_of_all_networks / 1e6} million." ) def _set_models_to_train_mode(self): self.model.train() def _set_models_to_eval_mode(self): self.model.eval() def spec_audio_visualization(self, noisy, enhanced, clean, name, epoch, mark=""): self.writer.add_audio(f"{mark}_Speech/{name}_Noisy", noisy, epoch, sample_rate=16000) self.writer.add_audio(f"{mark}_Speech/{name}_Enhanced", enhanced, epoch, sample_rate=16000) self.writer.add_audio(f"{mark}_Speech/{name}_Clean", clean, epoch, sample_rate=16000) # Visualize the spectrogram of noisy speech, clean speech, and enhanced speech noisy_mag, _ = librosa.magphase( self.librosa_stft(noisy, n_fft=320, hop_length=160, win_length=320)) enhanced_mag, _ = librosa.magphase( self.librosa_stft(enhanced, n_fft=320, hop_length=160, win_length=320)) clean_mag, _ = librosa.magphase( self.librosa_stft(clean, n_fft=320, hop_length=160, win_length=320)) fig, axes = plt.subplots(3, 1, figsize=(6, 6)) for k, mag in enumerate([noisy_mag, enhanced_mag, clean_mag]): axes[k].set_title(f"mean: {np.mean(mag):.3f}, " f"std: {np.std(mag):.3f}, " f"max: {np.max(mag):.3f}, " f"min: {np.min(mag):.3f}") librosa.display.specshow(librosa.amplitude_to_db(mag), cmap="magma", y_axis="linear", ax=axes[k], sr=16000) plt.tight_layout() self.writer.add_figure(f"{mark}_Spectrogram/{name}", fig, epoch) def metrics_visualization(self, noisy_list, clean_list, enhanced_list, metrics_list, epoch, num_workers=10, mark=""): """ Get metrics on validation dataset by paralleling. Notes: 1. You can register other metrics, but STOI and WB_PESQ metrics must be existence. These two metrics are used for checking if the current epoch is a "best epoch." 2. If you want to use a new metric, you must register it in "util.metrics" file. """ assert "STOI" in metrics_list and "WB_PESQ" in metrics_list, "'STOI' and 'WB_PESQ' must be existence." # Check if the metric is registered in "util.metrics" file. for i in metrics_list: assert i in metrics.REGISTERED_METRICS.keys( ), f"{i} is not registered, please check 'util.metrics' file." stoi_mean = 0.0 wb_pesq_mean = 0.0 for metric_name in metrics_list: score_on_noisy = Parallel(n_jobs=num_workers)( delayed(metrics.REGISTERED_METRICS[metric_name])(ref, est) for ref, est in zip(clean_list, noisy_list)) score_on_enhanced = Parallel(n_jobs=num_workers)( delayed(metrics.REGISTERED_METRICS[metric_name])(ref, est) for ref, est in zip(clean_list, enhanced_list)) # Add the mean value of the metric to tensorboard mean_score_on_noisy = np.mean(score_on_noisy) mean_score_on_enhanced = np.mean(score_on_enhanced) self.writer.add_scalars(f"{mark}_Validation/{metric_name}", { "Noisy": mean_score_on_noisy, "Enhanced": mean_score_on_enhanced }, epoch) if metric_name == "STOI": stoi_mean = mean_score_on_enhanced if metric_name == "WB_PESQ": wb_pesq_mean = transform_pesq_range(mean_score_on_enhanced) return (stoi_mean + wb_pesq_mean) / 2 def train(self): for epoch in range(self.start_epoch, self.epochs + 1): if self.rank == 0: print( self.color_tool.yellow( f"{'=' * 15} {epoch} epoch {'=' * 15}")) print("[0 seconds] Begin training...") # [debug validation] Only run validation (only use the first GPU (process)) # inference + calculating metrics + saving checkpoints if self.only_validation and self.rank == 0: self._set_models_to_eval_mode() metric_score = self._validation_epoch(epoch) if self._is_best_epoch( metric_score, save_max_metric_score=self.save_max_metric_score): self._save_checkpoint(epoch, is_best_epoch=True) # Skip the following regular training, saving checkpoints, and validation continue # Regular training timer = ExecutionTime() self._set_models_to_train_mode() self._train_epoch(epoch) # Regular save checkpoints if self.rank == 0 and self.save_checkpoint_interval != 0 and ( epoch % self.save_checkpoint_interval == 0): self._save_checkpoint(epoch) # Regular validation if self.rank == 0 and (epoch % self.validation_interval == 0): print( f"[{timer.duration()} seconds] Training has finished, validation is in progress..." ) self._set_models_to_eval_mode() metric_score = self._validation_epoch(epoch) if self._is_best_epoch( metric_score, save_max_metric_score=self.save_max_metric_score): self._save_checkpoint(epoch, is_best_epoch=True) print(f"[{timer.duration()} seconds] This epoch is finished.") def _train_epoch(self, epoch): raise NotImplementedError def _validation_epoch(self, epoch): raise NotImplementedError
def train(model: Model, state: dict, train_data_path: str, train_rgb_json: str, val_data_path: str, val_rgb_json: str, transform_file: str, growing_parameters: dict, lr: float, iterations: int, val_iterations: int, verbose: bool, train_segment_masks_path: str = '', val_segment_masks_path: str = '', lambda_ccl=0.0, loss_type='L2', ccl_version='linear', alpha=5, gamma=.5, regularization_l2: float = 0., warmup=5000, milestones=[], optimizer_name: str = 'sgd', print_every: int = 250, debug=False): model.train() torch.backends.cudnn.benchmark = True if debug: print_every = 10 sparse_growing_parameters = load_growing_parameters(growing_parameters) filled_growing_parameters = fill_growing_parameters( sparse_growing_parameters, iterations) assert os.path.isfile(transform_file) sys.path.insert(0, os.path.dirname(transform_file)) transforms = __import__( os.path.splitext(os.path.basename(transform_file))[0]) model_dir = os.path.dirname(state['path']) writer = SummaryWriter(log_dir=os.path.join(model_dir, 'logs')) if loss_type == 'L2': criterion = L2Loss(weighted=False) elif loss_type == 'L2W': criterion = L2Loss(weighted=True, alpha=alpha, gamma=gamma) elif loss_type == 'L1': criterion = L1Loss(weighted=False) elif loss_type == 'L1W': criterion = L1Loss(weighted=True, alpha=alpha, gamma=gamma) elif loss_type == 'L2+CCL': criterion = L2CCLoss(lambda_ccl=lambda_ccl, ccl_version=ccl_version) elif loss_type == 'L2W+CCL': criterion = L2CCLoss(lambda_ccl=lambda_ccl, ccl_version=ccl_version, weighted=True, alpha=alpha, gamma=gamma) elif loss_type == 'L2+CCL-gt': criterion = L2CCLoss(lambda_ccl=lambda_ccl, ccl_version=ccl_version, ccl_target='gt', weighted=False) elif loss_type == 'L2W+CCL-gt': criterion = L2CCLoss(lambda_ccl=lambda_ccl, ccl_version=ccl_version, ccl_target='gt', weighted=True, alpha=alpha, gamma=gamma) elif loss_type == 'L1+CCL': criterion = L1CCLoss(lambda_ccl=lambda_ccl, ccl_version=ccl_version) elif loss_type == 'L1W+CCL': criterion = L1CCLoss(lambda_ccl=lambda_ccl, ccl_version=ccl_version, weighted=True, alpha=alpha, gamma=gamma) elif loss_type == 'L1+CCL-gt': criterion = L1CCLoss(lambda_ccl=lambda_ccl, ccl_version=ccl_version, ccl_target='gt', weighted=False) elif loss_type == 'L1W+CCL-gt': criterion = L1CCLoss(lambda_ccl=lambda_ccl, ccl_version=ccl_version, ccl_target='gt', weighted=True, alpha=alpha, gamma=gamma) else: raise NotImplementedError() if torch.cuda.is_available(): model = model.cuda() criterion = criterion.cuda() if optimizer_name == 'adam': optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=regularization_l2) elif optimizer_name == 'sgd': optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=regularization_l2, momentum=0.9) else: raise NotImplementedError(f'Optimizer {optimizer_name} not available') if 'optimizer' in state: print('loading optimizer...') optimizer.load_state_dict(state['optimizer']) scaler = GradScaler(enabled=True) if 'scaler' in state: print('loading scaler...') scaler.load_state_dict(state['scaler']) def schedule(train_iter): if warmup and train_iter <= warmup: return 0.9 * train_iter / warmup + 0.1 return 0.1**len([m for m in milestones if m <= train_iter]) scheduler = LambdaLR(optimizer, schedule) if 'scheduler' in state: print('loading scheduler...') scheduler.load_state_dict(state['scheduler']) iteration = state.get('iteration', 0) if iteration >= iterations: print('Training already done.') return if train_segment_masks_path or val_segment_masks_path: trainset = ImagenetColorSegmentData(train_data_path, train_segment_masks_path, rgb_json=train_rgb_json, transform=None, transform_l=to_tensor_l, transform_ab=to_tensor_ab) testset = ImagenetColorSegmentData( val_data_path, val_segment_masks_path, rgb_json=val_rgb_json, transform=transforms.get_val_transform(1024), transform_l=to_tensor_l, transform_ab=to_tensor_ab) else: trainset = ImagenetData(train_data_path, rgb_json=train_rgb_json, transform=None, transform_l=to_tensor_l, transform_ab=to_tensor_ab) testset = ImagenetData(val_data_path, rgb_json=val_rgb_json, transform=transforms.get_val_transform(1024), transform_l=to_tensor_l, transform_ab=to_tensor_ab) trainset_infer = ImagenetData(train_data_path, rgb_json=train_rgb_json, transform=transforms.get_val_transform(1024), transform_l=to_tensor_l, transform_ab=to_tensor_ab, training=False) testset_infer = ImagenetData(val_data_path, rgb_json=val_rgb_json, transform=transforms.get_val_transform(1024), transform_l=to_tensor_l, transform_ab=to_tensor_ab, training=False) sampler = SavableShuffleSampler(trainset, shuffle=not debug) if 'sampler' in state: print('loading sampler...') sampler.load_state_dict(state['sampler']) if len(sampler) > len(trainset): sampler = SavableShuffleSampler(trainset, shuffle=not debug) print('recreate the sampler, trainset changed...') print(f' Loss: {loss_type}') print(criterion) print( f' Optimizer: {optimizer.__class__.__name__} (LR:{optimizer.param_groups[0]["lr"]:.6f})' ) print(f' Iteration: {iteration}/{iterations}') print(f' Warmup: {warmup}') print(f' Milestones: {milestones}') print(f' Growing: {sparse_growing_parameters}') print(f' Traindata: {len(trainset)} images') print(f' Testdata: {len(testset)} images') print(f' Sampler idx: {sampler.index}') print(f'Current step: {scheduler._step_count}') batch_size, input_size = filled_growing_parameters[iteration] trainset.transform = transforms.get_transform(input_size[0]) trainloader = get_trainloader(trainset, batch_size, sampler) running_psnr, img_per_sec = 0.0, 0.0 running_loss, avg_running_loss = defaultdict(float), defaultdict(float) tic = time.time() changed_batch_size = True psnr = PSNR() pbar = tqdm(total=iterations, initial=iteration) if iteration == 0: for name, param in model.named_parameters(): writer.add_histogram(name, param, global_step=iteration) while iteration < iterations: loss_str = ' - '.join( [f'{key}: {val:.5f} ' for key, val in avg_running_loss.items()]) pbar.set_description( f'[Ep: {sampler.epoch} | B: {batch_size} | Im: {input_size[0]}x{input_size[1]}] loss: {loss_str} - {img_per_sec:.2f} img/s' ) for data in trainloader: if iteration in sparse_growing_parameters and not changed_batch_size: # change batch size and input size batch_size, input_size = sparse_growing_parameters[iteration] trainset.transform = transforms.get_transform(input_size[0]) # recreate the loader, otherwise the transform is not propagated in multiprocessing to the workers trainloader = get_trainloader(trainset, batch_size, sampler) changed_batch_size = True break else: changed_batch_size = False if torch.cuda.is_available(): data = tuple([el.cuda(non_blocking=True) for el in data]) # get data if len(data) == 4: inputs, labels, segment_masks, _ = data else: inputs, labels = data # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize with autocast(): outputs = model(inputs) crit_labels = [labels, segment_masks ] if train_segment_masks_path else [labels] loss, loss_dict = criterion(outputs, *crit_labels) _psnr = psnr(outputs, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() scheduler.step() del outputs del inputs del labels del data # print statistics for k, v, in loss_dict.items(): running_loss[k] += v.item() running_psnr += _psnr.item() iteration += 1 if iteration % print_every == 0 or iteration == iterations: img_per_sec = print_every * batch_size / (time.time() - tic) for k, v in running_loss.items(): avg_running_loss[k] = running_loss[k] / print_every writer.add_scalar(f'train/{k}', avg_running_loss[k], global_step=iteration) avg_running_psnr = running_psnr / print_every writer.add_scalar('train/PSNR', avg_running_psnr, global_step=iteration) writer.add_scalar('Performance/Images per second', img_per_sec, global_step=iteration) writer.add_scalar('Learning rate', optimizer.param_groups[0]['lr'], global_step=iteration) if loss_type in ['L1+CCL', 'L2+CCL']: writer.add_scalar('Parameters/lambda CCL', lambda_ccl, global_step=iteration) loss_str = ' - '.join([ f'{key}: {val:.5} ' for key, val in avg_running_loss.items() ]) pbar.set_description( f'[Ep: {sampler.epoch} | B: {batch_size} | Im: {input_size[0]}x{input_size[1]}] loss: {loss_str} - {img_per_sec:.2f} img/s' ) running_loss = defaultdict(float) running_psnr = 0.0 state.update({ 'iteration': iteration, 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'scaler': scaler.state_dict(), 'sampler': sampler.state_dict() }) model.save(state, iteration) delete_older_then_n(state['path'], 10) tic = time.time() if iteration == iterations or iteration % val_iterations == 0: # run validation torch.backends.cudnn.benchmark = False model = model.eval() test_loader = DataLoader(testset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True, prefetch_factor=1) with torch.no_grad(): metric_results = get_validation_metrics( test_loader, model, criterion, ccl_version=ccl_version) for k, v in metric_results.items(): writer.add_scalar(f'validation/{k}', v, global_step=iteration) # images from validation predicted_images = infer( model=model, dataset=testset_infer, target_path=os.path.join(model_dir, f'predictions-{iteration}'), batch_size=1, img_limit=20, transform=transforms.get_val_transform(1024), debug=True, tensorboard=True) for i, img in enumerate(predicted_images): writer.add_image(f'example-{i}', img, global_step=iteration, dataformats='HWC') # images from training predicted_images = infer( model=model, dataset=trainset_infer, target_path=os.path.join( model_dir, f'predictions-training-{iteration}'), batch_size=1, img_limit=20, transform=transforms.get_val_transform(1024), debug=True, tensorboard=True) for i, img in enumerate(predicted_images): writer.add_image(f'example-train-{i}', img, global_step=iteration, dataformats='HWC') for name, param in model.named_parameters(): writer.add_histogram(name, param, global_step=iteration) model = model.train() torch.backends.cudnn.benchmark = True tic = time.time() pbar.update(1) if iteration == iterations: break pbar.close() writer.close() print('Finished Training')
class Trainer(object): def __init__( self, diffusion_model, folder, *, ema_decay = 0.995, image_size = 128, train_batch_size = 32, train_lr = 2e-5, train_num_steps = 100000, gradient_accumulate_every = 2, amp = False, step_start_ema = 2000, update_ema_every = 10, save_and_sample_every = 1000, results_folder = './results' ): super().__init__() self.model = diffusion_model self.ema = EMA(ema_decay) self.ema_model = copy.deepcopy(self.model) self.update_ema_every = update_ema_every self.step_start_ema = step_start_ema self.save_and_sample_every = save_and_sample_every self.batch_size = train_batch_size self.image_size = diffusion_model.image_size self.gradient_accumulate_every = gradient_accumulate_every self.train_num_steps = train_num_steps self.ds = Dataset(folder, image_size) self.dl = cycle(data.DataLoader(self.ds, batch_size = train_batch_size, shuffle=True, pin_memory=True)) self.opt = Adam(diffusion_model.parameters(), lr=train_lr) self.step = 0 self.amp = amp self.scaler = GradScaler(enabled = amp) self.results_folder = Path(results_folder) self.results_folder.mkdir(exist_ok = True) self.reset_parameters() def reset_parameters(self): self.ema_model.load_state_dict(self.model.state_dict()) def step_ema(self): if self.step < self.step_start_ema: self.reset_parameters() return self.ema.update_model_average(self.ema_model, self.model) def save(self, milestone): data = { 'step': self.step, 'model': self.model.state_dict(), 'ema': self.ema_model.state_dict(), 'scaler': self.scaler.state_dict() } torch.save(data, str(self.results_folder / f'model-{milestone}.pt')) def load(self, milestone): data = torch.load(str(self.results_folder / f'model-{milestone}.pt')) self.step = data['step'] self.model.load_state_dict(data['model']) self.ema_model.load_state_dict(data['ema']) self.scaler.load_state_dict(data['scaler']) def train(self): while self.step < self.train_num_steps: for i in range(self.gradient_accumulate_every): data = next(self.dl).cuda() with autocast(enabled = self.amp): loss = self.model(data) self.scaler.scale(loss / self.gradient_accumulate_every).backward() print(f'{self.step}: {loss.item()}') self.scaler.step(self.opt) self.scaler.update() self.opt.zero_grad() if self.step % self.update_ema_every == 0: self.step_ema() if self.step != 0 and self.step % self.save_and_sample_every == 0: milestone = self.step // self.save_and_sample_every batches = num_to_groups(36, self.batch_size) all_images_list = list(map(lambda n: self.ema_model.sample(batch_size=n), batches)) all_images = torch.cat(all_images_list, dim=0) all_images = (all_images + 1) * 0.5 utils.save_image(all_images, str(self.results_folder / f'sample-{milestone}.png'), nrow = 6) self.save(milestone) self.step += 1 print('training completed')
class Trainer(): def __init__(self, name='default', results_dir='results', models_dir='models', base_dir='./', optimizer="adam", latent_dim=256, image_size=128, fmap_max=512, transparent=False, greyscale=False, batch_size=4, gp_weight=10, gradient_accumulate_every=1, attn_res_layers=[], disc_output_size=5, antialias=False, lr=2e-4, lr_mlp=1., ttur_mult=1., save_every=1000, evaluate_every=1000, trunc_psi=0.6, aug_prob=None, aug_types=['translation', 'cutout'], dataset_aug_prob=0., calculate_fid_every=None, is_ddp=False, rank=0, world_size=1, log=False, amp=False, *args, **kwargs): self.GAN_params = [args, kwargs] self.GAN = None self.name = name base_dir = Path(base_dir) self.base_dir = base_dir self.results_dir = base_dir / results_dir self.models_dir = base_dir / models_dir self.config_path = self.models_dir / name / '.config.json' assert is_power_of_two( image_size ), 'image size must be a power of 2 (64, 128, 256, 512, 1024)' assert all( map(is_power_of_two, attn_res_layers) ), 'resolution layers of attention must all be powers of 2 (16, 32, 64, 128, 256, 512)' self.optimizer = optimizer self.latent_dim = latent_dim self.image_size = image_size self.fmap_max = fmap_max self.transparent = transparent self.greyscale = greyscale assert (int(self.transparent) + int(self.greyscale) ) < 2, 'you can only set either transparency or greyscale' self.aug_prob = aug_prob self.aug_types = aug_types self.lr = lr self.ttur_mult = ttur_mult self.batch_size = batch_size self.gradient_accumulate_every = gradient_accumulate_every self.gp_weight = gp_weight self.evaluate_every = evaluate_every self.save_every = save_every self.steps = 0 self.generator_top_k_gamma = 0.99 self.generator_top_k_frac = 0.5 self.attn_res_layers = attn_res_layers self.disc_output_size = disc_output_size self.antialias = antialias self.d_loss = 0 self.g_loss = 0 self.last_gp_loss = None self.last_recon_loss = None self.last_fid = None self.init_folders() self.loader = None self.dataset_aug_prob = dataset_aug_prob self.calculate_fid_every = calculate_fid_every self.is_ddp = is_ddp self.is_main = rank == 0 self.rank = rank self.world_size = world_size self.syncbatchnorm = is_ddp self.amp = amp self.G_scaler = GradScaler(enabled=self.amp) self.D_scaler = GradScaler(enabled=self.amp) @property def image_extension(self): return 'jpg' if not self.transparent else 'png' @property def checkpoint_num(self): return floor(self.steps // self.save_every) def init_GAN(self): args, kwargs = self.GAN_params # set some global variables before instantiating GAN global norm_class global Blur norm_class = nn.SyncBatchNorm if self.syncbatchnorm else nn.BatchNorm2d Blur = nn.Identity if not self.antialias else Blur # handle bugs when # switching from multi-gpu back to single gpu if self.syncbatchnorm and not self.is_ddp: import torch.distributed as dist os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' dist.init_process_group('nccl', rank=0, world_size=1) # instantiate GAN self.GAN = LightweightGAN(optimizer=self.optimizer, lr=self.lr, latent_dim=self.latent_dim, attn_res_layers=self.attn_res_layers, image_size=self.image_size, ttur_mult=self.ttur_mult, fmap_max=self.fmap_max, disc_output_size=self.disc_output_size, transparent=self.transparent, greyscale=self.greyscale, rank=self.rank, *args, **kwargs) if self.is_ddp: ddp_kwargs = { 'device_ids': [self.rank], 'output_device': self.rank, 'find_unused_parameters': True } self.G_ddp = DDP(self.GAN.G, **ddp_kwargs) self.D_ddp = DDP(self.GAN.D, **ddp_kwargs) self.D_aug_ddp = DDP(self.GAN.D_aug, **ddp_kwargs) def write_config(self): self.config_path.write_text(json.dumps(self.config())) def load_config(self): config = self.config( ) if not self.config_path.exists() else json.loads( self.config_path.read_text()) self.image_size = config['image_size'] self.transparent = config['transparent'] self.syncbatchnorm = config['syncbatchnorm'] self.disc_output_size = config['disc_output_size'] self.greyscale = config.pop('greyscale', False) self.attn_res_layers = config.pop('attn_res_layers', []) self.optimizer = config.pop('optimizer', 'adam') self.fmap_max = config.pop('fmap_max', 512) del self.GAN self.init_GAN() def config(self): return { 'image_size': self.image_size, 'transparent': self.transparent, 'greyscale': self.greyscale, 'syncbatchnorm': self.syncbatchnorm, 'disc_output_size': self.disc_output_size, 'optimizer': self.optimizer, 'attn_res_layers': self.attn_res_layers } def set_data_src(self, folder): self.dataset = ImageDataset(folder, self.image_size, transparent=self.transparent, greyscale=self.greyscale, aug_prob=self.dataset_aug_prob) sampler = DistributedSampler(self.dataset, rank=self.rank, num_replicas=self.world_size, shuffle=True) if self.is_ddp else None dataloader = DataLoader( self.dataset, num_workers=math.ceil(NUM_CORES / self.world_size), batch_size=math.ceil(self.batch_size / self.world_size), sampler=sampler, shuffle=not self.is_ddp, drop_last=True, pin_memory=True) self.loader = cycle(dataloader) # auto set augmentation prob for user if dataset is detected to be low num_samples = len(self.dataset) if not exists(self.aug_prob) and num_samples < 1e5: self.aug_prob = min(0.5, (1e5 - num_samples) * 3e-6) print( f'autosetting augmentation probability to {round(self.aug_prob * 100)}%' ) def train(self): assert exists( self.loader ), 'You must first initialize the data source with `.set_data_src(<folder of images>)`' device = torch.device(f'cuda:{self.rank}') if not exists(self.GAN): self.init_GAN() self.GAN.train() total_disc_loss = torch.zeros([], device=device) total_gen_loss = torch.zeros([], device=device) batch_size = math.ceil(self.batch_size / self.world_size) image_size = self.GAN.image_size latent_dim = self.GAN.latent_dim aug_prob = default(self.aug_prob, 0) aug_types = self.aug_types aug_kwargs = {'prob': aug_prob, 'types': aug_types} G = self.GAN.G if not self.is_ddp else self.G_ddp D = self.GAN.D if not self.is_ddp else self.D_ddp D_aug = self.GAN.D_aug if not self.is_ddp else self.D_aug_ddp apply_gradient_penalty = self.steps % 4 == 0 # amp related contexts and functions amp_context = autocast if self.amp else null_context # train discriminator self.GAN.D_opt.zero_grad() for i in gradient_accumulate_contexts(self.gradient_accumulate_every, self.is_ddp, ddps=[D_aug, G]): latents = torch.randn(batch_size, latent_dim).cuda(self.rank) image_batch = next(self.loader).cuda(self.rank) image_batch.requires_grad_() with amp_context(): with torch.no_grad(): generated_images = G(latents) fake_output, fake_output_32x32, _ = D_aug(generated_images, detach=True, **aug_kwargs) real_output, real_output_32x32, real_aux_loss = D_aug( image_batch, calc_aux_loss=True, **aug_kwargs) real_output_loss = real_output fake_output_loss = fake_output divergence = hinge_loss(real_output_loss, fake_output_loss) divergence_32x32 = hinge_loss(real_output_32x32, fake_output_32x32) disc_loss = divergence + divergence_32x32 aux_loss = real_aux_loss disc_loss = disc_loss + aux_loss if apply_gradient_penalty: outputs = [real_output, real_output_32x32] outputs = list(map(self.D_scaler.scale, outputs)) if self.amp else outputs scaled_gradients = torch_grad( outputs=outputs, inputs=image_batch, grad_outputs=list( map( lambda t: torch.ones(t.size(), device=image_batch.device), outputs)), create_graph=True, retain_graph=True, only_inputs=True)[0] inv_scale = (1. / self.D_scaler.get_scale()) if self.amp else 1. gradients = scaled_gradients * inv_scale with amp_context(): gradients = gradients.reshape(batch_size, -1) gp = self.gp_weight * ( (gradients.norm(2, dim=1) - 1)**2).mean() if not torch.isnan(gp): disc_loss = disc_loss + gp self.last_gp_loss = gp.clone().detach().item() with amp_context(): disc_loss = disc_loss / self.gradient_accumulate_every disc_loss.register_hook(raise_if_nan) self.D_scaler.scale(disc_loss).backward() total_disc_loss += divergence self.last_recon_loss = aux_loss.item() self.d_loss = float(total_disc_loss.item() / self.gradient_accumulate_every) self.D_scaler.step(self.GAN.D_opt) self.D_scaler.update() # train generator self.GAN.G_opt.zero_grad() for i in gradient_accumulate_contexts(self.gradient_accumulate_every, self.is_ddp, ddps=[G, D_aug]): latents = torch.randn(batch_size, latent_dim).cuda(self.rank) with amp_context(): generated_images = G(latents) fake_output, fake_output_32x32, _ = D_aug( generated_images, **aug_kwargs) fake_output_loss = fake_output.mean( dim=1) + fake_output_32x32.mean(dim=1) epochs = (self.steps * batch_size * self.gradient_accumulate_every) / len(self.dataset) k_frac = max(self.generator_top_k_gamma**epochs, self.generator_top_k_frac) k = math.ceil(batch_size * k_frac) if k != batch_size: fake_output_loss, _ = fake_output_loss.topk(k=k, largest=False) loss = fake_output_loss.mean() gen_loss = loss gen_loss = gen_loss / self.gradient_accumulate_every gen_loss.register_hook(raise_if_nan) self.G_scaler.scale(gen_loss).backward() total_gen_loss += loss self.g_loss = float(total_gen_loss.item() / self.gradient_accumulate_every) self.G_scaler.step(self.GAN.G_opt) self.G_scaler.update() # calculate moving averages if self.is_main and self.steps % 10 == 0 and self.steps > 20000: self.GAN.EMA() if self.is_main and self.steps <= 25000 and self.steps % 1000 == 2: self.GAN.reset_parameter_averaging() # save from NaN errors if any(torch.isnan(l) for l in (total_gen_loss, total_disc_loss)): print( f'NaN detected for generator or discriminator. Loading from checkpoint #{self.checkpoint_num}' ) self.load(self.checkpoint_num) raise NanException del total_disc_loss del total_gen_loss # periodically save results if self.is_main: if self.steps % self.save_every == 0: self.save(self.checkpoint_num) if self.steps % self.evaluate_every == 0 or ( self.steps % 100 == 0 and self.steps < 20000): self.evaluate(floor(self.steps / self.evaluate_every)) if exists( self.calculate_fid_every ) and self.steps % self.calculate_fid_every == 0 and self.steps != 0: num_batches = math.ceil(CALC_FID_NUM_IMAGES / self.batch_size) fid = self.calculate_fid(num_batches) self.last_fid = fid with open( str(self.results_dir / self.name / f'fid_scores.txt'), 'a') as f: f.write(f'{self.steps},{fid}\n') self.steps += 1 @torch.no_grad() def evaluate(self, num=0, num_image_tiles=8, trunc=1.0): self.GAN.eval() ext = self.image_extension num_rows = num_image_tiles latent_dim = self.GAN.latent_dim image_size = self.GAN.image_size # latents and noise latents = torch.randn((num_rows**2, latent_dim)).cuda(self.rank) # regular generated_images = self.generate_truncated(self.GAN.G, latents) torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}.{ext}'), nrow=num_rows) # moving averages generated_images = self.generate_truncated(self.GAN.GE, latents) torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}-ema.{ext}'), nrow=num_rows) @torch.no_grad() def calculate_fid(self, num_batches): from pytorch_fid import fid_score torch.cuda.empty_cache() real_path = str(self.results_dir / self.name / 'fid_real') + '/' fake_path = str(self.results_dir / self.name / 'fid_fake') + '/' # remove any existing files used for fid calculation and recreate directories rmtree(real_path, ignore_errors=True) rmtree(fake_path, ignore_errors=True) os.makedirs(real_path) os.makedirs(fake_path) for batch_num in tqdm(range(num_batches), desc='calculating FID - saving reals'): real_batch = next(self.loader) for k in range(real_batch.size(0)): torchvision.utils.save_image( real_batch[k, :, :, :], real_path + '{}.png'.format(k + batch_num * self.batch_size)) # generate a bunch of fake images in results / name / fid_fake self.GAN.eval() ext = self.image_extension latent_dim = self.GAN.latent_dim image_size = self.GAN.image_size for batch_num in tqdm(range(num_batches), desc='calculating FID - saving generated'): # latents and noise latents = torch.randn(self.batch_size, latent_dim).cuda(self.rank) # moving averages generated_images = self.generate_truncated(self.GAN.GE, latents) for j in range(generated_images.size(0)): torchvision.utils.save_image( generated_images[j, :, :, :], str( Path(fake_path) / f'{str(j + batch_num * self.batch_size)}-ema.{ext}')) return fid_score.calculate_fid_given_paths([real_path, fake_path], 256, True, 2048) @torch.no_grad() def generate_truncated(self, G, style, trunc_psi=0.75, num_image_tiles=8): generated_images = evaluate_in_chunks(self.batch_size, G, style) return generated_images.clamp_(0., 1.) @torch.no_grad() def generate_interpolation(self, num=0, num_image_tiles=8, trunc=1.0, num_steps=100, save_frames=False): self.GAN.eval() ext = self.image_extension num_rows = num_image_tiles latent_dim = self.GAN.latent_dim image_size = self.GAN.image_size # latents and noise latents_low = torch.randn(num_rows**2, latent_dim).cuda(self.rank) latents_high = torch.randn(num_rows**2, latent_dim).cuda(self.rank) ratios = torch.linspace(0., 8., num_steps) frames = [] for ratio in tqdm(ratios): interp_latents = slerp(ratio, latents_low, latents_high) generated_images = self.generate_truncated(self.GAN.GE, interp_latents) images_grid = torchvision.utils.make_grid(generated_images, nrow=num_rows) pil_image = transforms.ToPILImage()(images_grid.cpu()) if self.transparent: background = Image.new('RGBA', pil_image.size, (255, 255, 255)) pil_image = Image.alpha_composite(background, pil_image) frames.append(pil_image) frames[0].save(str(self.results_dir / self.name / f'{str(num)}.gif'), save_all=True, append_images=frames[1:], duration=80, loop=0, optimize=True) if save_frames: folder_path = (self.results_dir / self.name / f'{str(num)}') folder_path.mkdir(parents=True, exist_ok=True) for ind, frame in enumerate(frames): frame.save(str(folder_path / f'{str(ind)}.{ext}')) def print_log(self): data = [('G', self.g_loss), ('D', self.d_loss), ('GP', self.last_gp_loss), ('SS', self.last_recon_loss), ('FID', self.last_fid)] data = [d for d in data if exists(d[1])] log = ' | '.join(map(lambda n: f'{n[0]}: {n[1]:.2f}', data)) print(log) def model_name(self, num): return str(self.models_dir / self.name / f'model_{num}.pt') def init_folders(self): (self.results_dir / self.name).mkdir(parents=True, exist_ok=True) (self.models_dir / self.name).mkdir(parents=True, exist_ok=True) def clear(self): rmtree(str(self.models_dir / self.name), True) rmtree(str(self.results_dir / self.name), True) rmtree(str(self.config_path), True) self.init_folders() def save(self, num): save_data = { 'GAN': self.GAN.state_dict(), 'version': __version__, 'G_scaler': self.G_scaler.state_dict(), 'D_scaler': self.D_scaler.state_dict() } torch.save(save_data, self.model_name(num)) self.write_config() def load(self, num=-1): self.load_config() name = num if num == -1: file_paths = [ p for p in Path(self.models_dir / self.name).glob('model_*.pt') ] saved_nums = sorted( map(lambda x: int(x.stem.split('_')[1]), file_paths)) if len(saved_nums) == 0: return name = saved_nums[-1] print(f'continuing from previous epoch - {name}') self.steps = name * self.save_every load_data = torch.load(self.model_name(name)) if 'version' in load_data and self.is_main: print(f"loading from version {load_data['version']}") try: self.GAN.load_state_dict(load_data['GAN']) except Exception as e: print( 'unable to load save model. please try downgrading the package to the version specified by the saved model' ) raise e if 'G_scaler' in load_data: self.G_scaler.load_state_dict(load_data['G_scaler']) if 'D_scaler' in load_data: self.D_scaler.load_state_dict(load_data['D_scaler'])
def load_checkpoint( path: str, device: torch.device ) -> (TEDD1104, str, torch.optim, torch.optim.lr_scheduler, float, int, bool, str): """ Restore checkpoint Input: -path: path of the checkpoint to restore Output: - model: restored TEDD1104 model - optimizer_name: Name of the optimizer used for training: SGD or Adam - optimizer: Optimizer used for training - acc_dev: Accuracy of the model in the development set - epoch: Num of epoch used to train the model - fp16: true if the model uses fp16 else false - scaler: If the model uses FP16, the scaler used for training """ checkpoint = torch.load(path) dict_hyperparams = checkpoint["hyper_params"] model_weights = checkpoint["model"] optimizer_name = checkpoint["optimizer_name"] optimizer_state = checkpoint["optimizer"] acc_dev = checkpoint["acc_dev"] epoch = checkpoint["epoch"] scaler_state = checkpoint["scaler"] fp16 = dict_hyperparams["fp16"] model: TEDD1104 = TEDD1104( resnet=dict_hyperparams["resnet"], pretrained_resnet=dict_hyperparams["pretrained_resnet"], sequence_size=dict_hyperparams["sequence_size"], embedded_size=dict_hyperparams["embedded_size"], hidden_size=dict_hyperparams["hidden_size"], num_layers_lstm=dict_hyperparams["num_layers_lstm"], bidirectional_lstm=dict_hyperparams["bidirectional_lstm"], layers_out=dict_hyperparams["layers_out"], dropout_cnn=dict_hyperparams["dropout_cnn"], dropout_cnn_out=dict_hyperparams["dropout_cnn_out"], dropout_lstm=dict_hyperparams["dropout_lstm"], dropout_lstm_out=dict_hyperparams["dropout_lstm_out"], ).to(device=device) if optimizer_name == "SGD": optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) elif optimizer_name == "Adam": optimizer = torch.optim.Adam(model.parameters(), lr=0.001) else: raise ValueError( f"The optimizer you are trying to load is unknown: " f"Optimizer name {optimizer_name}. Available optimizers: SGD, Adam" ) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True) model.load_state_dict(model_weights) optimizer.load_state_dict(optimizer_state) try: scheduler_state = checkpoint["scheduler"] scheduler.load_state_dict(scheduler_state) except KeyError: logging.warning(f"Legacy checkpoint, a new scheduler will be created") try: running_loss = checkpoint["running_loss"] except KeyError: logging.warning( "Legacy checkpoint, running loss will be initialized with 0.0 value" ) running_loss = 0.0 try: total_training_examples = checkpoint["total_training_examples"] except KeyError: logging.warning( "Legacy checkpoint, total training examples will be initialized with 0 value" ) total_training_examples = 0 try: total_batches = checkpoint["total_batches"] except KeyError: logging.warning( "Legacy checkpoint, total batches will be initialized with 0 value" ) total_batches = 0 scaler: Optional[GradScaler] if fp16: scaler = GradScaler() scaler.load_state_dict(scaler_state) else: scaler = None return ( model, optimizer_name, optimizer, scheduler, running_loss, total_batches, total_training_examples, acc_dev, epoch, fp16, scaler, )