class AMPGradAccumulateOptimizerHook(OptimizerHook): def __init__(self, *wargs, **kwargs): self.accumulation = kwargs.pop('accumulation', 1) self.scaler = GradScaler() super(AMPGradAccumulateOptimizerHook, self).__init__(*wargs, **kwargs) def before_run(self, runner): assert hasattr(runner.model.module, 'use_amp') and runner.model.module.use_amp, 'model should support AMP when using this optimizer hook!' runner.model.zero_grad() runner.optimizer.zero_grad() def before_train_iter(self, runner): if runner.iter % self.accumulation == 0: runner.model.zero_grad() runner.optimizer.zero_grad() def after_train_iter(self, runner): scaled_loss = self.scaler.scale(runner.outputs['loss']) scaled_loss.backward() if (runner.iter + 1) % self.accumulation == 0: scale = self.scaler.get_scale() if self.grad_clip is not None: self.scaler.unscale_(runner.optimizer) grad_norm = self.clip_grads(runner.model.parameters()) if grad_norm is not None: # Add grad norm to the logger runner.log_buffer.update({'grad_norm': float(grad_norm)}, runner.outputs['num_samples']) runner.log_buffer.update({'grad_scale': float(scale)}, runner.outputs['num_samples']) self.scaler.step(runner.optimizer) self.scaler.update()
class Trainer(): def __init__(self, name='default', results_dir='results', models_dir='models', base_dir='./', optimizer="adam", latent_dim=256, image_size=128, fmap_max=512, transparent=False, greyscale=False, batch_size=4, gp_weight=10, gradient_accumulate_every=1, attn_res_layers=[], disc_output_size=5, antialias=False, lr=2e-4, lr_mlp=1., ttur_mult=1., save_every=1000, evaluate_every=1000, trunc_psi=0.6, aug_prob=None, aug_types=['translation', 'cutout'], dataset_aug_prob=0., calculate_fid_every=None, is_ddp=False, rank=0, world_size=1, log=False, amp=False, *args, **kwargs): self.GAN_params = [args, kwargs] self.GAN = None self.name = name base_dir = Path(base_dir) self.base_dir = base_dir self.results_dir = base_dir / results_dir self.models_dir = base_dir / models_dir self.config_path = self.models_dir / name / '.config.json' assert is_power_of_two( image_size ), 'image size must be a power of 2 (64, 128, 256, 512, 1024)' assert all( map(is_power_of_two, attn_res_layers) ), 'resolution layers of attention must all be powers of 2 (16, 32, 64, 128, 256, 512)' self.optimizer = optimizer self.latent_dim = latent_dim self.image_size = image_size self.fmap_max = fmap_max self.transparent = transparent self.greyscale = greyscale assert (int(self.transparent) + int(self.greyscale) ) < 2, 'you can only set either transparency or greyscale' self.aug_prob = aug_prob self.aug_types = aug_types self.lr = lr self.ttur_mult = ttur_mult self.batch_size = batch_size self.gradient_accumulate_every = gradient_accumulate_every self.gp_weight = gp_weight self.evaluate_every = evaluate_every self.save_every = save_every self.steps = 0 self.generator_top_k_gamma = 0.99 self.generator_top_k_frac = 0.5 self.attn_res_layers = attn_res_layers self.disc_output_size = disc_output_size self.antialias = antialias self.d_loss = 0 self.g_loss = 0 self.last_gp_loss = None self.last_recon_loss = None self.last_fid = None self.init_folders() self.loader = None self.dataset_aug_prob = dataset_aug_prob self.calculate_fid_every = calculate_fid_every self.is_ddp = is_ddp self.is_main = rank == 0 self.rank = rank self.world_size = world_size self.syncbatchnorm = is_ddp self.amp = amp self.G_scaler = GradScaler(enabled=self.amp) self.D_scaler = GradScaler(enabled=self.amp) @property def image_extension(self): return 'jpg' if not self.transparent else 'png' @property def checkpoint_num(self): return floor(self.steps // self.save_every) def init_GAN(self): args, kwargs = self.GAN_params # set some global variables before instantiating GAN global norm_class global Blur norm_class = nn.SyncBatchNorm if self.syncbatchnorm else nn.BatchNorm2d Blur = nn.Identity if not self.antialias else Blur # handle bugs when # switching from multi-gpu back to single gpu if self.syncbatchnorm and not self.is_ddp: import torch.distributed as dist os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' dist.init_process_group('nccl', rank=0, world_size=1) # instantiate GAN self.GAN = LightweightGAN(optimizer=self.optimizer, lr=self.lr, latent_dim=self.latent_dim, attn_res_layers=self.attn_res_layers, image_size=self.image_size, ttur_mult=self.ttur_mult, fmap_max=self.fmap_max, disc_output_size=self.disc_output_size, transparent=self.transparent, greyscale=self.greyscale, rank=self.rank, *args, **kwargs) if self.is_ddp: ddp_kwargs = { 'device_ids': [self.rank], 'output_device': self.rank, 'find_unused_parameters': True } self.G_ddp = DDP(self.GAN.G, **ddp_kwargs) self.D_ddp = DDP(self.GAN.D, **ddp_kwargs) self.D_aug_ddp = DDP(self.GAN.D_aug, **ddp_kwargs) def write_config(self): self.config_path.write_text(json.dumps(self.config())) def load_config(self): config = self.config( ) if not self.config_path.exists() else json.loads( self.config_path.read_text()) self.image_size = config['image_size'] self.transparent = config['transparent'] self.syncbatchnorm = config['syncbatchnorm'] self.disc_output_size = config['disc_output_size'] self.greyscale = config.pop('greyscale', False) self.attn_res_layers = config.pop('attn_res_layers', []) self.optimizer = config.pop('optimizer', 'adam') self.fmap_max = config.pop('fmap_max', 512) del self.GAN self.init_GAN() def config(self): return { 'image_size': self.image_size, 'transparent': self.transparent, 'greyscale': self.greyscale, 'syncbatchnorm': self.syncbatchnorm, 'disc_output_size': self.disc_output_size, 'optimizer': self.optimizer, 'attn_res_layers': self.attn_res_layers } def set_data_src(self, folder): self.dataset = ImageDataset(folder, self.image_size, transparent=self.transparent, greyscale=self.greyscale, aug_prob=self.dataset_aug_prob) sampler = DistributedSampler(self.dataset, rank=self.rank, num_replicas=self.world_size, shuffle=True) if self.is_ddp else None dataloader = DataLoader( self.dataset, num_workers=math.ceil(NUM_CORES / self.world_size), batch_size=math.ceil(self.batch_size / self.world_size), sampler=sampler, shuffle=not self.is_ddp, drop_last=True, pin_memory=True) self.loader = cycle(dataloader) # auto set augmentation prob for user if dataset is detected to be low num_samples = len(self.dataset) if not exists(self.aug_prob) and num_samples < 1e5: self.aug_prob = min(0.5, (1e5 - num_samples) * 3e-6) print( f'autosetting augmentation probability to {round(self.aug_prob * 100)}%' ) def train(self): assert exists( self.loader ), 'You must first initialize the data source with `.set_data_src(<folder of images>)`' device = torch.device(f'cuda:{self.rank}') if not exists(self.GAN): self.init_GAN() self.GAN.train() total_disc_loss = torch.zeros([], device=device) total_gen_loss = torch.zeros([], device=device) batch_size = math.ceil(self.batch_size / self.world_size) image_size = self.GAN.image_size latent_dim = self.GAN.latent_dim aug_prob = default(self.aug_prob, 0) aug_types = self.aug_types aug_kwargs = {'prob': aug_prob, 'types': aug_types} G = self.GAN.G if not self.is_ddp else self.G_ddp D = self.GAN.D if not self.is_ddp else self.D_ddp D_aug = self.GAN.D_aug if not self.is_ddp else self.D_aug_ddp apply_gradient_penalty = self.steps % 4 == 0 # amp related contexts and functions amp_context = autocast if self.amp else null_context # train discriminator self.GAN.D_opt.zero_grad() for i in gradient_accumulate_contexts(self.gradient_accumulate_every, self.is_ddp, ddps=[D_aug, G]): latents = torch.randn(batch_size, latent_dim).cuda(self.rank) image_batch = next(self.loader).cuda(self.rank) image_batch.requires_grad_() with amp_context(): with torch.no_grad(): generated_images = G(latents) fake_output, fake_output_32x32, _ = D_aug(generated_images, detach=True, **aug_kwargs) real_output, real_output_32x32, real_aux_loss = D_aug( image_batch, calc_aux_loss=True, **aug_kwargs) real_output_loss = real_output fake_output_loss = fake_output divergence = hinge_loss(real_output_loss, fake_output_loss) divergence_32x32 = hinge_loss(real_output_32x32, fake_output_32x32) disc_loss = divergence + divergence_32x32 aux_loss = real_aux_loss disc_loss = disc_loss + aux_loss if apply_gradient_penalty: outputs = [real_output, real_output_32x32] outputs = list(map(self.D_scaler.scale, outputs)) if self.amp else outputs scaled_gradients = torch_grad( outputs=outputs, inputs=image_batch, grad_outputs=list( map( lambda t: torch.ones(t.size(), device=image_batch.device), outputs)), create_graph=True, retain_graph=True, only_inputs=True)[0] inv_scale = (1. / self.D_scaler.get_scale()) if self.amp else 1. gradients = scaled_gradients * inv_scale with amp_context(): gradients = gradients.reshape(batch_size, -1) gp = self.gp_weight * ( (gradients.norm(2, dim=1) - 1)**2).mean() if not torch.isnan(gp): disc_loss = disc_loss + gp self.last_gp_loss = gp.clone().detach().item() with amp_context(): disc_loss = disc_loss / self.gradient_accumulate_every disc_loss.register_hook(raise_if_nan) self.D_scaler.scale(disc_loss).backward() total_disc_loss += divergence self.last_recon_loss = aux_loss.item() self.d_loss = float(total_disc_loss.item() / self.gradient_accumulate_every) self.D_scaler.step(self.GAN.D_opt) self.D_scaler.update() # train generator self.GAN.G_opt.zero_grad() for i in gradient_accumulate_contexts(self.gradient_accumulate_every, self.is_ddp, ddps=[G, D_aug]): latents = torch.randn(batch_size, latent_dim).cuda(self.rank) with amp_context(): generated_images = G(latents) fake_output, fake_output_32x32, _ = D_aug( generated_images, **aug_kwargs) fake_output_loss = fake_output.mean( dim=1) + fake_output_32x32.mean(dim=1) epochs = (self.steps * batch_size * self.gradient_accumulate_every) / len(self.dataset) k_frac = max(self.generator_top_k_gamma**epochs, self.generator_top_k_frac) k = math.ceil(batch_size * k_frac) if k != batch_size: fake_output_loss, _ = fake_output_loss.topk(k=k, largest=False) loss = fake_output_loss.mean() gen_loss = loss gen_loss = gen_loss / self.gradient_accumulate_every gen_loss.register_hook(raise_if_nan) self.G_scaler.scale(gen_loss).backward() total_gen_loss += loss self.g_loss = float(total_gen_loss.item() / self.gradient_accumulate_every) self.G_scaler.step(self.GAN.G_opt) self.G_scaler.update() # calculate moving averages if self.is_main and self.steps % 10 == 0 and self.steps > 20000: self.GAN.EMA() if self.is_main and self.steps <= 25000 and self.steps % 1000 == 2: self.GAN.reset_parameter_averaging() # save from NaN errors if any(torch.isnan(l) for l in (total_gen_loss, total_disc_loss)): print( f'NaN detected for generator or discriminator. Loading from checkpoint #{self.checkpoint_num}' ) self.load(self.checkpoint_num) raise NanException del total_disc_loss del total_gen_loss # periodically save results if self.is_main: if self.steps % self.save_every == 0: self.save(self.checkpoint_num) if self.steps % self.evaluate_every == 0 or ( self.steps % 100 == 0 and self.steps < 20000): self.evaluate(floor(self.steps / self.evaluate_every)) if exists( self.calculate_fid_every ) and self.steps % self.calculate_fid_every == 0 and self.steps != 0: num_batches = math.ceil(CALC_FID_NUM_IMAGES / self.batch_size) fid = self.calculate_fid(num_batches) self.last_fid = fid with open( str(self.results_dir / self.name / f'fid_scores.txt'), 'a') as f: f.write(f'{self.steps},{fid}\n') self.steps += 1 @torch.no_grad() def evaluate(self, num=0, num_image_tiles=8, trunc=1.0): self.GAN.eval() ext = self.image_extension num_rows = num_image_tiles latent_dim = self.GAN.latent_dim image_size = self.GAN.image_size # latents and noise latents = torch.randn((num_rows**2, latent_dim)).cuda(self.rank) # regular generated_images = self.generate_truncated(self.GAN.G, latents) torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}.{ext}'), nrow=num_rows) # moving averages generated_images = self.generate_truncated(self.GAN.GE, latents) torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}-ema.{ext}'), nrow=num_rows) @torch.no_grad() def calculate_fid(self, num_batches): from pytorch_fid import fid_score torch.cuda.empty_cache() real_path = str(self.results_dir / self.name / 'fid_real') + '/' fake_path = str(self.results_dir / self.name / 'fid_fake') + '/' # remove any existing files used for fid calculation and recreate directories rmtree(real_path, ignore_errors=True) rmtree(fake_path, ignore_errors=True) os.makedirs(real_path) os.makedirs(fake_path) for batch_num in tqdm(range(num_batches), desc='calculating FID - saving reals'): real_batch = next(self.loader) for k in range(real_batch.size(0)): torchvision.utils.save_image( real_batch[k, :, :, :], real_path + '{}.png'.format(k + batch_num * self.batch_size)) # generate a bunch of fake images in results / name / fid_fake self.GAN.eval() ext = self.image_extension latent_dim = self.GAN.latent_dim image_size = self.GAN.image_size for batch_num in tqdm(range(num_batches), desc='calculating FID - saving generated'): # latents and noise latents = torch.randn(self.batch_size, latent_dim).cuda(self.rank) # moving averages generated_images = self.generate_truncated(self.GAN.GE, latents) for j in range(generated_images.size(0)): torchvision.utils.save_image( generated_images[j, :, :, :], str( Path(fake_path) / f'{str(j + batch_num * self.batch_size)}-ema.{ext}')) return fid_score.calculate_fid_given_paths([real_path, fake_path], 256, True, 2048) @torch.no_grad() def generate_truncated(self, G, style, trunc_psi=0.75, num_image_tiles=8): generated_images = evaluate_in_chunks(self.batch_size, G, style) return generated_images.clamp_(0., 1.) @torch.no_grad() def generate_interpolation(self, num=0, num_image_tiles=8, trunc=1.0, num_steps=100, save_frames=False): self.GAN.eval() ext = self.image_extension num_rows = num_image_tiles latent_dim = self.GAN.latent_dim image_size = self.GAN.image_size # latents and noise latents_low = torch.randn(num_rows**2, latent_dim).cuda(self.rank) latents_high = torch.randn(num_rows**2, latent_dim).cuda(self.rank) ratios = torch.linspace(0., 8., num_steps) frames = [] for ratio in tqdm(ratios): interp_latents = slerp(ratio, latents_low, latents_high) generated_images = self.generate_truncated(self.GAN.GE, interp_latents) images_grid = torchvision.utils.make_grid(generated_images, nrow=num_rows) pil_image = transforms.ToPILImage()(images_grid.cpu()) if self.transparent: background = Image.new('RGBA', pil_image.size, (255, 255, 255)) pil_image = Image.alpha_composite(background, pil_image) frames.append(pil_image) frames[0].save(str(self.results_dir / self.name / f'{str(num)}.gif'), save_all=True, append_images=frames[1:], duration=80, loop=0, optimize=True) if save_frames: folder_path = (self.results_dir / self.name / f'{str(num)}') folder_path.mkdir(parents=True, exist_ok=True) for ind, frame in enumerate(frames): frame.save(str(folder_path / f'{str(ind)}.{ext}')) def print_log(self): data = [('G', self.g_loss), ('D', self.d_loss), ('GP', self.last_gp_loss), ('SS', self.last_recon_loss), ('FID', self.last_fid)] data = [d for d in data if exists(d[1])] log = ' | '.join(map(lambda n: f'{n[0]}: {n[1]:.2f}', data)) print(log) def model_name(self, num): return str(self.models_dir / self.name / f'model_{num}.pt') def init_folders(self): (self.results_dir / self.name).mkdir(parents=True, exist_ok=True) (self.models_dir / self.name).mkdir(parents=True, exist_ok=True) def clear(self): rmtree(str(self.models_dir / self.name), True) rmtree(str(self.results_dir / self.name), True) rmtree(str(self.config_path), True) self.init_folders() def save(self, num): save_data = { 'GAN': self.GAN.state_dict(), 'version': __version__, 'G_scaler': self.G_scaler.state_dict(), 'D_scaler': self.D_scaler.state_dict() } torch.save(save_data, self.model_name(num)) self.write_config() def load(self, num=-1): self.load_config() name = num if num == -1: file_paths = [ p for p in Path(self.models_dir / self.name).glob('model_*.pt') ] saved_nums = sorted( map(lambda x: int(x.stem.split('_')[1]), file_paths)) if len(saved_nums) == 0: return name = saved_nums[-1] print(f'continuing from previous epoch - {name}') self.steps = name * self.save_every load_data = torch.load(self.model_name(name)) if 'version' in load_data and self.is_main: print(f"loading from version {load_data['version']}") try: self.GAN.load_state_dict(load_data['GAN']) except Exception as e: print( 'unable to load save model. please try downgrading the package to the version specified by the saved model' ) raise e if 'G_scaler' in load_data: self.G_scaler.load_state_dict(load_data['G_scaler']) if 'D_scaler' in load_data: self.D_scaler.load_state_dict(load_data['D_scaler'])
def train_epoch(self, model: Reader, optimizer: torch.optim.Optimizer, scaler: GradScaler, train: DataLoader, val: DataLoader, scheduler: torch.optim.lr_scheduler.LambdaLR) -> float: """ Performs one training epoch. :param model: The model you are training. :type model: Reader :param optimizer: Use this optimizer for training. :type optimizer: torch.optim.Optimizer :param scaler: Scaler for gradients when the mixed precision is used. :type scaler: GradScaler :param train: The train dataset loader. :type train: DataLoader :param val: The validation dataset loader. :type val: DataLoader :param scheduler: Learning rate scheduler. :type scheduler: torch.optim.lr_scheduler.LambdaLR :return: Best achieved exact match among validations. :rtype: float """ model.train() loss_sum = 0 samples = 0 startTime = time.time() total_tokens = 0 optimizer.zero_grad() initStep = 0 if self.resumeSkip is not None: initStep = self.resumeSkip self.resumeSkip = None iterator = tqdm(enumerate(train), total=len(train), initial=initStep) bestExactMatch = 0.0 for current_it, batch in iterator: batch: ReaderBatch lastScale = scaler.get_scale() self.n_iter += 1 batchOnDevice = batch.to(self.device) samples += 1 try: with torch.cuda.amp.autocast( enabled=self.config["mixed_precision"]): startScores, endScores, jointScore, selectionScore = self._useModel( model, batchOnDevice) # according to the config we can get following loss combinations # join components # independent components # join components with HardEM # independent components with HardEM logSpanProb = None if not self.config["independent_components_in_loss"]: # joined components in loss logSpanProb = Reader.scores2logSpanProb( startScores, endScores, jointScore, selectionScore) # User may want to use hardEMLoss with certain probability. # In the original article it is not written clearly and it seams like it is the other way around. # After I had consulted it with authors the idea became clear. if self.config["hard_em_steps"] > 0 and \ random.random() <= min(self.update_it/self.config["hard_em_steps"], self.config["max_hard_em_prob"]): # loss is calculated for the max answer span with max probability if self.config["independent_components_in_loss"]: loss = Reader.hardEMIndependentComponentsLoss( startScores, endScores, jointScore, selectionScore, batchOnDevice.answersMask) else: loss = Reader.hardEMLoss(logSpanProb, batchOnDevice.answersMask) else: # loss is calculated for all answer spans if self.config["independent_components_in_loss"]: loss = Reader.marginalCompoundLossWithIndependentComponents( startScores, endScores, jointScore, selectionScore, batchOnDevice.answersMask) else: loss = Reader.marginalCompoundLoss( logSpanProb, batchOnDevice.answersMask) if self.config[ "use_auxiliary_loss"] and batch.isGroundTruth: # we must be sure that user wants it and that the true passage is ground truth loss += Reader.auxiliarySelectedLoss(selectionScore) loss_sum += loss.item() scaler.scale(loss).backward() # Catch out-of-memory errors except RuntimeError as e: if "CUDA out of memory." in str(e): torch.cuda.empty_cache() logging.error(e) tb = traceback.format_exc() logging.error(tb) continue else: raise e # update parameters scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_( filter(lambda p: p.requires_grad, model.parameters()), self.config["max_grad_norm"]) scaler.step(optimizer) scaler.update() optimizer.zero_grad() self.update_it += 1 if math.isclose(lastScale, scaler.get_scale(), rel_tol=1e-6) and scheduler is not None: # we should not perform scheduler step when the optimizer step was omitted due to the # change of scale factor scheduler.step() if self.update_it % self.config["validate_after_steps"] == 0: valLoss, exactMatch, passageMatch, samplesWithLoss = self.validate( model, val) logging.info( f"Steps:{self.update_it}, Training loss: {loss_sum / samples:.5f}, Validation loss: {valLoss} (samples with loss {samplesWithLoss} [{samplesWithLoss / len(val):.1%}]), Exact match: {exactMatch:.5f}, Passage match: {passageMatch:.5f}" ) bestExactMatch = max(exactMatch, bestExactMatch) if self.update_it > self.config["first_save_after_updates_K"]: checkpoint = Checkpoint( model.module if isinstance(model, DataParallel) else model, optimizer, scheduler, train.sampler.actPerm, current_it + 1, self.config, self.update_it) checkpoint.save(f"{self.config['save_dir']}/Reader_train" f"_{get_timestamp()}" f"_{socket.gethostname()}" f"_{valLoss}" f"_S_{self.update_it}" f"_E_{current_it}.pt") model.train() # statistics & logging total_tokens += batch.inputSequences.numel() if (self.n_iter + 1) % 50 == 0 or current_it == len(iterator) - 1: iterator.set_description( f"Steps: {self.update_it} Tokens/s: {total_tokens / (time.time() - startTime)}, Training loss: {loss_sum / samples}" ) if self.config["max_steps"] <= self.update_it: break logging.info( f"End of epoch training loss: {loss_sum / samples:.5f}, best validation exact match: {bestExactMatch}" ) return bestExactMatch
class Solver(object): def __init__(self, config): self.model = None self.args = config self.criterion = None self.optimizer = None self.scheduler = None self.device = None self.cuda = config.cuda self.train_loader = None self.test_loader = None self.infer_loader = None self.es = EarlyStopping(patience=self.args.es_patience) self.scaler = GradScaler(enabled=self.args.half) if not self.args.save_dir: self.writer = SummaryWriter() else: self.writer = SummaryWriter(log_dir="runs/" + self.args.save_dir) self.train_batch_plot_idx = 0 self.test_batch_plot_idx = 0 def load_data(self): if self.args.dataset.name not in datasets: print( f"This dataset is not implemented ({self.args.dataset.name}), go ahead and commit it" ) exit() train_cache_index = 0 train_data_transformations = [] for idx, transformation in enumerate( self.args.transformations.train.data): if transformation.name not in transformations: print( f"This transformation is not implemented ({transformation.name}), go ahead and commit it" ) exit() if hasattr(transformation, 'cache_point'): train_cache_index = idx + 1 train_data_transformations.append( transformations[transformation.name]( **transformation.parameters)) train_target_transformations = [] for transformation in self.args.transformations.train.target: if transformation.name not in transformations: print( f"This transformation is not implemented ({transformation.name}), go ahead and commit it" ) exit() train_target_transformations.append( transformations[transformation.name]( **transformation.parameters)) train_both_transformations = [] for transformation in self.args.transformations.train.both: if transformation.name not in transformations: print( f"This transformation is not implemented ({transformation.name}), go ahead and commit it" ) exit() train_both_transformations.append( transformations[transformation.name]( **transformation.parameters)) train_output_transformations = [] for transformation in self.args.transformations.train.output: if transformation.name not in transformations: print( f"This transformation is not implemented ({transformation.name}), go ahead and commit it" ) exit() train_output_transformations.append( transformations[transformation.name]( **transformation.parameters)) train_data_transform = transforms.Compose( train_data_transformations ) if len(train_data_transformations) > 0 else None train_target_transform = transforms.Compose( train_target_transformations ) if len(train_target_transformations) > 0 else None train_both_transform = transforms.Compose( train_both_transformations ) if len(train_both_transformations) > 0 else None self.train_output_transform = transforms.Compose( train_output_transformations ) if len(train_output_transformations) > 0 else None test_cache_index = 0 test_data_transformations = [] for idx, transformation in enumerate( self.args.transformations.test.data): if transformation.name not in transformations: print( f"This transformation is not implemented ({transformation.name}), go ahead and commit it" ) exit() if hasattr(transformation, 'cache_point'): test_cache_index = idx + 1 test_data_transformations.append( transformations[transformation.name]( **transformation.parameters)) test_target_transformations = [] for transformation in self.args.transformations.test.target: if transformation.name not in transformations: print( f"This transformation is not implemented ({transformation.name}), go ahead and commit it" ) exit() test_target_transformations.append( transformations[transformation.name]( **transformation.parameters)) test_both_transformations = [] for transformation in self.args.transformations.test.both: if transformation.name not in transformations: print( f"This transformation is not implemented ({transformation.name}), go ahead and commit it" ) exit() test_both_transformations.append( transformations[transformation.name]( **transformation.parameters)) test_output_transformations = [] for transformation in self.args.transformations.test.output: if transformation.name not in transformations: print( f"This transformation is not implemented ({transformation.name}), go ahead and commit it" ) exit() test_output_transformations.append( transformations[transformation.name]( **transformation.parameters)) test_data_transform = transforms.Compose( test_data_transformations ) if len(test_data_transformations) > 0 else None test_target_transform = transforms.Compose( test_target_transformations ) if len(test_target_transformations) > 0 else None test_both_transform = transforms.Compose( test_both_transformations ) if len(test_both_transformations) > 0 else None self.test_output_transform = transforms.Compose( test_output_transformations ) if len(test_output_transformations) > 0 else None parameters = OmegaConf.to_container( self.args.dataset.train_loader_params, resolve=True) parameters = {k: v for k, v in parameters.items() if v is not None} if self.args.dataset.name in ['CIFAR-10', 'CIFAR-100', 'ImageNet2012']: parameters["transform"] = train_data_transform parameters["target_transform"] = train_target_transform else: parameters["data_transform"] = train_data_transform parameters["target_transform"] = train_target_transform parameters["both_transform"] = train_both_transform parameters['cache_index'] = train_cache_index self.train_set = datasets[self.args.dataset.name](**parameters) parameters = OmegaConf.to_container( self.args.dataset.test_loader_params, resolve=True) parameters = {k: v for k, v in parameters.items() if v is not None} if self.args.dataset.name in ['CIFAR-10', 'CIFAR-100', 'ImageNet2012']: parameters["transform"] = test_data_transform parameters["target_transform"] = test_target_transform else: parameters["data_transform"] = test_data_transform parameters["target_transform"] = test_target_transform parameters["both_transform"] = test_both_transform parameters['cache_index'] = test_cache_index self.test_set = datasets[self.args.dataset.name](**parameters) if hasattr(self.args.dataset, 'mixup_args') and self.args.dataset.mixup_args != None: collate_fn = FastCollateMixup(**self.args.dataset.mixup_args) else: collate_fn = None self.train_loader = torch.utils.data.DataLoader( dataset=self.train_set, batch_size=self.args.dataset.train_batch_size, shuffle=self.args.dataset.shuffle, num_workers=self.args.dataset.num_workers_train, collate_fn=collate_fn, drop_last=True, persistent_workers=self.args.dataset.num_workers_train > 0) self.test_loader = torch.utils.data.DataLoader( dataset=self.test_set, batch_size=self.args.dataset.test_batch_size, shuffle=False, num_workers=self.args.dataset.num_workers_test, persistent_workers=self.args.dataset.num_workers_test > 0) if self.args.infer_only is True: parameters = OmegaConf.to_container( self.args.dataset.infer_loader_params, resolve=True) parameters = {k: v for k, v in parameters.items() if v is not None} parameters["data_transform"] = test_data_transform parameters["target_transform"] = test_target_transform parameters["both_transform"] = test_both_transform self.infer_set = datasets[self.args.dataset.name](**parameters) self.infer_loader = torch.utils.data.DataLoader( dataset=self.infer_set, batch_size=self.args.dataset.test_batch_size, shuffle=False, num_workers=self.args.dataset.num_workers_test, persistent_workers=self.args.dataset.num_workers_test > 0) def init_model(self): if self.cuda: self.device = torch.device('cuda' + ":" + str(self.args.cuda_device)) cudnn.benchmark = True # The flag below controls whether to allow TF32 on matmul. This flag defaults to True. torch.backends.cuda.matmul.allow_tf32 = True # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. torch.backends.cudnn.allow_tf32 = True else: self.device = torch.device('cpu') parameters = OmegaConf.to_container(self.args.model.parameters, resolve=True) parameters = {k: v for k, v in parameters.items() if v is not None} try: self.model = getattr(models, self.args.model.name) except: print( f"This model is not implemented ({self.args.model.name}), go ahead and commit it" ) exit() self.model = self.model(**parameters) self.save_dir = os.path.join(self.args.storage_dir, "model_weights", self.args.save_dir) if not os.path.isdir(self.save_dir): os.makedirs(self.save_dir) if self.args.initialization == 1: # xavier init for m in self.model.modules(): if isinstance(m, (nn.Conv2d, nn.Linear)): nn.init.xavier_uniform(m.weight, gain=nn.init.calculate_gain('relu')) elif self.args.initialization == 2: # he initialization for m in self.model.modules(): if isinstance(m, (nn.Conv2d, nn.Linear)): nn.init.kaiming_normal(m.weight, mode='fan_in') elif self.args.initialization == 3: # selu init for m in self.model.modules(): if isinstance(m, nn.Conv2d): fan_in = m.kernel_size[0] * \ m.kernel_size[1] * m.in_channels nn.init.normal(m.weight, 0, torch.sqrt(1. / fan_in)) elif isinstance(m, nn.Linear): fan_in = m.in_features nn.init.normal(m.weight, 0, torch.sqrt(1. / fan_in)) elif self.args.initialization == 4: # orthogonal initialization for m in self.model.modules(): if isinstance(m, (nn.Conv2d, nn.Linear)): nn.init.orthogonal_(m.weight) if self.args.initialization_batch_norm: # batch norm initialization for m in self.model.modules(): if isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) if len(self.args.load_model) > 0: print("Loading model from " + self.args.load_model) self.model.load_state_dict(torch.load(self.args.load_model)) # for param in self.model.parameters(): # param.requires_grad = True # for param in self.model.patch_embed.parameters(): # param.requires_grad = True # for param in self.model.norm.parameters(): # param.requires_grad = True # for param in self.model.avgpool.parameters(): # param.requires_grad = True # for param in self.model.head.parameters(): # param.requires_grad = True self.model = self.model.to(self.device) def init_optimizer(self): parameters = OmegaConf.to_container(self.args.optimizer.parameters, resolve=True) parameters = {k: v for k, v in parameters.items() if v is not None} parameters["params"] = self.model.parameters() try: self.optimizer = getattr(torch_optimizer, self.args.optimizer.name) except Exception as e: try: self.optimizer = getattr(optim, self.args.optimizer.name) except: print( f"This optimizer is not implemented ({self.args.optimizer.name}), go ahead and commit it" ) exit() self.optimizer = self.optimizer(**parameters) if self.args.optimizer.use_SAM: self.optimizer = optimizers['SAM'](base_optimizer=self.optimizer, rho=self.args.optimizer.SAM_rho) if self.args.optimizer.use_lookahead: self.optimizer = torch_optimizer.Lookahead( self.optimizer, k=self.args.optimizer.lookahead_k, alpha=self.args.optimizer.lookahead_alpha) def init_scheduler(self): if self.args.scheduler.name not in schedulers: print( f"This loss is not implemented ({self.args.scheduler.name}), go ahead and commit it" ) exit() parameters = OmegaConf.to_container(self.args.scheduler.parameters, resolve=True) parameters = {k: v for k, v in parameters.items() if v is not None} parameters["optimizer"] = self.optimizer self.scheduler = schedulers[self.args.scheduler.name](**parameters) def init_criterion(self): if self.args.loss.name not in losses: print( f"This loss is not implemented ({self.args.loss.name}), go ahead and commit it" ) exit() parameters = OmegaConf.to_container(self.args.loss.parameters, resolve=True) parameters = {k: v for k, v in parameters.items() if v is not None} self.criterion = losses[self.args.loss.name]['constructor']( **parameters) def init_metrics(self): self.metrics = { 'train': { 'batch': [], 'epoch': [] }, 'test': { 'batch': [], 'epoch': [] }, 'solver': { 'batch': [], 'epoch': [] }, } for metric in self.args.metrics.train: if metric.name not in metrics: print( f"This metric is not implemented ({metric.name}), go ahead and commit it" ) exit() metric_func = metrics[metric.name]['constructor']( **metric.parameters) metric_object = Metric(metric.name, metric_func, solver_metric=False, aggregator=metric.aggregator) for level in metric.levels: self.metrics['train'][level].append(metric_object) for metric in self.args.metrics.test: if metric.name not in metrics: print( f"This metric is not implemented ({metric.name}), go ahead and commit it" ) exit() metric_func = metrics[metric.name]['constructor']( **metric.parameters) metric_object = Metric(metric.name, metric_func, solver_metric=False, aggregator=metric.aggregator) for level in metric.levels: self.metrics['test'][level].append(metric_object) for metric in self.args.metrics.solver: if metric.name not in metrics: print( f"This metric is not implemented ({metric.name}), go ahead and commit it" ) exit() metric_func = metrics[metric.name]['constructor']( **metric.parameters) metric_object = Metric(metric.name, metric_func, solver_metric=True, aggregator=metric.aggregator) for level in metric.levels: self.metrics['solver'][level].append(metric_object) def disable_bn(self): for module in self.model.modules(): if isinstance(module, nn.modules.batchnorm._NormBase) or isinstance( module, nn.LayerNorm): module.eval() def enable_bn(self): self.model.train() def train(self): print("train:") self.model.train() total_loss = 0 correct = 0 total = 0 accumulation_data = [] accumulation_target = [] predictions = [] targets = [] for batch_num, (data, target) in enumerate(self.train_loader): if isinstance(data, list): data = [i.to(self.device) for i in data] else: data = data.to(self.device) if isinstance(target, list): target = [i.to(self.device) for i in target] else: target = target.to(self.device) if self.args.optimizer.use_SAM: accumulation_data.append(data) accumulation_target.append(target) while True: with autocast(enabled=self.args.half): output = self.model(data) if self.train_output_transform is not None: output = self.train_output_transform(output) loss = self.criterion(output, target) loss = loss / self.args.dataset.update_every if self.args.optimizer.grad_penalty is not None and self.args.optimizer.grad_penalty > 0.0: # Creates gradients scaled_grad_params = torch.autograd.grad( outputs=self.scaler.scale(loss), inputs=self.model.parameters(), create_graph=True) #Creates unscaled grad_params before computing the penalty. scaled_grad_params are # not owned by any optimizer, so ordinary division is used instead of scaler.unscale_: inv_scale = 1. / self.scaler.get_scale() grad_params = [p * inv_scale for p in scaled_grad_params] # Computes the penalty term and adds it to the loss with autocast(): grad_norm = 0 for grad in grad_params: grad_norm += grad.pow(2).sum() grad_norm = grad_norm.sqrt() loss = loss + (grad_norm * self.args.optimizer.grad_penalty) self.scaler.scale(loss).backward() def sam_closure(): self.disable_bn() for i in range(len(accumulation_data)): with autocast(enabled=self.args.half): output = self.model(accumulation_data[i]) if self.train_output_transform is not None: output = self.train_output_transform(output) loss = self.criterion(output, accumulation_target[i]) loss = loss / self.args.dataset.update_every if self.args.optimizer.grad_penalty is not None and self.args.optimizer.grad_penalty is not False and self.args.optimizer.grad_penalty > 0.0: # Creates gradients scaled_grad_params = torch.autograd.grad( outputs=self.scaler.scale(loss), inputs=self.model.parameters(), create_graph=True) #Creates unscaled grad_params before computing the penalty. scaled_grad_params are # not owned by any optimizer, so ordinary division is used instead of scaler.unscale_: inv_scale = 1. / self.scaler.get_scale() grad_params = [ p * inv_scale for p in scaled_grad_params ] # Computes the penalty term and adds it to the loss with autocast(): grad_norm = 0 for grad in grad_params: grad_norm += grad.pow(2).sum() grad_norm = grad_norm.sqrt() loss = loss + ( grad_norm * self.args.optimizer.grad_penalty) self.scaler.scale(loss).backward() self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.optimizer.max_norm) self.enable_bn() if self.args.optimizer.batch_replay: found_inf = False for _, param in self.model.named_parameters(): if param.grad.isnan().any() or param.grad.isinf().any( ): found_inf = True break if found_inf: self.scaler.update() self.optimizer.zero_grad() if type(self.args.optimizer.batch_replay ) == int or type( self.args.optimizer.batch_replay) == float: self.args.optimizer.batch_replay -= 1 else: break else: break if self.train_batch_plot_idx % self.args.dataset.update_every == 0: self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.optimizer.max_norm) self.scaler.step(self.optimizer, closure=sam_closure if self.args.optimizer.use_SAM else None) self.scaler.update() self.optimizer.zero_grad() accumulation_data = [] accumulation_target = [] predictions.extend(output) targets.extend(target) metrics_results = {} for metric in self.metrics['train']['batch']: metrics_results["Train/Batch-" + metric.name] = metric.calculate(output, target, level='batch') for metric in self.metrics['solver']['batch']: metrics_results["Solver/Batch-" + metric.name] = metric.calculate(solver=self, level='batch') print_metrics(self.writer, metrics_results, self.get_train_batch_plot_idx()) if self.args.progress_bar: progress_bar(batch_num, len(self.train_loader)) if self.args.scheduler.name == "OneCycleLR": self.scheduler.step() return torch.stack(predictions), torch.stack(targets) def test(self): print("test:") self.model.eval() total_loss = 0 correct = 0 total = 0 predictions = [] targets = [] with torch.no_grad(): for batch_num, (data, target) in enumerate(self.test_loader): if isinstance(data, list): data = [i.to(self.device) for i in data] else: data = data.to(self.device) if isinstance(target, list): target = [i.to(self.device) for i in target] else: target = target.to(self.device) with autocast(enabled=self.args.half): output = self.model(data) if self.test_output_transform is not None: output = self.test_output_transform(output) loss = self.criterion(output, target) predictions.extend(output) targets.extend(target) metrics_results = {} for metric in self.metrics['test']['batch']: metrics_results["Test/Batch-" + metric.name] = metric.calculate( output, target, level='batch') print_metrics(self.writer, metrics_results, self.get_test_batch_plot_idx()) if self.args.progress_bar: progress_bar(batch_num, len(self.test_loader)) return torch.stack(predictions), torch.stack(targets) def infer(self): print("infer:") self.model.eval() predictions = [] filenames = [] with torch.no_grad(): for batch_num, (filename, data) in enumerate(self.infer_loader): if isinstance(data, list): data = [i.to(self.device) for i in data] else: data = data.to(self.device) with autocast(enabled=self.args.half): output = self.model(data) if self.test_output_transform is not None: output = self.test_output_transform(output) predictions.extend(output) filenames.extend(filename) return filenames, torch.stack(predictions) def save(self, epoch, metric, tag=None): if tag != None: tag = "_" + tag else: tag = "" model_out_path = os.path.join( self.save_dir, "model_{}_{}{}.pth".format(epoch, metric, tag)) torch.save(self.model.state_dict(), model_out_path) print("Checkpoint saved to {}".format(model_out_path)) def run(self): if self.args.seed is not None: reset_seed(self.args.seed) self.load_data() self.init_model() self.init_optimizer() self.init_scheduler() self.init_criterion() self.init_metrics() try: if self.args.infer_only == True: filenames, predictions = self.infer( ) # If its the "separated" dataset, we need to average the scores of the 2/3 different projections predictions = predictions.argmax(-1) + 1 save_path = os.path.join(self.save_dir, "predictions.csv") pd.DataFrame({ 'Patient': filenames, 'Class': predictions.cpu().numpy() }).to_csv(save_path, header=False, index=False) exit() best_metrics = {} higher_is_better = metrics[self.args.optimized_metric.split('/') [-1]]['higher_is_better'] for epoch in range(1, self.args.epochs + 1): print("\n===> epoch: %d/%d" % (epoch, self.args.epochs)) self.epoch = epoch metrics_results = {} predictions, targets = self.train() for metric in self.metrics['train']['epoch']: metric_name = "Train/" + metric.name metrics_results[metric_name] = metric.calculate( predictions, targets, level='epoch') if self.epoch % self.args.test_every == 0: predictions, targets = self.test() for metric in self.metrics['test']['epoch']: metric_name = "Test/" + metric.name metrics_results[metric_name] = metric.calculate( predictions, targets, level='epoch') for metric in self.metrics['solver']['epoch']: metric_name = "Solver/" + metric.name metrics_results[metric_name] = metric.calculate( solver=self, level='epoch') print_metrics(self.writer, metrics_results, self.epoch) if self.epoch % self.args.test_every == 0: save_best_metric = False if self.args.optimized_metric not in best_metrics: best_metrics[ self.args.optimized_metric] = metrics_results[ self.args.optimized_metric] save_best_metric = True if higher_is_better: if best_metrics[ self.args.optimized_metric] < metrics_results[ self.args.optimized_metric]: best_metrics[ self.args.optimized_metric] = metrics_results[ self.args.optimized_metric] save_best_metric = True else: if best_metrics[ self.args.optimized_metric] > metrics_results[ self.args.optimized_metric]: best_metrics[ self.args.optimized_metric] = metrics_results[ self.args.optimized_metric] save_best_metric = True if save_best_metric: self.save(epoch, best_metrics[self.args.optimized_metric]) print("===> BEST " + self.args.optimized_metric + " PERFORMANCE: %.5f" % best_metrics[self.args.optimized_metric]) if self.args.save_model and epoch % self.args.save_interval == 0: self.save(epoch, 0) if self.args.scheduler.name == "MultiStepLR": self.scheduler.step() elif self.args.scheduler.name == "ReduceLROnPlateau": self.scheduler.step( metrics_results[self.args.scheduler_metric]) elif self.args.scheduler.name == "OneCycleLR": pass else: self.scheduler.step() if self.es.step(metrics_results[self.args.es_metric]): print("Early stopping") raise KeyboardInterrupt except KeyboardInterrupt: pass print("===> BEST " + self.args.optimized_metric + " PERFORMANCE: %.5f" % best_metrics[self.args.optimized_metric]) files = os.listdir(self.save_dir) paths = [ os.path.join(self.save_dir, basename) for basename in files if "_0" not in basename ] if len(paths) > 0: src = max(paths, key=os.path.getctime) copyfile( src, os.path.join("runs", self.args.save_dir, os.path.basename(src))) with open("runs/" + self.args.save_dir + "/README.md", 'a+') as f: f.write("\n## " + self.args.optimized_metric + "\n %.5f" % (best_metrics[self.args.optimized_metric])) tensorboard_export_dump(self.writer) print("Saved best accuracy checkpoint") return best_metrics[self.args.optimized_metric] def get_train_batch_plot_idx(self): self.train_batch_plot_idx += 1 return self.train_batch_plot_idx - 1 def get_test_batch_plot_idx(self): self.test_batch_plot_idx += 1 return self.test_batch_plot_idx - 1
class Trainer(): images_evaluated: int = 0 accumulated_batches: int = 0 accumulated_loss: float = 0.0 accumulated_accuracy: float = 0.0 all_predictions = [] all_labels = [] def __init__(self, options: TrainerOptions): self.net = options.net self.dataloader = options.dataloader self.optimizer = options.optimizer self.criterion = options.criterion self.save_dir = options.save_dir self.freeze = options.freeze self.accumulate_over_n_batches = options.accumulate_over_n_batches self.distributed = options.distributed self.gpu_rank = options.gpu_rank self.n_gpus = options.n_gpus self.test_time_bn = options.test_time_bn self.dtype = options.dtype if self.distributed: self.config_distributed(self.n_gpus, self.gpu_rank) self.mixedprecision = options.mixedprecision if self.mixedprecision: self.grad_scaler = GradScaler(init_scale=8192, growth_interval=4) self.multilabel = options.multilabel self.regression = options.regression self.reset_epoch_stats() def config_distributed(self, n_gpus, gpu_rank=None): self.sync_networks_distributed_if_needed() self.n_gpus = torch.cuda.device_count() if n_gpus is None else n_gpus assert gpu_rank is not None self.gpu_rank = gpu_rank def sync_networks_distributed_if_needed(self, check=True): if self.distributed: self.sync_network_distributed(self.net, check) def sync_network_distributed(self, net, check=True): for _, param in net.named_parameters(): dist.broadcast(param.data, 0) for mod in net.modules(): if isinstance(mod, torch.nn.BatchNorm2d): dist.broadcast(mod.running_mean, 0) dist.broadcast(mod.running_var, 0) def prepare_network_for_training(self): torch.set_grad_enabled(True) self.optimizer.zero_grad() self.net.train() for mod in self.freeze: mod.eval() def prepare_network_for_evaluation(self): torch.set_grad_enabled(False) self.net.eval() self.prepare_batchnorm_for_evaluation(self.net) def prepare_batchnorm_for_evaluation(self, net): for mod in net.modules(): if isinstance(mod, torch.nn.BatchNorm2d): if self.test_time_bn: mod.train() else: mod.eval() def reset_epoch_stats(self): self.accumulated_loss = 0 self.accumulated_accuracy = 0 self.batches_evaluated = 0 self.images_evaluated = 0 self.accumulated_batches = 0 self.all_predictions = [] self.all_labels = [] def save_batch_stats(self, loss, accuracy, predictions, labels): self.accumulated_loss += float(loss) * len(labels) self.accumulated_accuracy += accuracy * len(labels) self.batches_evaluated += 1 self.images_evaluated += len(labels) self.all_predictions.append(predictions) self.all_labels.append( labels.copy() ) # https://github.com/pytorch/pytorch/issues/973#issuecomment-459398189 | fix RuntimeError: received 0 items of ancdata def stack_epoch_predictions(self): self.all_predictions, self.all_labels = self.epoch_predictions_and_labels( gather=True) def correct_loss_for_multigpu(self): self.accumulated_loss = 0.0 self.accumulated_accuracy = 0.0 for pred, label in zip(self.all_predictions, self.all_labels): self.accumulated_loss += float( self.criterion(pred[None], label[None])) self.accumulated_accuracy += self.accuracy_with_predictions( pred[None], label) self.images_evaluated = len(self.all_predictions) def epoch_predictions_and_labels(self, gather=False): preds, labels = [], [] if len(self.all_predictions) > 0: preds = np.vstack(self.all_predictions) labels = np.vstack(self.all_labels) if self.distributed and gather: preds = list(self.gather(preds)) labels = list(self.gather(labels)) preds = torch.from_numpy(np.array(preds)) labels = torch.from_numpy( np.array(labels).astype(self.all_labels[0].dtype)) # reshape to correct shapes if len(self.all_labels[0].shape) == 1: labels = labels.flatten() # labels = labels.view(-1, self.all_labels[0].shape[0]) if len(self.all_predictions[0].shape) == 1: preds = preds.flatten() # preds = preds.view(-1, self.all_predictions[0].shape[0]) return preds.float(), labels else: return torch.FloatTensor(), torch.LongTensor() def gather(self, results): results = torch.tensor(results, dtype=torch.float32).cuda() tensor_list = [ results.new_empty(results.shape) for _ in range(self.n_gpus) ] dist.all_gather(tensor_list, results) cpu_list = [tensor.cpu().numpy() for tensor in tensor_list] return np.concatenate(cpu_list, axis=0) def average_epoch_loss(self): if self.images_evaluated == 0: return -1 return self.accumulated_loss / self.images_evaluated def average_epoch_accuracy(self): if self.images_evaluated == 0: return -1 return self.accumulated_accuracy / self.images_evaluated def train_epoch(self, batch_callback) -> typing.Tuple[np.array, np.array]: self.sync_networks_distributed_if_needed() self.prepare_network_for_training() self.reset_epoch_stats() self.train_full_dataloader(batch_callback) self.stack_epoch_predictions() if self.distributed: self.correct_loss_for_multigpu() return self.all_predictions, self.all_labels def validation_epoch(self, batch_callback): self.prepare_network_for_evaluation() self.reset_epoch_stats() self.evaluate_full_dataloader(batch_callback) self.stack_epoch_predictions() if self.distributed: self.correct_loss_for_multigpu() return self.all_predictions, self.all_labels def train_full_dataloader(self, batch_callback): for x, y in self.dataloader: loss, accuracy, predictions = self.train_on_batch(x, y) self.save_batch_stats(loss, accuracy, predictions, y.cpu().numpy()) batch_callback(self, self.batches_evaluated, loss, accuracy) def evaluate_full_dataloader(self, batch_callback): for x, y in self.dataloader: loss, accuracy, predictions = self.forward_batch(x, y) self.save_batch_stats(loss, accuracy, predictions, y.cpu().numpy()) batch_callback(self, self.batches_evaluated, loss, accuracy) def forward_batch(self, x, y): if self.mixedprecision: with autocast(): output, loss = self.forward_batch_with_loss(x, y) else: output, loss = self.forward_batch_with_loss(x, y) output = output.detach().cpu() accuracy = self.accuracy_with_predictions(output, y.cpu()) # NOTE: removed a `del output` here, could cause memory issues return loss, accuracy, output.numpy() def forward_batch_with_loss(self, x, y): output = self.net.forward(x.cuda()) label = y.cuda() loss = self.criterion(output, label) return output, loss def train_on_batch(self, x, y): loss, accuracy, predictions = self.forward_batch(x, y) full_loss = float(loss) loss = loss / self.accumulate_over_n_batches / self.n_gpus if self.mixedprecision: self.grad_scaler.scale(loss).backward() else: loss.backward() self.accumulated_batches += 1 self.step_optimizer_if_needed() return full_loss, accuracy, predictions def step_optimizer_if_needed(self): if self.accumulated_batches == self.accumulate_over_n_batches: self.distribute_gradients_if_needed() if self.mixedprecision: self.grad_scaler.step(self.optimizer) self.grad_scaler.update() # prohibit scales larger than 65536, training crashes, # maybe due to gradient accumulation? if self.grad_scaler.get_scale() > 65536.0: self.grad_scaler.update(new_scale=65536.0) else: self.optimizer.step() self.optimizer.zero_grad() self.accumulated_batches = 0 def distribute_gradients_if_needed(self): if self.distributed: for _, param in self.net.named_parameters(): if param.grad is not None: dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) def save_checkpoint(self, name, epoch, additional={}): state = { 'checkpoint': epoch, 'state_dict': self.net.state_dict(), 'optimizer': self.optimizer.state_dict() } state.update(additional) print('Saving', 'checkpoint_' + name + '_' + str(epoch) + '_network') try: torch.save( state, self.save_dir / pathlib.Path('checkpoint_' + name + '_' + str(epoch) + '_network')) torch.save( state, self.save_dir / pathlib.Path('checkpoint_' + name + '_last')) except Exception as e: print('WARNING: Network not stored', e) def checkpoint_available_for_name(self, name, epoch=-1): if epoch > -1: print(self.save_dir / pathlib.Path('checkpoint_' + name + '_' + str(epoch) + '_network')) print( os.path.isfile(self.save_dir / pathlib.Path('checkpoint_' + name + '_' + str(epoch) + '_network'))) return os.path.isfile(self.save_dir / pathlib.Path('checkpoint_' + name + '_' + str(epoch) + '_network')) else: return os.path.isfile(self.save_dir / pathlib.Path('checkpoint_' + name + '_last')) def load_network_checkpoint(self, name): state = torch.load(self.save_dir / pathlib.Path('checkpoint_' + name)) self.load_state_dict(state) def load_checkpoint(self, name, epoch=-1): if epoch > -1: state = torch.load(self.save_dir / pathlib.Path('checkpoint_' + name + '_' + str(epoch) + '_network'), map_location=lambda storage, loc: storage) else: state = torch.load(self.save_dir / pathlib.Path('checkpoint_' + name + '_last'), map_location=lambda storage, loc: storage) return state def load_state_dict(self, state): try: self.optimizer.load_state_dict(state['optimizer']) except KeyError: print('WARNING: Optimizer not restored') self.net.load_state_dict(state['state_dict']) def load_checkpoint_if_available(self, name, epoch=-1): if self.checkpoint_available_for_name(name, epoch): state = self.load_checkpoint(name, epoch) self.load_state_dict(state) return True, state return False, None def accuracy_with_predictions(self, predictions, labels): if self.regression: return 0 if self.multilabel: equal = np.equal(np.round(torch.sigmoid(predictions.float())), labels.numpy() == 1) equal_c = np.sum(equal, axis=1) equal = (equal_c == labels.shape[1]).sum() elif predictions.shape[1] == 1: equal = np.equal(np.round(torch.sigmoid(predictions.float())), labels) else: equal = np.equal( np.argmax(torch.softmax(predictions.float(), dim=1), axis=1), labels) return float(equal.sum()) / float(predictions.shape[0])
class TestGradientScalingAMP(unittest.TestCase): def setUp(self): self.x = torch.tensor([2.0]).cuda().half() weight = 3.0 bias = 5.0 self.error = 1.0 self.target = torch.tensor([self.x * weight + bias + self.error ]).cuda() self.loss_fn = torch.nn.L1Loss() self.model = torch.nn.Linear(1, 1) self.model.weight.data = torch.tensor([[weight]]) self.model.bias.data = torch.tensor([bias]) self.model.cuda() self.params = list(self.model.parameters()) self.namespace_dls = argparse.Namespace( optimizer="adam", lr=[0.1], adam_betas="(0.9, 0.999)", adam_eps=1e-8, weight_decay=0.0, threshold_loss_scale=1, min_loss_scale=1e-4, ) self.scaler = GradScaler( init_scale=1, growth_interval=1, ) def run_iter(self, model, params, optimizer): optimizer.zero_grad() with autocast(): y = model(self.x) loss = self.loss_fn(y, self.target) self.scaler.scale(loss).backward() self.assertEqual( loss, torch.tensor(1.0, device="cuda:0", dtype=torch.float16)) self.scaler.unscale_(optimizer) grad_norm = optimizer.clip_grad_norm(0) self.assertAlmostEqual(grad_norm.item(), 2.2361, 4) self.scaler.step(optimizer) self.scaler.update() self.assertEqual( model.weight, torch.tensor([[3.1]], device="cuda:0", requires_grad=True), ) self.assertEqual( model.bias, torch.tensor([5.1], device="cuda:0", requires_grad=True), ) self.assertEqual(self.scaler.get_scale(), 2.0) def test_automatic_mixed_precision(self): model = copy.deepcopy(self.model) params = list(model.parameters()) optimizer = build_optimizer(self.namespace_dls, params) self.run_iter(model, params, optimizer)
def main(args): np.random.seed(args.seed) torch.manual_seed(args.seed) # delete args.output_dir if the flag is set and the directory exists if args.clear_output_dir and args.output_dir.exists(): rmtree(args.output_dir) args.output_dir.mkdir(parents=True, exist_ok=True) args.checkpoint_dir = args.output_dir / 'checkpoints' args.checkpoint_dir.mkdir(parents=True, exist_ok=True) args.cuda = not args.no_cuda and torch.cuda.is_available() args.device = torch.device("cuda" if args.cuda else "cpu") train_loader, val_loader, test_loader = get_loaders(args) summary = Summary(args) scaler = GradScaler(enabled=args.mixed_precision) args.output_logits = (args.loss in ['bce', 'binarycrossentropy'] and args.model != 'identity') model = get_model(args, summary) if args.weights_dir is not None: model = utils.load_weights(args, model) optimizer = optim.Adam(model.parameters(), lr=args.lr) scheduler = StepLR(optimizer, step_size=5, gamma=args.gamma, verbose=args.verbose == 2) loss_function = utils.get_loss_function(name=args.loss) critic = None if args.critic is None else Critic(args, summary=summary) utils.save_args(args) args.global_step = 0 for epoch in range(args.epochs): print(f'Epoch {epoch + 1:03d}/{args.epochs:03d}') start = time() train_results = train(args, model=model, data=train_loader, optimizer=optimizer, loss_function=loss_function, scaler=scaler, summary=summary, epoch=epoch, critic=critic) val_results = validate(args, model=model, data=val_loader, loss_function=loss_function, summary=summary, epoch=epoch, critic=critic) end = time() scheduler.step() summary.scalar('elapse', end - start, step=epoch, mode=0) summary.scalar('lr', scheduler.get_last_lr()[0], step=epoch, mode=0) summary.scalar('gradient_scale', scaler.get_scale(), step=epoch, mode=0) print(f'Train\t\tLoss: {train_results["Loss"]:.04f}\n' f'Validation\tLoss: {val_results["Loss"]:.04f}\t' f'MAE: {val_results["MAE"]:.04f}\t' f'PSNR: {val_results["PSNR"]:.02f}\t' f'SSIM: {val_results["SSIM"]:.04f}\n') utils.save_model(args, model) test(args, model=model, data=test_loader, loss_function=loss_function, summary=summary, epoch=args.epochs, critic=critic) summary.close()
def run_epoch(model, optimizer, train_ldr, logger, debug_mode: bool, tbX_writer, iter_count: int, avg_loss: float, local_rank: int, loss_name: str, save_path: str, gcs_ckpt_handler, scaler: GradScaler = None) -> tuple: """ Performs a forwards and backward pass through the model Args: iter_count (int): count of iterations save_path (str): path to directory where model is saved gcs_ckpt_handler: facilities saving files to google cloud storage scaler (GradScaler): gradient scaler to prevent gradient underflow when autocast uses float16 precision for forward pass Returns: Tuple[int, float]: train state of # batch iterations and average loss """ # booleans and constants for logging is_rank_0 = (torch.distributed.get_rank() == 0) use_log = (logger is not None and is_rank_0) log_modulus = 100 # limits certain logging function to report less frequently exp_w = 0.985 # exponential weight for exponential moving average loss avg_grad_norm = 0 model_t, data_t = 0.0, 0.0 end_t = time.time() # progress bar for rank_0 process tq = tqdm.tqdm(train_ldr) if is_rank_0 else train_ldr # counter for model checkpointing batch_counter = 0 device = torch.device("cuda:" + str(local_rank)) # if scaler is enabled, amp is being used use_amp = scaler.is_enabled() print(f"Amp is being used: {use_amp}") # training loop for batch in tq: if use_log: logger.info( f"train: ====== Iteration: {iter_count} in run_epoch =======") ############## Mid-epoch checkpoint ############### if is_rank_0 \ and batch_counter % (len(train_ldr) // gcs_ckpt_handler.chkpt_per_epoch) == 0 \ and batch_counter != 0: preproc = train_ldr.dataset.preproc save(model.module, preproc, save_path, tag='ckpt') gcs_ckpt_handler.upload_to_gcs("ckpt_model_state_dict.pth") gcs_ckpt_handler.upload_to_gcs("ckpt_preproc.pyc") # save the run_sate ckpt_state_path = os.path.join(save_path, "ckpt_run_state.pickle") write_pickle(ckpt_state_path, {'run_state': (iter_count, avg_loss)}) gcs_ckpt_handler.upload_to_gcs("ckpt_run_state.pickle") # checkpoint tensorboard gcs_ckpt_handler.upload_tensorboard_ckpt() batch_counter += 1 #################################################### # convert the temprorary generator batch to a permanent list batch = list(batch) # save the batch information if use_log: if debug_mode: save_batch_log_stats(batch, logger) log_batchnorm_mean_std(model.module.state_dict(), logger) start_t = time.time() optimizer.zero_grad( set_to_none=True) # set grads to None for modest perf improvement # will autocast to lower precision if amp is used. otherwise, it's no-operation with autocast(enabled=use_amp): # unpack the batch inputs, labels, input_lens, label_lens = model.module.collate( *batch) inputs = inputs.cuda() #.to(device) #.cuda(local_rank) out, rnn_args = model(inputs, softmax=False) # use the loss function defined in `loss_name` if loss_name == "native": loss = native_loss(out, labels, input_lens, label_lens, model.module.blank) elif loss_name == "awni": loss = awni_loss(out, labels, input_lens, label_lens, model.module.blank) elif loss_name == "naren": loss = naren_loss(out, labels, input_lens, label_lens, model.module.blank) # backward pass loss = loss.cuda() # amp needs the loss to be on cuda scaler.scale(loss).backward() if use_log: if debug_mode: plot_grad_flow_bar(model.module.named_parameters(), get_logger_filename(logger)) log_param_grad_norms(model.module.named_parameters(), logger) # gradient clipping and optimizer step, scaling disabled if amp is not used scaler.unscale_(optimizer) grad_norm = nn.utils.clip_grad_norm_(model.parameters(), 200).item() scaler.step(optimizer) scaler.update() # logging in rank_0 process if is_rank_0: # calculate timers prev_end_t = end_t end_t = time.time() model_t += end_t - start_t data_t += start_t - prev_end_t # creating scalers from grad_norm and loss for weighted # TODO, needed with pytorch 0.4, may not be necessary anymore if isinstance(grad_norm, torch.Tensor): grad_norm = grad_norm.item() if isinstance(loss, torch.Tensor): loss = loss.item() # calculating the weighted average of loss and grad_norm if iter_count == 0: avg_loss = loss avg_grad_norm = grad_norm else: avg_loss = exp_w * avg_loss + (1 - exp_w) * loss avg_grad_norm = exp_w * avg_grad_norm + (1 - exp_w) * grad_norm # writing to the tensorboard log files tbX_writer.add_scalars('train/loss', {"loss": loss}, iter_count) tbX_writer.add_scalars('train/loss', {"avg_loss": avg_loss}, iter_count) # adding this to suppress a tbX WARNING about inf values # TODO, this may or may not be a good idea as it masks inf in tensorboard if grad_norm == float('inf') or math.isnan(grad_norm): tbX_grad_norm = 1 else: tbX_grad_norm = grad_norm tbX_writer.add_scalars('train/grad', {"grad_norm": tbX_grad_norm}, iter_count) # progress bar update tq.set_postfix(it=iter_count, grd_nrm=grad_norm, lss=loss, lss_av=avg_loss, t_mdl=model_t, t_data=data_t, scl=scaler.get_scale()) if use_log: logger.info(f'train: loss is inf: {loss == float("inf")}') logger.info( f"train: iter={iter_count}, loss={round(loss,3)}, grad_norm={round(grad_norm,3)}" ) if iter_count % log_modulus == 0: if use_log: log_cpu_mem_disk_usage(logger) # checks for nan gradients if check_nan_params_grads(model.module.parameters()): print("\n~~~ NaN value detected in gradients or parameters ~~~\n") if use_log: logger.error( f"train: labels: {[labels]}, label_lens: {label_lens} state_dict: {model.module.state_dict()}" ) log_model_grads(model.module.named_parameters(), logger) save_batch_log_stats(batch, logger) log_param_grad_norms(model.module.named_parameters(), logger) plot_grad_flow_bar(model.module.named_parameters(), get_logger_filename(logger)) #debug_mode = True #torch.autograd.set_detect_anomaly(True) iter_count += 1 return iter_count, avg_loss
class Trainer(): def __init__(self, cfg, writer, img_writer, logger, run_id): # Copy shared config fields if "monodepth_options" in cfg: cfg["data"].update(cfg["monodepth_options"]) cfg["model"].update(cfg["monodepth_options"]) cfg["training"]["monodepth_loss"].update(cfg["monodepth_options"]) if "generated_depth_dir" in cfg["data"]: dataset_name = f"{cfg['data']['dataset']}_" \ f"{cfg['data']['width']}x{cfg['data']['height']}" depth_teacher = cfg["data"].get("depth_teacher", None) assert not (depth_teacher and cfg['model'].get('detph_estimator_weights') is not None) if depth_teacher is not None: cfg["data"]["generated_depth_dir"] += dataset_name + "/" + depth_teacher + "/" else: cfg["data"]["generated_depth_dir"] += dataset_name + "/" + cfg['model']['depth_estimator_weights'] + "/" # Setup seeds setup_seeds(cfg.get("seed", 1337)) if cfg["data"]["dataset_seed"] == "same": cfg["data"]["dataset_seed"] = cfg["seed"] # Setup device torch.backends.cudnn.benchmark = cfg["training"].get("benchmark", True) self.cfg = cfg self.writer = writer self.img_writer = img_writer self.logger = logger self.run_id = run_id self.mIoU = 0 self.fwAcc = 0 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.setup_segmentation_unlabeled() self.unlabeled_require_depth = (self.cfg["training"]["unlabeled_segmentation"] is not None and (self.cfg["training"]["unlabeled_segmentation"]["mix_mask"] == "depth" or self.cfg["training"]["unlabeled_segmentation"]["mix_mask"] == "depthcomp" or self.cfg["training"]["unlabeled_segmentation"]["mix_mask"] == "depthhist")) # Prepare depth estimates do_precalculate_depth = self.cfg["training"]["segmentation_lambda"] != 0 and self.unlabeled_require_depth and \ self.cfg['model']['segmentation_name'] != 'mtl_pad' use_depth_teacher = cfg["data"].get("depth_teacher", None) is not None if do_precalculate_depth or use_depth_teacher: assert not (do_precalculate_depth and use_depth_teacher) if not self.cfg["training"].get("disable_depth_estimator", False): print("Prepare depth estimates") depth_estimator = DepthEstimator(cfg) depth_estimator.prepare_depth_estimates() del depth_estimator torch.cuda.empty_cache() else: self.cfg["data"]["generated_depth_dir"] = None # Setup Dataloader load_labels, load_sequence = True, True if self.cfg["training"]["monodepth_lambda"] == 0: load_sequence = False if self.cfg["training"]["segmentation_lambda"] == 0: load_labels = False train_data_cfg = deepcopy(self.cfg["data"]) if not do_precalculate_depth and not use_depth_teacher: train_data_cfg["generated_depth_dir"] = None self.train_loader = build_loader(train_data_cfg, "train", load_labels=load_labels, load_sequence=load_sequence) if self.cfg["training"].get("minimize_entropy_unlabeled", False) or self.enable_unlabled_segmentation: unlabeled_segmentation_cfg = deepcopy(self.cfg["data"]) if not self.only_unlabeled and self.mix_use_gt: unlabeled_segmentation_cfg["load_onehot"] = True if self.only_unlabeled: unlabeled_segmentation_cfg.update({"load_unlabeled": True, "load_labeled": False}) elif self.only_labeled: unlabeled_segmentation_cfg.update({"load_unlabeled": False, "load_labeled": True}) else: unlabeled_segmentation_cfg.update({"load_unlabeled": True, "load_labeled": True}) if self.mix_video: assert not self.mix_use_gt and not self.only_labeled and not self.only_unlabeled, \ "Video sample indices are not compatible with non-video indices." unlabeled_segmentation_cfg.update({"only_sequences_with_segmentation": not self.mix_video, "restrict_to_subset": None}) self.unlabeled_loader = build_loader(unlabeled_segmentation_cfg, "train", load_labels=load_labels if not self.mix_video else False, load_sequence=load_sequence) else: self.unlabeled_loader = None self.val_loader = build_loader(self.cfg["data"], "val", load_labels=load_labels, load_sequence=load_sequence) self.n_classes = self.train_loader.n_classes # monodepth dataloader settings uses drop_last=True and shuffle=True even for val self.train_data_loader = data.DataLoader( self.train_loader, batch_size=self.cfg["training"]["batch_size"], num_workers=self.cfg["training"]["n_workers"], shuffle=self.cfg["data"]["shuffle_trainset"], pin_memory=True, # Setting to false will cause crash at the end of epoch drop_last=True, ) if self.unlabeled_loader is not None: self.unlabeled_data_loader = infinite_iterator(data.DataLoader( self.unlabeled_loader, batch_size=self.cfg["training"]["batch_size"], num_workers=self.cfg["training"]["n_workers"], shuffle=self.cfg["data"]["shuffle_trainset"], pin_memory=True, # Setting to false will cause crash at the end of epoch drop_last=True, )) self.val_batch_size = self.cfg["training"]["val_batch_size"] self.val_data_loader = data.DataLoader( self.val_loader, batch_size=self.val_batch_size, num_workers=self.cfg["training"]["n_workers"], pin_memory=True, # If using a dataset with odd number of samples (CamVid), the memory consumption suddenly increases for the # last batch. This can be circumvented by dropping the last batch. Only do that if it is necessary for your # system as it will result in an incomplete validation set. # drop_last=True, ) # Setup Model self.model = get_model(cfg["model"], self.n_classes).to(self.device) # print(self.model) assert not (self.enable_unlabled_segmentation and self.cfg["training"]["save_monodepth_ema"]) if self.enable_unlabled_segmentation and not self.only_labeled: print("Create segmentation ema model.") self.ema_model = self.create_ema_model(self.model).to(self.device) elif self.cfg["training"]["save_monodepth_ema"]: print("Create depth ema model.") # TODO: Try to remove unnecessary components and fit into gpu for better performance self.ema_model = self.create_ema_model(self.model) # .to(self.device) else: self.ema_model = None # Setup optimizer, lr_scheduler and loss function optimizer_cls = get_optimizer(cfg) optimizer_params = {k: v for k, v in cfg["training"]["optimizer"].items() if k not in ["name", "backbone_lr", "pose_lr", "depth_lr", "segmentation_lr"]} train_params = get_train_params(self.model, self.cfg) self.optimizer = optimizer_cls(train_params, **optimizer_params) self.scheduler = get_scheduler(self.optimizer, self.cfg["training"]["lr_schedule"]) # Creates a GradScaler once at the beginning of training. self.scaler = GradScaler(enabled=self.cfg["training"]["amp"]) self.loss_fn = get_segmentation_loss_function(self.cfg) self.monodepth_loss_calculator_train = get_monodepth_loss(self.cfg, is_train=True) self.monodepth_loss_calculator_val = get_monodepth_loss(self.cfg, is_train=False, batch_size=self.val_batch_size) if cfg["training"]["early_stopping"] is None: logger.info("Using No Early Stopping") self.earlyStopping = None else: self.earlyStopping = EarlyStopping( patience=round(cfg["training"]["early_stopping"]["patience"] / cfg["training"]["val_interval"]), min_delta=cfg["training"]["early_stopping"]["min_delta"], cumulative_delta=cfg["training"]["early_stopping"]["cum_delta"], logger=logger ) def extract_monodepth_ema_params(self, model, ema_model): model_names = ["depth"] if not self.cfg["model"]["freeze_backbone"]: model_names.append("encoder") return extract_ema_params(model, ema_model, model_names) def extract_pad_ema_params(self, model, ema_model): model_names = ["depth", "encoder", "mtl_decoder"] return extract_ema_params(model, ema_model, model_names) def create_ema_model(self, model): ema_cfg = deepcopy(self.cfg["model"]) ema_cfg["disable_pose"] = True ema_model = get_model(ema_cfg, self.n_classes) if self.cfg["training"]["save_monodepth_ema"]: mp, mcp = self.extract_monodepth_ema_params(model, ema_model) elif self.cfg['model']['segmentation_name'] == 'mtl_pad': mp, mcp = self.extract_pad_ema_params(model, ema_model) else: mp, mcp = list(model.parameters()), list(ema_model.parameters()) for param in mcp: param.detach_() assert len(mp) == len(mcp), f"len(mp)={len(mp)}; len(mcp)={len(mcp)}" n = len(mp) for i in range(0, n): mcp[i].data[:] = mp[i].to(mcp[i].device, non_blocking=True).data[:].clone() return ema_model def update_ema_variables(self, ema_model, model, alpha_teacher, iteration): if self.cfg["training"]["save_monodepth_ema"]: model_params, ema_params = self.extract_monodepth_ema_params(model, ema_model) elif self.cfg['model']['segmentation_name'] == 'mtl_pad': model_params, ema_params = self.extract_pad_ema_params(model, ema_model) else: model_params, ema_params = model.parameters(), ema_model.parameters() # Use the "true" average until the exponential average is more correct alpha_teacher = min(1 - 1 / (iteration + 1), alpha_teacher) for ema_param, param in zip(ema_params, model_params): ema_param.data[:] = alpha_teacher * ema_param[:].data[:] + \ (1 - alpha_teacher) * param.to(ema_param.device, non_blocking=True)[:].data[:] return ema_model def save_resume(self, step): if self.ema_model is not None: raise NotImplementedError("ema model not supported") state = { "epoch": step + 1, "model_state": self.model.state_dict(), "optimizer_state": self.optimizer.state_dict(), "scheduler_state": self.scheduler.state_dict(), "best_iou": self.best_iou, } save_path = os.path.join( self.writer.file_writer.get_logdir(), "best_model.pkl" ) torch.save(state, save_path) return save_path def save_monodepth_models(self): if self.cfg["training"]["save_monodepth_ema"]: print("Save ema monodepth models.") assert self.ema_model is not None model_to_save = self.ema_model else: model_to_save = self.model models = ["depth", "pose_encoder", "pose"] if not self.cfg["model"]["freeze_backbone"]: models.append("encoder") for model_name in models: save_path = os.path.join(self.writer.file_writer.get_logdir(), "{}.pth".format(model_name)) to_save = model_to_save.models[model_name].state_dict() torch.save(to_save, save_path) def load_resume(self, strict=True, load_model_only=False): if os.path.isfile(self.cfg["training"]["resume"]): self.logger.info( "Loading model and optimizer from checkpoint '{}'".format(self.cfg["training"]["resume"]) ) checkpoint = torch.load(self.cfg["training"]["resume"]) self.model.load_state_dict(checkpoint["model_state"], strict=strict) if not load_model_only: self.optimizer.load_state_dict(checkpoint["optimizer_state"]) self.scheduler.load_state_dict(checkpoint["scheduler_state"]) self.start_iter = checkpoint["epoch"] self.best_iou = checkpoint["best_iou"] self.logger.info( "Loaded checkpoint '{}' (iter {})".format( self.cfg["training"]["resume"], checkpoint["epoch"] ) ) else: self.logger.info("No checkpoint found at '{}'".format(self.cfg["training"]["resume"])) def tensorboard_training_images(self): num_saved = 0 if self.cfg["training"]["n_tensorboard_trainimgs"] == 0: return for inputs in self.train_data_loader: images = inputs[("color_aug", 0, 0)] labels = inputs["lbl"] for img, label in zip(images.numpy(), labels.numpy()): if num_saved < self.cfg["training"]["n_tensorboard_trainimgs"]: num_saved += 1 self.img_writer.add_image( "trainset_{}/{}_0image".format(self.run_id.replace('/', '_'), num_saved), img, global_step=0) colored_image = self.val_loader.decode_segmap_tocolor(label) self.img_writer.add_image( "trainset_{}/{}_1ground_truth".format(self.run_id.replace('/', '_'), num_saved), colored_image, global_step=0, dataformats="HWC") if num_saved >= self.cfg["training"]["n_tensorboard_trainimgs"]: break def _train_batchnorm(self, model, train, only_encoder=False): if only_encoder: modules = model.models["encoder"].modules() else: modules = model.modules() for m in modules: if isinstance(m, nn.BatchNorm2d): m.train(train) def train_step(self, inputs, step): self.model.train() if self.ema_model is not None: self.ema_model.train() for k, v in inputs.items(): if torch.is_tensor(v): inputs[k] = v.to(self.device, non_blocking=True) if self.enable_unlabled_segmentation: unlabeled_inputs = self.unlabeled_data_loader.__next__() for k in unlabeled_inputs.keys(): if "color_aug" in k or "K" in k or "inv_K" in k or "color" in k or k in ["onehot_lbl", "pseudo_depth"]: # print(f"Move {k} to gpu.") unlabeled_inputs[k] = unlabeled_inputs[k].to(self.device, non_blocking=True) self.optimizer.zero_grad() segmentation_loss = torch.tensor(0) segmentation_total_loss = torch.tensor(0) mono_loss = torch.tensor(0) feat_dist_loss = torch.tensor(0) mono_total_loss = torch.tensor(0) if self.cfg["model"].get("freeze_backbone_bn", False): self._train_batchnorm(self.model, False, only_encoder=True) with autocast(enabled=self.cfg["training"]["amp"]): outputs = self.model(inputs) # Train monodepth if self.cfg["training"]["monodepth_lambda"] > 0: for k, v in outputs.items(): if "depth" in k or "cam_T_cam" in k: outputs[k] = v.to(torch.float32) self.monodepth_loss_calculator_train.generate_images_pred(inputs, outputs) mono_losses = self.monodepth_loss_calculator_train.compute_losses(inputs, outputs) mono_lambda = self.cfg["training"]["monodepth_lambda"] mono_loss = mono_lambda * mono_losses["loss"] feat_dist_lambda = self.cfg["training"]["feat_dist_lambda"] if feat_dist_lambda > 0: feat_dist = torch.dist(outputs["encoder_features"], outputs["imnet_features"], p=2) feat_dist_loss = feat_dist_lambda * feat_dist mono_total_loss = mono_loss + feat_dist_loss self.scaler.scale(mono_total_loss).backward(retain_graph=True) # Train depth on pseudo-labels if self.cfg["training"].get("pseudo_depth_lambda", 0) > 0: # Crop away bottom of image with own car with torch.no_grad(): depth_loss_mask = torch.ones(outputs["disp", 0].shape, device=self.device) depth_loss_mask[:, :, int(outputs["disp", 0].shape[2] * 0.9):, :] = 0 pseudo_depth_loss = berhu(outputs["disp", 0], inputs["pseudo_depth"], depth_loss_mask) pseudo_depth_loss *= self.cfg["training"]["pseudo_depth_lambda"] self.scaler.scale(pseudo_depth_loss).backward(retain_graph=True) else: pseudo_depth_loss = torch.tensor(0) # Train segmentation if self.cfg["training"]["segmentation_lambda"] > 0: with autocast(enabled=self.cfg["training"]["amp"]): segmentation_loss = self.loss_fn(input=outputs["semantics"], target=inputs["lbl"]) if "intermediate_semantics" in outputs: segmentation_loss += self.loss_fn(input=outputs["intermediate_semantics"], target=inputs["lbl"]) segmentation_loss /= 2 segmentation_loss *= self.cfg["training"]["segmentation_lambda"] segmentation_total_loss = segmentation_loss self.scaler.scale(segmentation_total_loss).backward() if self.enable_unlabled_segmentation: unlabeled_loss, unlabeled_mono_loss = self.train_step_segmentation_unlabeled(unlabeled_inputs, step) segmentation_total_loss += unlabeled_loss mono_total_loss += unlabeled_mono_loss if self.cfg["training"].get("clip_grad_norm") is not None: # Unscales the gradients of optimizer's assigned params in-place self.scaler.unscale_(self.optimizer) # Since the gradients of optimizer's assigned params are unscaled, clips as usual: if self.cfg["training"].get("disable_depth_grad_clip", False): torch.nn.utils.clip_grad_norm_(get_params(self.model, ["encoder", "segmentation"]), self.cfg["training"]["clip_grad_norm"]) else: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg["training"]["clip_grad_norm"]) # optimizer's gradients are already unscaled, so scaler.step does not unscale them, # although it still skips optimizer.step() if the gradients contain infs or NaNs. self.scaler.step(self.optimizer) self.scaler.update() if isinstance(self.scheduler, ReduceLROnPlateau): self.scheduler.step(metrics=self.mIoU) else: self.scheduler.step() # update Mean teacher network if self.ema_model is not None: self.ema_model = self.update_ema_variables(ema_model=self.ema_model, model=self.model, alpha_teacher=0.99, iteration=step) total_loss = segmentation_total_loss + mono_total_loss + pseudo_depth_loss return { 'segmentation_loss': segmentation_loss.detach(), 'mono_loss': mono_loss.detach(), 'pseudo_depth_loss': pseudo_depth_loss.detach(), 'feat_dist_loss': feat_dist_loss.detach(), 'segmentation_total_loss': segmentation_total_loss.detach(), 'mono_total_loss': mono_total_loss.detach(), 'total_loss': total_loss.detach() } def setup_segmentation_unlabeled(self): if self.cfg["training"].get("unlabeled_segmentation", None) is None: self.enable_unlabled_segmentation = False return unlabeled_cfg = self.cfg["training"]["unlabeled_segmentation"] self.enable_unlabled_segmentation = True self.consistency_weight = unlabeled_cfg["consistency_weight"] self.mix_mask = unlabeled_cfg.get("mix_mask", None) self.unlabeled_color_jitter = unlabeled_cfg.get("color_jitter") self.unlabeled_blur = unlabeled_cfg.get("blur") self.only_unlabeled = unlabeled_cfg.get("only_unlabeled", True) self.only_labeled = unlabeled_cfg.get("only_labeled", False) self.mix_video = unlabeled_cfg.get("mix_video", False) assert not (self.only_unlabeled and self.only_labeled) self.mix_use_gt = unlabeled_cfg.get("mix_use_gt", False) self.unlabeled_debug_imgs = unlabeled_cfg.get("debug_images", False) self.depthcomp_margin = unlabeled_cfg["depthcomp_margin"] self.depthcomp_foreground_threshold = unlabeled_cfg["depthcomp_foreground_threshold"] self.unlabeled_backward_first_pseudo_label = unlabeled_cfg["backward_first_pseudo_label"] self.depthmix_online_depth = unlabeled_cfg.get("depthmix_online_depth", False) def generate_mix_mask(self, mode, argmax_u_w, unlabeled_imgs, depths): if mode == "class": for image_i in range(self.cfg["training"]["batch_size"]): classes = torch.unique(argmax_u_w[image_i]) classes = classes[classes != 250] nclasses = classes.shape[0] classes = (classes[torch.Tensor( np.random.choice(nclasses, int((nclasses - nclasses % 2) / 2), replace=False)).long()]).cuda() if image_i == 0: MixMask = transformmasks.generate_class_mask(argmax_u_w[image_i], classes).unsqueeze(0).cuda() else: MixMask = torch.cat( (MixMask, transformmasks.generate_class_mask(argmax_u_w[image_i], classes).unsqueeze(0).cuda())) elif self.mix_mask == "depthcomp": assert self.cfg["training"]["batch_size"] == 2 for image_i, other_image_i in [(0, 1), (1, 0)]: own_disp = depths[image_i] other_disp = depths[other_image_i] # Margin avoids too much of mixing road with same depth foreground_mask = torch.ge(own_disp, other_disp - self.depthcomp_margin).long() # Avoid hiding the real background of the other image with own a bit closer background if isinstance(self.depthcomp_foreground_threshold, tuple) or isinstance( self.depthcomp_foreground_threshold, list): ft_l, ft_u = self.depthcomp_foreground_threshold assert ft_u > ft_l ft = torch.rand(1, device=own_disp.device) * (ft_u - ft_l) + ft_l else: ft = self.depthcomp_foreground_threshold foreground_mask *= torch.ge(own_disp, ft).long() if image_i == 0: MixMask = foreground_mask else: MixMask = torch.cat((MixMask, foreground_mask)) elif mode == "depth": for image_i in range(self.cfg["training"]["batch_size"]): generated_depth = depths[image_i] min_depth = 0.1 max_depth = 0.4 depth_threshold = torch.rand(1, device=depths.device) * (max_depth - min_depth) + min_depth if image_i == 0: MixMask = transformmasks.generate_depth_mask(generated_depth, depth_threshold).cuda() else: MixMask = torch.cat( (MixMask, transformmasks.generate_depth_mask(generated_depth, depth_threshold).cuda())) elif mode == "depthhist": for image_i in range(self.cfg["training"]["batch_size"]): generated_depth = depths[image_i] hist, bin_edges = np.histogram(torch.log(1 + generated_depth).flatten(), bins=100, density=True) # Exclude the first bin as it sometimes has a meaningless peak for v, e in zip(np.flip(hist)[1:], np.flip(bin_edges)[1:]): if v > 1.5: max_depth = torch.tensor([e]) break hist = np.cumsum(hist) / np.sum(hist) for v, e in zip(hist, bin_edges): if v > 0.4: min_depth = torch.tensor([e]) break depth_threshold = torch.rand(1) * (max_depth - min_depth) + min_depth if image_i == 0: MixMask = transformmasks.generate_depth_mask(generated_depth, depth_threshold).cuda() else: MixMask = torch.cat( (MixMask, transformmasks.generate_depth_mask(generated_depth, depth_threshold).cuda())) elif mode is None: MixMask = torch.ones((unlabeled_imgs.shape[0], *unlabeled_imgs.shape[2:]), device=self.device) else: raise NotImplementedError(f"Unknown mix_mask {self.mix_mask}") return MixMask def calc_pseudo_label_loss(self, teacher_softmax, student_logits): max_probs, pseudo_label = torch.max(teacher_softmax, dim=1) pseudo_label[max_probs == 0] = self.unlabeled_loader.ignore_index unlabeled_weight = torch.sum(max_probs.ge(0.968).long() == 1).item() / np.prod(pseudo_label.shape) pixelWiseWeight = unlabeled_weight * torch.ones(max_probs.shape, device=self.device) L_u = self.consistency_weight * cross_entropy2d(input=student_logits, target=pseudo_label, pixel_weights=pixelWiseWeight) return L_u, pseudo_label def train_step_segmentation_unlabeled(self, unlabeled_inputs, step): def strongTransform(parameters, data=None, target=None): assert ((data is not None) or (target is not None)) data, target = transformsgpu.mix(mask=parameters["Mix"], data=data, target=target) data, target = transformsgpu.color_jitter(jitter=parameters["ColorJitter"], data=data, target=target) data, target = transformsgpu.gaussian_blur(blur=parameters["GaussianBlur"], data=data, target=None) return data, target unlabeled_imgs = unlabeled_inputs[("color_aug", 0, 0)] # First Step: Run teacher to generate pseudo labels self.ema_model.use_pose_net = False logits_u_w = self.ema_model(unlabeled_inputs)["semantics"] softmax_u_w = torch.softmax(logits_u_w.detach(), dim=1) if self.mix_use_gt: with torch.no_grad(): for i in range(unlabeled_imgs.shape[0]): # .data is necessary to access truth value of tensor if unlabeled_inputs["is_labeled"][i].data: softmax_u_w[i] = unlabeled_inputs["onehot_lbl"][i] _, argmax_u_w = torch.max(softmax_u_w, dim=1) # Second Step: Run student network on unaugmented data to generate depth for DepthMix, calculate monodepth loss, # and unaugmented segmentation pseudo label loss mono_loss = 0 L_1 = 0 if self.depthmix_online_depth: outputs_1 = self.model(unlabeled_inputs) if self.cfg["training"]["monodepth_lambda"] > 0: self.monodepth_loss_calculator_train.generate_images_pred(unlabeled_inputs, outputs_1) mono_losses = self.monodepth_loss_calculator_train.compute_losses(unlabeled_inputs, outputs_1) mono_lambda = self.cfg["training"]["monodepth_lambda"] mono_loss = mono_lambda * mono_losses["loss"] self.scaler.scale(mono_loss).backward(retain_graph=self.unlabeled_backward_first_pseudo_label) depths = outputs_1[("disp", 0)].detach() for j in range(depths.shape[0]): dmin = torch.min(depths[j]) dmax = torch.max(depths[j]) depths[j] = torch.clamp(depths[j], dmin, dmax) depths[j] = (depths[j] - dmin) / (dmax - dmin) else: depths = unlabeled_inputs["pseudo_depth"] if self.unlabeled_backward_first_pseudo_label: logits_1 = outputs_1["semantics"] L_1, _ = self.calc_pseudo_label_loss(teacher_softmax=softmax_u_w, student_logits=logits_1) self.scaler.scale(L_1).backward() elif "pseudo_depth" in unlabeled_inputs: depths = unlabeled_inputs["pseudo_depth"] else: depths = [None] * unlabeled_imgs.shape[0] # Third Step: Run Mix MixMask = self.generate_mix_mask(self.mix_mask, argmax_u_w, unlabeled_imgs, depths) strong_parameters = {"Mix": MixMask} if self.unlabeled_color_jitter: strong_parameters["ColorJitter"] = random.uniform(0, 1) else: strong_parameters["ColorJitter"] = 0 if self.unlabeled_blur: strong_parameters["GaussianBlur"] = random.uniform(0, 1) else: strong_parameters["GaussianBlur"] = 0 inputs_u_s, _ = strongTransform(strong_parameters, data=unlabeled_imgs) unlabeled_inputs[("color_aug", 0, 0)] = inputs_u_s outputs = self.model(unlabeled_inputs) logits_u_s = outputs["semantics"] softmax_u_w_mixed, _ = strongTransform(strong_parameters, data=softmax_u_w) L_2, pseudo_label = self.calc_pseudo_label_loss(teacher_softmax=softmax_u_w_mixed, student_logits=logits_u_s) self.scaler.scale(L_2).backward() for j, (f, img, ps_lab, mask, d) in enumerate( zip(unlabeled_inputs["filename"], inputs_u_s, pseudo_label, MixMask, depths)): if (step + 1) % self.cfg["training"]["print_interval"] != 0: continue fn = f"{self.cfg['training']['log_path']}/class_mix_debug/{step}_{j}_img.jpg" os.makedirs(os.path.dirname(fn), exist_ok=True) rows, cols = 2, 2 fig, axs = plt.subplots(rows, cols, sharex='col', sharey='row', gridspec_kw={'hspace': 0, 'wspace': 0}, figsize=(4 * cols, 4 * rows)) axs[0][0].imshow(img.permute(1, 2, 0).cpu().numpy()) axs[0][1].imshow(mask.float().cpu().numpy(), cmap="gray") if d is not None: axs[1][1].imshow(d[0].cpu().numpy(), cmap="plasma") axs[1][0].imshow(self.val_loader.decode_segmap_tocolor(ps_lab.cpu().numpy())) for ax in axs.flat: ax.axis("off") plt.savefig(fn) plt.close() return L_2 + L_1, mono_loss def train(self): self.start_iter = 0 self.best_iou = -100.0 if self.cfg["training"]["resume"] is not None: self.load_resume() for param_group in self.optimizer.param_groups: param_group['lr'] = self.cfg["training"]["optimizer"]["lr"] train_loss_meter = AverageMeterDict() time_meter = AverageMeter() step = self.start_iter flag = True self.tensorboard_training_images() start_ts = time.time() while step <= self.cfg["training"]["train_iters"] and flag: for inputs in self.train_data_loader: # torch.cuda.empty_cache() step += 1 losses = self.train_step(inputs, step) time_meter.update(time.time() - start_ts) train_loss_meter.update(losses) if (step + 1) % self.cfg["training"]["print_interval"] == 0: fmt_str = "Iter [{}/{}] Loss: {:.4f} Time/Image: {:.4f}" print_str = fmt_str.format( step + 1, self.cfg["training"]["train_iters"], train_loss_meter.avgs["total_loss"], time_meter.avg / self.cfg["training"]["batch_size"], ) self.logger.info(print_str) for k, v in train_loss_meter.avgs.items(): self.writer.add_scalar("training/" + k, v, step + 1) self.writer.add_scalar("training/learning_rate", get_lr(self.optimizer), step + 1) self.writer.add_scalar("training/time_per_image", time_meter.avg / self.cfg["training"]["batch_size"], step + 1) self.writer.add_scalar("training/amp_scale", self.scaler.get_scale(), step + 1) self.writer.add_scalar("training/memory", psutil.virtual_memory().used / 1e9, step + 1) time_meter.reset() train_loss_meter.reset() if (step + 1) % current_val_interval(self.cfg, step + 1) == 0 or (step + 1) == self.cfg["training"][ "train_iters" ]: self.validate(step) if self.mIoU >= self.best_iou: self.best_iou = self.mIoU if self.cfg["training"]["save_model"]: self.save_resume(step) if self.earlyStopping is not None: if not self.earlyStopping.step(self.mIoU): flag = False break if (step + 1) == self.cfg["training"]["train_iters"]: flag = False break start_ts = time.time() return step def validate(self, step): self.model.eval() val_loss_meter = AverageMeterDict() running_metrics_val = runningScore(self.n_classes) imgs_to_save = [] with torch.no_grad(): for inputs_val in tqdm(self.val_data_loader, total=len(self.val_data_loader)): if self.cfg["model"]["disable_monodepth"]: required_inputs = [("color_aug", 0, 0), "lbl"] else: required_inputs = inputs_val.keys() for k, v in inputs_val.items(): if torch.is_tensor(v) and k in required_inputs: inputs_val[k] = v.to(self.device, non_blocking=True) images_val = inputs_val[("color_aug", 0, 0)] with autocast(enabled=self.cfg["training"]["amp"]): outputs = self.model(inputs_val) if self.cfg["training"]["segmentation_lambda"] > 0: labels_val = inputs_val["lbl"] semantics = outputs["semantics"] val_segmentation_loss = self.loss_fn(input=semantics, target=labels_val) # Handle inconsistent size between input and target n, c, h, w = semantics.size() nt, ht, wt = labels_val.size() if h != ht and w != wt: # upsample labels semantics = F.interpolate( semantics, size=(ht, wt), mode="bilinear", align_corners=True ) pred = semantics.data.max(1)[1].cpu().numpy() gt = labels_val.data.cpu().numpy() running_metrics_val.update(gt, pred) else: pred = [None] * images_val.shape[0] gt = [None] * images_val.shape[0] val_segmentation_loss = torch.tensor(0) if not self.cfg["model"]["disable_monodepth"]: if not self.cfg["model"]["disable_pose"]: self.monodepth_loss_calculator_val.generate_images_pred(inputs_val, outputs) mono_losses = self.monodepth_loss_calculator_val.compute_losses(inputs_val, outputs) val_mono_loss = mono_losses["loss"] else: outputs.update(self.model.predict_test_disp(inputs_val)) self.monodepth_loss_calculator_val.generate_depth_test_pred(outputs) val_mono_loss = torch.tensor(0) else: outputs[("disp", 0)] = [None] * images_val.shape[0] val_mono_loss = torch.tensor(0) if self.cfg["data"].get("depth_teacher", None) is not None: # Crop away bottom of image with own car with torch.no_grad(): depth_loss_mask = torch.ones(outputs["disp", 0].shape, device=self.device) depth_loss_mask[:, :, int(outputs["disp", 0].shape[2] * 0.9):, :] = 0 val_pseudo_depth_loss = berhu(outputs["disp", 0], inputs_val["pseudo_depth"], depth_loss_mask, apply_log=self.cfg["training"].get("pseudo_depth_loss_log", False)) else: val_pseudo_depth_loss = torch.tensor(0) val_loss_meter.update({ "segmentation_loss": val_segmentation_loss.detach(), "monodepth_loss": val_mono_loss.detach(), "pseudo_depth_loss": val_pseudo_depth_loss.detach() }) for img, label, output, depth in zip(images_val, gt, pred, outputs[("disp", 0)]): if len(imgs_to_save) < self.cfg["training"]["n_tensorboard_imgs"]: imgs_to_save.append([ img, label, output, depth if depth is None else depth.detach()]) for k, v in val_loss_meter.avgs.items(): self.writer.add_scalar("validation/" + k, v, step + 1) if self.cfg["training"]["segmentation_lambda"] > 0: score, class_iou = running_metrics_val.get_scores() for k, v in score.items(): print(k, v) self.writer.add_scalar("val_metrics/{}".format(k), v, step + 1) for k, v in class_iou.items(): self.writer.add_scalar("val_metrics/cls_{}".format(k), v, step + 1) self.mIoU = score["Mean IoU : \t"] self.fwAcc = score["FreqW Acc : \t"] for j, imgs in enumerate(imgs_to_save): # Only log the first image as they won't change -> save memory if (step + 1) // current_val_interval(self.cfg, step + 1) == 1: self.img_writer.add_image( "{}/{}_0image".format(self.run_id.replace('/', '_'), j), imgs[0], global_step=step + 1) if imgs[1] is not None: colored_image = self.val_loader.decode_segmap_tocolor(imgs[1]) self.img_writer.add_image( "{}/{}_1ground_truth".format(self.run_id.replace('/', '_'), j), colored_image, global_step=step + 1, dataformats="HWC") if imgs[2] is not None: colored_image = self.val_loader.decode_segmap_tocolor(imgs[2]) self.img_writer.add_image( "{}/{}_2prediction".format(self.run_id.replace('/', '_'), j), colored_image, global_step=step + 1, dataformats="HWC") if imgs[3] is not None: colored_image = _colorize(imgs[3], "plasma", max_percentile=100) self.img_writer.add_image( "{}/{}_3depth".format(self.run_id.replace('/', '_'), j), colored_image, global_step=step + 1, dataformats="HWC")
class BaseModule(nn.Module): def __init__(self, cuda=True, warmup_ratio=0.1, num_training_steps=1000, device_idxs=(), mixed_precision=False): super().__init__() # Other parameters self.num_warmup_steps = int(warmup_ratio * num_training_steps) self.num_training_steps = num_training_steps self.cuda = cuda if self.cuda: self.devices = device_idxs else: self.devices = ['cpu'] self.model_device = device_idxs[0] self.mixed_precision = mixed_precision # Mixed precision training support if self.mixed_precision: self.scaler = GradScaler() def linear_scheduler(self, optimizer, last_epoch=-1): return lr_scheduler.LambdaLR(optimizer, self.lr_lambda, last_epoch) def lr_lambda(self, current_step): if current_step < self.num_warmup_steps: return float(current_step) / float(max(1, self.num_warmup_steps)) return max( 0.0, float(self.num_training_steps - current_step) / float( max(1, self.num_training_steps - self.num_warmup_steps)) ) def backward(self, r=1, l2=False): # Loss scaling (can be used for accumulation normalizing) self.loss_grad = self.loss_grad * r # L2 normalization if l2: if self.mixed_precision: grad_params = torch.autograd.grad(self.scaler.scale(self.loss_grad), self.parameters(), create_graph=True) inv_scale = 1 / self.scaler.get_scale() grad_params = [p * inv_scale for p in grad_params] else: grad_params = torch.autograd.grad(self.loss_grad, self.parameters(), create_graph=True) with autocast(self.mixed_precision): grad_norm = 0 for grad in grad_params: grad_norm += grad.pow(2).sum() grad_norm = grad_norm.sqrt() self.loss_grad = self.loss_grad + grad_norm # Backward if self.mixed_precision: self.scaler.scale(self.loss_grad).backward() else: self.loss_grad.backward() def optimize(self, clip=True): if clip: if self.mixed_precision: self.scaler.unscale_(self.optimizer) nn.utils.clip_grad_norm_(self.parameters(), 1.0) if self.mixed_precision: self.scaler.step(self.optimizer) self.scaler.update() else: self.optimizer.step() self.scheduler.step() def save_model(self, checkpoint_name, state_dict_only=True): dataparallel = self.single_gpu() if not os.path.isdir('checkpoint'): os.makedirs('checkpoint', exist_ok=True) save_path = os.path.join('checkpoint', checkpoint_name + '.th') if state_dict_only: torch.save(self.state_dict(), save_path) else: torch.save(self, save_path) self.multi_gpus(dataparallel) saved_component = 'state dict' if state_dict_only else 'model' print(f'Saved {saved_component} to {save_path}') def load_model(self, path, is_state_dict=True): dataparallel = self.single_gpu() state_dict = torch.load(path, map_location='cpu') if not is_state_dict: state_dict = state_dict.state_dict() self.load_state_dict(state_dict) self.multi_gpus(dataparallel) loaded_component = 'state dict' if is_state_dict else 'model' print(f'Loaded {loaded_component} from {path}') @classmethod def tensor(cls, x): try: return torch.tensor(x) except: return torch.stack(x) @classmethod def get_pad_amount(cls, max_lens, x, first=False): zeros = torch.zeros_like(max_lens) if first: idxs = torch.stack([zeros, max_lens - torch.tensor(x.shape)]) else: idxs = torch.stack([max_lens - torch.tensor(x.shape), zeros]) return list(idxs.T.reshape(-1).flip(0)) @classmethod def pad_seq(cls, x, val=0, first=False): if isinstance(x, torch.Tensor): return x try: return BaseModule.tensor(x) except: x = [BaseModule.pad_seq(x_, val=val, first=first) for x_ in x] max_lens = torch.tensor([max(x_.shape[i] for x_ in x) for i in range(x[0].ndim)]) return torch.stack([pad(x_, pad=BaseModule.get_pad_amount(max_lens, x_, first), value=val) for x_ in x]) @classmethod def getattr(cls, obj, name, *args, **kwargs): if '.' in name: split_index = name.index('.') return cls.getattr(getattr(obj, name[:split_index]), name[split_index + 1:], *args, **kwargs) return getattr(obj, name, *args, **kwargs) @classmethod def setattr(cls, obj, name, value): if '.' in name: split_index = name.index('.') return cls.setattr(getattr(obj, name[:split_index]), name[split_index + 1:], value) return setattr(obj, name, value) def single_gpu(self): dataparallel = set() for name, module in self.named_modules(): if isinstance(module, nn.DataParallel): dataparallel.add(name) for name in dataparallel: BaseModule.setattr(self, name, BaseModule.getattr(self, name).module) return dataparallel def multi_gpus(self, modules): for name in modules: BaseModule.setattr(self, name, DataParallel(BaseModule.getattr(self, name), device_ids=self.devices, output_device=self.model_device)) @classmethod def ratio(cls, x, y, ndigits=3): if y == 0: return 0 return round(x / y, ndigits) def make_position_ids(self, attention_mask): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 0) return position_ids
def train(net, data_loader, criterion, optimizer, epochs=10, save_every=20, model_path=None, use_drive=False, resume=False, reset=False, track_grad_norm=False, scheduler=None, plot=False, use_amp=False): "Training Loop" device = next(net.parameters()).device save_path, load_path = search_drive(model_path, use_drive) init_epoch = 0 if load_path and os.path.exists(load_path) and not reset: checkpoint = torch.load(load_path, map_location=device) if 'net_state_dict' in checkpoint: net.load_state_dict(checkpoint['net_state_dict']) else: net.load_state_dict(checkpoint) if 'epoch' in checkpoint: init_epoch = checkpoint['epoch'] if 'optimizer_state_dict' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer_state_dict']) print("Training Checkpoint restored: " + load_path) if not resume: net.eval() return else: if model_path: print("No Checkpoint found / Reset.") if save_path: print("Path: " + save_path) assert valid_data_loader( data_loader), f"invalid data_loader: {data_loader}" net.train() USE_AMP = device.type == 'cuda' and use_amp if USE_AMP: scaler = GradScaler() TRACKING = None if plot: TRACKING = defaultdict(list, loss=[]) print("Beginning training.", flush=True) with tqdmEpoch(epochs, len(data_loader)) as pbar: saved_epoch = 0 for epoch in range(1 + init_epoch, 1 + init_epoch + epochs): total_count = 0.0 total_loss = 0.0 total_correct = 0.0 grad_total = 0.0 for inputs, labels in data_loader: optimizer.zero_grad() if USE_AMP: with autocast(): outputs = net(inputs) loss = criterion(outputs, labels) scaler.scale(loss).backward() grad_scale = scaler.get_scale() else: outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() grad_scale = 1 if track_grad_norm: for param in net.parameters(): grad_total += (param.grad.norm(2) / grad_scale).item() if USE_AMP: scaler.step(optimizer) scaler.update() else: optimizer.step() batch_size = len(inputs) total_count += batch_size total_loss += loss.item() * batch_size total_correct += count_correct(outputs, labels) pbar.set_postfix( loss=total_loss / total_count, acc=f"{total_correct / total_count * 100:.0f}%", chkpt=saved_epoch, refresh=False, ) pbar.update() loss = total_loss / total_count accuracy = total_correct / total_count # grad_norm = grad_total / total_count if scheduler is not None: scheduler.step(loss) if TRACKING: TRACKING['loss'].append(loss) TRACKING['accuracy'].append(accuracy) if track_grad_norm: TRACKING['|grad|'].append(grad_norm) if save_path is not None \ and (save_every is not None and epoch % save_every == 0 or epoch == init_epoch + epochs - 1): torch.save( { 'epoch': epoch, 'net_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), }, save_path) saved_epoch = epoch pbar.set_postfix( loss=total_loss / total_count, acc=f"{total_correct / total_count * 100:.0f}%", chkpt=saved_epoch, ) print(flush=True, end='') # net.eval() if TRACKING: plot_metrics(TRACKING, step_start=init_epoch) # plt.xlabel('epochs') # plt.show() return TRACKING
def invert( data_loader, loss_fn, optimizer, steps=10, scheduler=None, use_amp=False, grad_norm_fn=None, callback_fn=None, plot=False, fig_path=None, track_per_batch=False, track_grad_norm=False, print_grouped=False, ): assert valid_data_loader( data_loader), f"invalid data_loader: {data_loader}" params = sum((p_group['params'] for p_group in optimizer.param_groups), []) lrs = [p_group['lr'] for p_group in optimizer.param_groups] device = params[0].device USE_AMP = (device.type == 'cuda') and use_amp if USE_AMP: scaler = GradScaler() num_batches = len(data_loader) track_len = steps * num_batches if track_per_batch else steps metrics = pd.DataFrame({'step': [None] * track_len}) def process_result(res): if isinstance(res, dict): loss = res['loss'] info = res for k, v in info.items(): info[k] = v.item() if isinstance(v, torch.Tensor) else v else: loss = res info = {'loss': loss.item()} return loss, info print(flush=True) if callback_fn: callback_fn(0, None) with tqdmEpoch(steps, num_batches) as pbar: for epoch in range(steps): for batch_i, data in enumerate(data_loader): optimizer.zero_grad() if USE_AMP: with autocast(): res = loss_fn(data) loss, info = process_result(res) scaler.scale(loss).backward() grad_scale = scaler.get_scale() else: res = loss_fn(data) loss, info = process_result(res) loss.backward() grad_scale = 1 if USE_AMP: scaler.step(optimizer) scaler.update() else: optimizer.step() if scheduler is not None: scheduler.step(loss) if track_grad_norm or grad_norm_fn: # XXX: probably shouldn't multiply with lr total_norm = torch.norm( torch.stack([ p.grad.detach().norm() / grad_scale # * lr for p, lr in zip(params, lrs) ])).item() if grad_norm_fn: rescale_coef = grad_norm_fn(total_norm) / total_norm for param in params: param.grad.detach().mul_(rescale_coef) info['|grad|'] = total_norm pbar.set_postfix(**{ k: v for k, v in info.items() if ']' not in k }, refresh=False) pbar.update() if track_per_batch: batch_total = epoch * num_batches + batch_i step = batch_total # step = epoch + (batch_i + 1) / num_batches else: step = epoch # step = epoch + 1 + batch_i / num_batches for k, v in info.items(): if k not in metrics: # add new column metrics[k] = None if metrics[k][step] is None: metrics[k][step] = v else: metrics[k][step] += v if not track_per_batch and batch_i == 0: metrics['step'][epoch] = epoch + 1 if track_per_batch: metrics['step'][batch_total] = (batch_total + 1) / num_batches # batch end if not track_per_batch: for k, v in metrics.items(): if k != 'step': metrics[k][epoch] /= num_batches if callback_fn: callback_fn(epoch + 1, metrics.iloc[step]) # epoch end print(flush=True) if plot and steps > 1: plot_metrics(metrics, fig_path=fig_path, smoothing=0) return metrics
class Trainer: def __init__( self, name="default", results_dir="results", models_dir="models", base_dir="./", optimizer="adam", latent_dim=256, image_size=128, fmap_max=512, transparent=False, batch_size=4, gp_weight=10, gradient_accumulate_every=1, attn_res_layers=[], sle_spatial=False, disc_output_size=5, antialias=False, lr=2e-4, lr_mlp=1.0, ttur_mult=1.0, save_every=1000, evaluate_every=1000, trunc_psi=0.6, aug_prob=None, aug_types=["translation", "cutout"], dataset_aug_prob=0.0, calculate_fid_every=None, is_ddp=False, rank=0, world_size=1, log=False, amp=False, *args, **kwargs, ): self.GAN_params = [args, kwargs] self.GAN = None self.name = name base_dir = Path(base_dir) self.base_dir = base_dir self.results_dir = base_dir / results_dir self.models_dir = base_dir / models_dir self.config_path = self.models_dir / name / ".config.json" assert is_power_of_two( image_size ), "image size must be a power of 2 (64, 128, 256, 512, 1024)" assert all( map(is_power_of_two, attn_res_layers) ), "resolution layers of attention must all be powers of 2 (16, 32, 64, 128, 256, 512)" self.optimizer = optimizer self.latent_dim = latent_dim self.image_size = image_size self.fmap_max = fmap_max self.transparent = transparent self.aug_prob = aug_prob self.aug_types = aug_types self.lr = lr self.ttur_mult = ttur_mult self.batch_size = batch_size self.gradient_accumulate_every = gradient_accumulate_every self.gp_weight = gp_weight self.evaluate_every = evaluate_every self.save_every = save_every self.steps = 0 self.generator_top_k_gamma = 0.99 self.generator_top_k_frac = 0.5 self.attn_res_layers = attn_res_layers self.sle_spatial = sle_spatial self.disc_output_size = disc_output_size self.antialias = antialias self.d_loss = 0 self.g_loss = 0 self.last_gp_loss = None self.last_recon_loss = None self.last_fid = None self.init_folders() self.loader = None self.dataset_aug_prob = dataset_aug_prob self.calculate_fid_every = calculate_fid_every self.is_ddp = is_ddp self.is_main = rank == 0 self.rank = rank self.world_size = world_size self.syncbatchnorm = is_ddp self.amp = amp self.G_scaler = None self.D_scaler = None if self.amp: self.G_scaler = GradScaler() self.D_scaler = GradScaler() @property def image_extension(self): return "jpg" if not self.transparent else "png" @property def checkpoint_num(self): return floor(self.steps // self.save_every) def init_GAN(self): args, kwargs = self.GAN_params # set some global variables before instantiating GAN global norm_class global Blur norm_class = nn.SyncBatchNorm if self.syncbatchnorm else nn.BatchNorm2d Blur = nn.Identity if not self.antialias else Blur # handle bugs when # switching from multi-gpu back to single gpu if self.syncbatchnorm and not self.is_ddp: import torch.distributed as dist os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" dist.init_process_group("nccl", rank=0, world_size=1) # instantiate GAN self.GAN = LightweightGAN( optimizer=self.optimizer, lr=self.lr, latent_dim=self.latent_dim, attn_res_layers=self.attn_res_layers, sle_spatial=self.sle_spatial, image_size=self.image_size, ttur_mult=self.ttur_mult, fmap_max=self.fmap_max, disc_output_size=self.disc_output_size, transparent=self.transparent, rank=self.rank, *args, **kwargs, ) if self.is_ddp: ddp_kwargs = { "device_ids": [self.rank], "output_device": self.rank, "find_unused_parameters": True, } self.G_ddp = DDP(self.GAN.G, **ddp_kwargs) self.D_ddp = DDP(self.GAN.D, **ddp_kwargs) self.D_aug_ddp = DDP(self.GAN.D_aug, **ddp_kwargs) def write_config(self): self.config_path.write_text(json.dumps(self.config())) def load_config(self): config = ( self.config() if not self.config_path.exists() else json.loads(self.config_path.read_text()) ) self.image_size = config["image_size"] self.transparent = config["transparent"] self.syncbatchnorm = config["syncbatchnorm"] self.disc_output_size = config["disc_output_size"] self.attn_res_layers = config.pop("attn_res_layers", []) self.sle_spatial = config.pop("sle_spatial", False) self.optimizer = config.pop("optimizer", "adam") self.fmap_max = config.pop("fmap_max", 512) del self.GAN self.init_GAN() def config(self): return { "image_size": self.image_size, "transparent": self.transparent, "syncbatchnorm": self.syncbatchnorm, "disc_output_size": self.disc_output_size, "optimizer": self.optimizer, "attn_res_layers": self.attn_res_layers, "sle_spatial": self.sle_spatial, } def set_data_src(self, folder): self.dataset = ImageDataset( folder, self.image_size, transparent=self.transparent, aug_prob=self.dataset_aug_prob, ) sampler = ( DistributedSampler( self.dataset, rank=self.rank, num_replicas=self.world_size, shuffle=True ) if self.is_ddp else None ) dataloader = DataLoader( self.dataset, num_workers=math.ceil(NUM_CORES / self.world_size), batch_size=math.ceil(self.batch_size / self.world_size), sampler=sampler, shuffle=not self.is_ddp, drop_last=True, pin_memory=True, ) self.loader = cycle(dataloader) # auto set augmentation prob for user if dataset is detected to be low num_samples = len(self.dataset) if not exists(self.aug_prob) and num_samples < 1e5: self.aug_prob = min(0.5, (1e5 - num_samples) * 3e-6) print( f"autosetting augmentation probability to {round(self.aug_prob * 100)}%" ) def train(self): assert exists( self.loader ), "You must first initialize the data source with `.set_data_src(<folder of images>)`" device = torch.device(f"cuda:{self.rank}") if not exists(self.GAN): self.init_GAN() self.GAN.train() total_disc_loss = torch.zeros([], device=device) total_gen_loss = torch.zeros([], device=device) batch_size = math.ceil(self.batch_size / self.world_size) # image_size = self.GAN.image_size latent_dim = self.GAN.latent_dim aug_prob = default(self.aug_prob, 0) aug_types = self.aug_types aug_kwargs = {"prob": aug_prob, "types": aug_types} G = self.GAN.G if not self.is_ddp else self.G_ddp # D = self.GAN.D if not self.is_ddp else self.D_ddp D_aug = self.GAN.D_aug if not self.is_ddp else self.D_aug_ddp apply_gradient_penalty = self.steps % 4 == 0 # amp related contexts and functions amp_context = autocast if self.amp else null_context def backward(amp, loss, scaler): if amp: return scaler.scale(loss).backward() loss.backward() def optimizer_step(amp, optimizer, scaler): if amp: scaler.step(optimizer) scaler.update() return optimizer.step() backward = partial(backward, self.amp) optimizer_step = partial(optimizer_step, self.amp) # train discriminator self.GAN.D_opt.zero_grad() for i in gradient_accumulate_contexts( self.gradient_accumulate_every, self.is_ddp, ddps=[D_aug, G] ): latents = torch.randn(batch_size, latent_dim).cuda(self.rank) image_batch = next(self.loader).cuda(self.rank) image_batch.requires_grad_() with amp_context(): generated_images = G(latents) fake_output, fake_output_32x32, _ = D_aug( generated_images.detach(), detach=True, **aug_kwargs ) real_output, real_output_32x32, real_aux_loss = D_aug( image_batch, calc_aux_loss=True, **aug_kwargs ) real_output_loss = real_output fake_output_loss = fake_output divergence = hinge_loss(real_output_loss, fake_output_loss) divergence_32x32 = hinge_loss(real_output_32x32, fake_output_32x32) disc_loss = divergence + divergence_32x32 aux_loss = real_aux_loss disc_loss = disc_loss + aux_loss if apply_gradient_penalty: outputs = [real_output, real_output_32x32] outputs = ( list(map(self.D_scaler.scale, outputs)) if self.amp else outputs ) scaled_gradients = torch_grad( outputs=outputs, inputs=image_batch, grad_outputs=list( map( lambda t: torch.ones(t.size(), device=image_batch.device), outputs, ) ), create_graph=True, retain_graph=True, only_inputs=True, )[0] inv_scale = (1.0 / self.D_scaler.get_scale()) if self.amp else 1.0 gradients = scaled_gradients * inv_scale with amp_context(): gradients = gradients.reshape(batch_size, -1) gp = self.gp_weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean() if not torch.isnan(gp): disc_loss = disc_loss + gp self.last_gp_loss = gp.clone().detach().item() with amp_context(): disc_loss = disc_loss / self.gradient_accumulate_every disc_loss.register_hook(raise_if_nan) backward(disc_loss, self.D_scaler) total_disc_loss += divergence self.last_recon_loss = aux_loss.item() self.d_loss = float(total_disc_loss.item() / self.gradient_accumulate_every) optimizer_step(self.GAN.D_opt, self.D_scaler) # train generator self.GAN.G_opt.zero_grad() for i in gradient_accumulate_contexts( self.gradient_accumulate_every, self.is_ddp, ddps=[G, D_aug] ): latents = torch.randn(batch_size, latent_dim).cuda(self.rank) with amp_context(): generated_images = G(latents) fake_output, fake_output_32x32, _ = D_aug( generated_images, **aug_kwargs ) fake_output_loss = fake_output.mean(dim=1) + fake_output_32x32.mean( dim=1 ) epochs = ( self.steps * batch_size * self.gradient_accumulate_every ) / len(self.dataset) k_frac = max( self.generator_top_k_gamma ** epochs, self.generator_top_k_frac ) k = math.ceil(batch_size * k_frac) if k != batch_size: fake_output_loss, _ = fake_output_loss.topk(k=k, largest=False) loss = fake_output_loss.mean() gen_loss = loss gen_loss = gen_loss / self.gradient_accumulate_every gen_loss.register_hook(raise_if_nan) backward(gen_loss, self.G_scaler) total_gen_loss += loss self.g_loss = float(total_gen_loss.item() / self.gradient_accumulate_every) optimizer_step(self.GAN.G_opt, self.G_scaler) # calculate moving averages if self.is_main and self.steps % 10 == 0 and self.steps > 20000: self.GAN.EMA() if self.is_main and self.steps <= 25000 and self.steps % 1000 == 2: self.GAN.reset_parameter_averaging() # save from NaN errors if any(torch.isnan(l) for l in (total_gen_loss, total_disc_loss)): print( f"NaN detected for generator or discriminator. Loading from checkpoint #{self.checkpoint_num}" ) self.load(self.checkpoint_num) raise NanException del total_disc_loss del total_gen_loss # periodically save results if self.is_main: if self.steps % self.save_every == 0: self.save(self.checkpoint_num) if self.steps % self.evaluate_every == 0 or ( self.steps % 100 == 0 and self.steps < 20000 ): self.evaluate(floor(self.steps / self.evaluate_every)) if ( exists(self.calculate_fid_every) and self.steps % self.calculate_fid_every == 0 and self.steps != 0 ): num_batches = math.ceil(CALC_FID_NUM_IMAGES / self.batch_size) fid = self.calculate_fid(num_batches) self.last_fid = fid with open( str(self.results_dir / self.name / "fid_scores.txt"), "a" ) as f: f.write(f"{self.steps},{fid}\n") self.steps += 1 @torch.no_grad() def evaluate(self, num=0, num_image_tiles=8, trunc=1.0): self.GAN.eval() ext = self.image_extension num_rows = num_image_tiles latent_dim = self.GAN.latent_dim # image_size = self.GAN.image_size # latents and noise latents = torch.randn((num_rows ** 2, latent_dim)).cuda(self.rank) # regular generated_images = self.generate_truncated(self.GAN.G, latents) torchvision.utils.save_image( generated_images, str(self.results_dir / self.name / f"{str(num)}.{ext}"), nrow=num_rows, ) # moving averages generated_images = self.generate_truncated(self.GAN.GE, latents) torchvision.utils.save_image( generated_images, str(self.results_dir / self.name / f"{str(num)}-ema.{ext}"), nrow=num_rows, ) @torch.no_grad() def calculate_fid(self, num_batches): torch.cuda.empty_cache() real_path = str(self.results_dir / self.name / "fid_real") + "/" fake_path = str(self.results_dir / self.name / "fid_fake") + "/" # remove any existing files used for fid calculation and recreate directories rmtree(real_path, ignore_errors=True) rmtree(fake_path, ignore_errors=True) os.makedirs(real_path) os.makedirs(fake_path) for batch_num in tqdm( range(num_batches), desc="calculating FID - saving reals" ): real_batch = next(self.loader) for k in range(real_batch.size(0)): torchvision.utils.save_image( real_batch[k, :, :, :], real_path + "{}.png".format(k + batch_num * self.batch_size), ) # generate a bunch of fake images in results / name / fid_fake self.GAN.eval() ext = self.image_extension latent_dim = self.GAN.latent_dim # image_size = self.GAN.image_size for batch_num in tqdm( range(num_batches), desc="calculating FID - saving generated" ): # latents and noise latents = torch.randn(self.batch_size, latent_dim).cuda(self.rank) # moving averages generated_images = self.generate_truncated(self.GAN.GE, latents) for j in range(generated_images.size(0)): torchvision.utils.save_image( generated_images[j, :, :, :], str( Path(fake_path) / f"{str(j + batch_num * self.batch_size)}-ema.{ext}" ), ) return fid_score.calculate_fid_given_paths( [real_path, fake_path], 256, True, 2048 ) @torch.no_grad() def generate_truncated(self, G, style, trunc_psi=0.75, num_image_tiles=8): generated_images = evaluate_in_chunks(self.batch_size, G, style) return generated_images.clamp_(0.0, 1.0) @torch.no_grad() def generate_interpolation( self, num=0, num_image_tiles=8, trunc=1.0, num_steps=100, save_frames=False ): self.GAN.eval() ext = self.image_extension num_rows = num_image_tiles latent_dim = self.GAN.latent_dim # image_size = self.GAN.image_size # latents and noise latents_low = torch.randn(num_rows ** 2, latent_dim).cuda(self.rank) latents_high = torch.randn(num_rows ** 2, latent_dim).cuda(self.rank) ratios = torch.linspace(0.0, 8.0, num_steps) frames = [] for ratio in tqdm(ratios): interp_latents = slerp(ratio, latents_low, latents_high) generated_images = self.generate_truncated(self.GAN.GE, interp_latents) images_grid = torchvision.utils.make_grid(generated_images, nrow=num_rows) pil_image = transforms.ToPILImage()(images_grid.cpu()) if self.transparent: background = Image.new("RGBA", pil_image.size, (255, 255, 255)) pil_image = Image.alpha_composite(background, pil_image) frames.append(pil_image) frames[0].save( str(self.results_dir / self.name / f"{str(num)}.gif"), save_all=True, append_images=frames[1:], duration=80, loop=0, optimize=True, ) if save_frames: folder_path = self.results_dir / self.name / f"{str(num)}" folder_path.mkdir(parents=True, exist_ok=True) for ind, frame in enumerate(frames): frame.save(str(folder_path / f"{str(ind)}.{ext}")) def print_log(self): data = [ ("G", self.g_loss), ("D", self.d_loss), ("GP", self.last_gp_loss), ("SS", self.last_recon_loss), ("FID", self.last_fid), ] data = [d for d in data if exists(d[1])] log = " | ".join(map(lambda n: f"{n[0]}: {n[1]:.2f}", data)) print(log) def model_name(self, num): return str(self.models_dir / self.name / f"model_{num}.pt") def init_folders(self): (self.results_dir / self.name).mkdir(parents=True, exist_ok=True) (self.models_dir / self.name).mkdir(parents=True, exist_ok=True) def clear(self): rmtree(str(self.models_dir / self.name), True) rmtree(str(self.results_dir / self.name), True) rmtree(str(self.config_path), True) self.init_folders() def save(self, num): save_data = {"GAN": self.GAN.state_dict(), "version": __version__} if self.amp: save_data = { **save_data, "G_scaler": self.G_scaler.state_dict(), "D_scaler": self.D_scaler.state_dict(), } torch.save(save_data, self.model_name(num)) self.write_config() def load(self, num=-1): self.load_config() name = num if num == -1: file_paths = [ p for p in Path(self.models_dir / self.name).glob("model_*.pt") ] saved_nums = sorted(map(lambda x: int(x.stem.split("_")[1]), file_paths)) if len(saved_nums) == 0: return name = saved_nums[-1] print(f"continuing from previous epoch - {name}") self.steps = name * self.save_every load_data = torch.load(self.model_name(name)) if "version" in load_data and self.is_main: print(f"loading from version {load_data['version']}") try: self.GAN.load_state_dict(load_data["GAN"]) except Exception as e: print( "unable to load save model. please try downgrading the package to the version specified by the saved model" ) raise e if self.amp: if "G_scaler" in load_data: self.G_scaler.load_state_dict(load_data["G_scaler"]) if "D_scaler" in load_data: self.D_scaler.load_state_dict(load_data["D_scaler"])