Example #1
0
class AMPGradAccumulateOptimizerHook(OptimizerHook):
    def __init__(self, *wargs, **kwargs):
        self.accumulation = kwargs.pop('accumulation', 1)
        self.scaler = GradScaler()
        super(AMPGradAccumulateOptimizerHook, self).__init__(*wargs, **kwargs)

    def before_run(self, runner):
        assert hasattr(runner.model.module,
                       'use_amp') and runner.model.module.use_amp, 'model should support AMP when using this optimizer hook!'
        runner.model.zero_grad()
        runner.optimizer.zero_grad()

    def before_train_iter(self, runner):
        if runner.iter % self.accumulation == 0:
            runner.model.zero_grad()
            runner.optimizer.zero_grad()

    def after_train_iter(self, runner):
        scaled_loss = self.scaler.scale(runner.outputs['loss'])
        scaled_loss.backward()

        if (runner.iter + 1) % self.accumulation == 0:
            scale = self.scaler.get_scale()
            if self.grad_clip is not None:
                self.scaler.unscale_(runner.optimizer)
                grad_norm = self.clip_grads(runner.model.parameters())
                if grad_norm is not None:
                    # Add grad norm to the logger
                    runner.log_buffer.update({'grad_norm': float(grad_norm)},
                                             runner.outputs['num_samples'])
            runner.log_buffer.update({'grad_scale': float(scale)},
                                     runner.outputs['num_samples'])
            self.scaler.step(runner.optimizer)
            self.scaler.update()
class Trainer():
    def __init__(self,
                 name='default',
                 results_dir='results',
                 models_dir='models',
                 base_dir='./',
                 optimizer="adam",
                 latent_dim=256,
                 image_size=128,
                 fmap_max=512,
                 transparent=False,
                 greyscale=False,
                 batch_size=4,
                 gp_weight=10,
                 gradient_accumulate_every=1,
                 attn_res_layers=[],
                 disc_output_size=5,
                 antialias=False,
                 lr=2e-4,
                 lr_mlp=1.,
                 ttur_mult=1.,
                 save_every=1000,
                 evaluate_every=1000,
                 trunc_psi=0.6,
                 aug_prob=None,
                 aug_types=['translation', 'cutout'],
                 dataset_aug_prob=0.,
                 calculate_fid_every=None,
                 is_ddp=False,
                 rank=0,
                 world_size=1,
                 log=False,
                 amp=False,
                 *args,
                 **kwargs):
        self.GAN_params = [args, kwargs]
        self.GAN = None

        self.name = name

        base_dir = Path(base_dir)
        self.base_dir = base_dir
        self.results_dir = base_dir / results_dir
        self.models_dir = base_dir / models_dir
        self.config_path = self.models_dir / name / '.config.json'

        assert is_power_of_two(
            image_size
        ), 'image size must be a power of 2 (64, 128, 256, 512, 1024)'
        assert all(
            map(is_power_of_two, attn_res_layers)
        ), 'resolution layers of attention must all be powers of 2 (16, 32, 64, 128, 256, 512)'

        self.optimizer = optimizer
        self.latent_dim = latent_dim
        self.image_size = image_size
        self.fmap_max = fmap_max
        self.transparent = transparent
        self.greyscale = greyscale

        assert (int(self.transparent) + int(self.greyscale)
                ) < 2, 'you can only set either transparency or greyscale'

        self.aug_prob = aug_prob
        self.aug_types = aug_types

        self.lr = lr
        self.ttur_mult = ttur_mult
        self.batch_size = batch_size
        self.gradient_accumulate_every = gradient_accumulate_every

        self.gp_weight = gp_weight

        self.evaluate_every = evaluate_every
        self.save_every = save_every
        self.steps = 0

        self.generator_top_k_gamma = 0.99
        self.generator_top_k_frac = 0.5

        self.attn_res_layers = attn_res_layers
        self.disc_output_size = disc_output_size
        self.antialias = antialias

        self.d_loss = 0
        self.g_loss = 0
        self.last_gp_loss = None
        self.last_recon_loss = None
        self.last_fid = None

        self.init_folders()

        self.loader = None
        self.dataset_aug_prob = dataset_aug_prob

        self.calculate_fid_every = calculate_fid_every

        self.is_ddp = is_ddp
        self.is_main = rank == 0
        self.rank = rank
        self.world_size = world_size

        self.syncbatchnorm = is_ddp

        self.amp = amp
        self.G_scaler = GradScaler(enabled=self.amp)
        self.D_scaler = GradScaler(enabled=self.amp)

    @property
    def image_extension(self):
        return 'jpg' if not self.transparent else 'png'

    @property
    def checkpoint_num(self):
        return floor(self.steps // self.save_every)

    def init_GAN(self):
        args, kwargs = self.GAN_params

        # set some global variables before instantiating GAN

        global norm_class
        global Blur

        norm_class = nn.SyncBatchNorm if self.syncbatchnorm else nn.BatchNorm2d
        Blur = nn.Identity if not self.antialias else Blur

        # handle bugs when
        # switching from multi-gpu back to single gpu

        if self.syncbatchnorm and not self.is_ddp:
            import torch.distributed as dist
            os.environ['MASTER_ADDR'] = 'localhost'
            os.environ['MASTER_PORT'] = '12355'
            dist.init_process_group('nccl', rank=0, world_size=1)

        # instantiate GAN

        self.GAN = LightweightGAN(optimizer=self.optimizer,
                                  lr=self.lr,
                                  latent_dim=self.latent_dim,
                                  attn_res_layers=self.attn_res_layers,
                                  image_size=self.image_size,
                                  ttur_mult=self.ttur_mult,
                                  fmap_max=self.fmap_max,
                                  disc_output_size=self.disc_output_size,
                                  transparent=self.transparent,
                                  greyscale=self.greyscale,
                                  rank=self.rank,
                                  *args,
                                  **kwargs)

        if self.is_ddp:
            ddp_kwargs = {
                'device_ids': [self.rank],
                'output_device': self.rank,
                'find_unused_parameters': True
            }

            self.G_ddp = DDP(self.GAN.G, **ddp_kwargs)
            self.D_ddp = DDP(self.GAN.D, **ddp_kwargs)
            self.D_aug_ddp = DDP(self.GAN.D_aug, **ddp_kwargs)

    def write_config(self):
        self.config_path.write_text(json.dumps(self.config()))

    def load_config(self):
        config = self.config(
        ) if not self.config_path.exists() else json.loads(
            self.config_path.read_text())
        self.image_size = config['image_size']
        self.transparent = config['transparent']
        self.syncbatchnorm = config['syncbatchnorm']
        self.disc_output_size = config['disc_output_size']
        self.greyscale = config.pop('greyscale', False)
        self.attn_res_layers = config.pop('attn_res_layers', [])
        self.optimizer = config.pop('optimizer', 'adam')
        self.fmap_max = config.pop('fmap_max', 512)
        del self.GAN
        self.init_GAN()

    def config(self):
        return {
            'image_size': self.image_size,
            'transparent': self.transparent,
            'greyscale': self.greyscale,
            'syncbatchnorm': self.syncbatchnorm,
            'disc_output_size': self.disc_output_size,
            'optimizer': self.optimizer,
            'attn_res_layers': self.attn_res_layers
        }

    def set_data_src(self, folder):
        self.dataset = ImageDataset(folder,
                                    self.image_size,
                                    transparent=self.transparent,
                                    greyscale=self.greyscale,
                                    aug_prob=self.dataset_aug_prob)
        sampler = DistributedSampler(self.dataset,
                                     rank=self.rank,
                                     num_replicas=self.world_size,
                                     shuffle=True) if self.is_ddp else None
        dataloader = DataLoader(
            self.dataset,
            num_workers=math.ceil(NUM_CORES / self.world_size),
            batch_size=math.ceil(self.batch_size / self.world_size),
            sampler=sampler,
            shuffle=not self.is_ddp,
            drop_last=True,
            pin_memory=True)
        self.loader = cycle(dataloader)

        # auto set augmentation prob for user if dataset is detected to be low
        num_samples = len(self.dataset)
        if not exists(self.aug_prob) and num_samples < 1e5:
            self.aug_prob = min(0.5, (1e5 - num_samples) * 3e-6)
            print(
                f'autosetting augmentation probability to {round(self.aug_prob * 100)}%'
            )

    def train(self):
        assert exists(
            self.loader
        ), 'You must first initialize the data source with `.set_data_src(<folder of images>)`'
        device = torch.device(f'cuda:{self.rank}')

        if not exists(self.GAN):
            self.init_GAN()

        self.GAN.train()
        total_disc_loss = torch.zeros([], device=device)
        total_gen_loss = torch.zeros([], device=device)

        batch_size = math.ceil(self.batch_size / self.world_size)

        image_size = self.GAN.image_size
        latent_dim = self.GAN.latent_dim

        aug_prob = default(self.aug_prob, 0)
        aug_types = self.aug_types
        aug_kwargs = {'prob': aug_prob, 'types': aug_types}

        G = self.GAN.G if not self.is_ddp else self.G_ddp
        D = self.GAN.D if not self.is_ddp else self.D_ddp
        D_aug = self.GAN.D_aug if not self.is_ddp else self.D_aug_ddp

        apply_gradient_penalty = self.steps % 4 == 0

        # amp related contexts and functions

        amp_context = autocast if self.amp else null_context

        # train discriminator
        self.GAN.D_opt.zero_grad()
        for i in gradient_accumulate_contexts(self.gradient_accumulate_every,
                                              self.is_ddp,
                                              ddps=[D_aug, G]):
            latents = torch.randn(batch_size, latent_dim).cuda(self.rank)
            image_batch = next(self.loader).cuda(self.rank)
            image_batch.requires_grad_()

            with amp_context():
                with torch.no_grad():
                    generated_images = G(latents)

                fake_output, fake_output_32x32, _ = D_aug(generated_images,
                                                          detach=True,
                                                          **aug_kwargs)

                real_output, real_output_32x32, real_aux_loss = D_aug(
                    image_batch, calc_aux_loss=True, **aug_kwargs)

                real_output_loss = real_output
                fake_output_loss = fake_output

                divergence = hinge_loss(real_output_loss, fake_output_loss)
                divergence_32x32 = hinge_loss(real_output_32x32,
                                              fake_output_32x32)
                disc_loss = divergence + divergence_32x32

                aux_loss = real_aux_loss
                disc_loss = disc_loss + aux_loss

            if apply_gradient_penalty:
                outputs = [real_output, real_output_32x32]
                outputs = list(map(self.D_scaler.scale,
                                   outputs)) if self.amp else outputs

                scaled_gradients = torch_grad(
                    outputs=outputs,
                    inputs=image_batch,
                    grad_outputs=list(
                        map(
                            lambda t: torch.ones(t.size(),
                                                 device=image_batch.device),
                            outputs)),
                    create_graph=True,
                    retain_graph=True,
                    only_inputs=True)[0]

                inv_scale = (1. /
                             self.D_scaler.get_scale()) if self.amp else 1.
                gradients = scaled_gradients * inv_scale

                with amp_context():
                    gradients = gradients.reshape(batch_size, -1)
                    gp = self.gp_weight * (
                        (gradients.norm(2, dim=1) - 1)**2).mean()

                    if not torch.isnan(gp):
                        disc_loss = disc_loss + gp
                        self.last_gp_loss = gp.clone().detach().item()

            with amp_context():
                disc_loss = disc_loss / self.gradient_accumulate_every

            disc_loss.register_hook(raise_if_nan)
            self.D_scaler.scale(disc_loss).backward()
            total_disc_loss += divergence

        self.last_recon_loss = aux_loss.item()
        self.d_loss = float(total_disc_loss.item() /
                            self.gradient_accumulate_every)
        self.D_scaler.step(self.GAN.D_opt)
        self.D_scaler.update()

        # train generator

        self.GAN.G_opt.zero_grad()

        for i in gradient_accumulate_contexts(self.gradient_accumulate_every,
                                              self.is_ddp,
                                              ddps=[G, D_aug]):
            latents = torch.randn(batch_size, latent_dim).cuda(self.rank)

            with amp_context():
                generated_images = G(latents)
                fake_output, fake_output_32x32, _ = D_aug(
                    generated_images, **aug_kwargs)
                fake_output_loss = fake_output.mean(
                    dim=1) + fake_output_32x32.mean(dim=1)

                epochs = (self.steps * batch_size *
                          self.gradient_accumulate_every) / len(self.dataset)
                k_frac = max(self.generator_top_k_gamma**epochs,
                             self.generator_top_k_frac)
                k = math.ceil(batch_size * k_frac)

                if k != batch_size:
                    fake_output_loss, _ = fake_output_loss.topk(k=k,
                                                                largest=False)

                loss = fake_output_loss.mean()
                gen_loss = loss

                gen_loss = gen_loss / self.gradient_accumulate_every

            gen_loss.register_hook(raise_if_nan)
            self.G_scaler.scale(gen_loss).backward()
            total_gen_loss += loss

        self.g_loss = float(total_gen_loss.item() /
                            self.gradient_accumulate_every)
        self.G_scaler.step(self.GAN.G_opt)
        self.G_scaler.update()

        # calculate moving averages

        if self.is_main and self.steps % 10 == 0 and self.steps > 20000:
            self.GAN.EMA()

        if self.is_main and self.steps <= 25000 and self.steps % 1000 == 2:
            self.GAN.reset_parameter_averaging()

        # save from NaN errors

        if any(torch.isnan(l) for l in (total_gen_loss, total_disc_loss)):
            print(
                f'NaN detected for generator or discriminator. Loading from checkpoint #{self.checkpoint_num}'
            )
            self.load(self.checkpoint_num)
            raise NanException

        del total_disc_loss
        del total_gen_loss

        # periodically save results

        if self.is_main:
            if self.steps % self.save_every == 0:
                self.save(self.checkpoint_num)

            if self.steps % self.evaluate_every == 0 or (
                    self.steps % 100 == 0 and self.steps < 20000):
                self.evaluate(floor(self.steps / self.evaluate_every))

            if exists(
                    self.calculate_fid_every
            ) and self.steps % self.calculate_fid_every == 0 and self.steps != 0:
                num_batches = math.ceil(CALC_FID_NUM_IMAGES / self.batch_size)
                fid = self.calculate_fid(num_batches)
                self.last_fid = fid

                with open(
                        str(self.results_dir / self.name / f'fid_scores.txt'),
                        'a') as f:
                    f.write(f'{self.steps},{fid}\n')

        self.steps += 1

    @torch.no_grad()
    def evaluate(self, num=0, num_image_tiles=8, trunc=1.0):
        self.GAN.eval()

        ext = self.image_extension
        num_rows = num_image_tiles

        latent_dim = self.GAN.latent_dim
        image_size = self.GAN.image_size

        # latents and noise

        latents = torch.randn((num_rows**2, latent_dim)).cuda(self.rank)

        # regular

        generated_images = self.generate_truncated(self.GAN.G, latents)
        torchvision.utils.save_image(generated_images,
                                     str(self.results_dir / self.name /
                                         f'{str(num)}.{ext}'),
                                     nrow=num_rows)

        # moving averages

        generated_images = self.generate_truncated(self.GAN.GE, latents)
        torchvision.utils.save_image(generated_images,
                                     str(self.results_dir / self.name /
                                         f'{str(num)}-ema.{ext}'),
                                     nrow=num_rows)

    @torch.no_grad()
    def calculate_fid(self, num_batches):
        from pytorch_fid import fid_score
        torch.cuda.empty_cache()

        real_path = str(self.results_dir / self.name / 'fid_real') + '/'
        fake_path = str(self.results_dir / self.name / 'fid_fake') + '/'

        # remove any existing files used for fid calculation and recreate directories
        rmtree(real_path, ignore_errors=True)
        rmtree(fake_path, ignore_errors=True)
        os.makedirs(real_path)
        os.makedirs(fake_path)

        for batch_num in tqdm(range(num_batches),
                              desc='calculating FID - saving reals'):
            real_batch = next(self.loader)
            for k in range(real_batch.size(0)):
                torchvision.utils.save_image(
                    real_batch[k, :, :, :], real_path +
                    '{}.png'.format(k + batch_num * self.batch_size))

        # generate a bunch of fake images in results / name / fid_fake
        self.GAN.eval()
        ext = self.image_extension

        latent_dim = self.GAN.latent_dim
        image_size = self.GAN.image_size

        for batch_num in tqdm(range(num_batches),
                              desc='calculating FID - saving generated'):
            # latents and noise
            latents = torch.randn(self.batch_size, latent_dim).cuda(self.rank)

            # moving averages
            generated_images = self.generate_truncated(self.GAN.GE, latents)

            for j in range(generated_images.size(0)):
                torchvision.utils.save_image(
                    generated_images[j, :, :, :],
                    str(
                        Path(fake_path) /
                        f'{str(j + batch_num * self.batch_size)}-ema.{ext}'))

        return fid_score.calculate_fid_given_paths([real_path, fake_path], 256,
                                                   True, 2048)

    @torch.no_grad()
    def generate_truncated(self, G, style, trunc_psi=0.75, num_image_tiles=8):
        generated_images = evaluate_in_chunks(self.batch_size, G, style)
        return generated_images.clamp_(0., 1.)

    @torch.no_grad()
    def generate_interpolation(self,
                               num=0,
                               num_image_tiles=8,
                               trunc=1.0,
                               num_steps=100,
                               save_frames=False):
        self.GAN.eval()
        ext = self.image_extension
        num_rows = num_image_tiles

        latent_dim = self.GAN.latent_dim
        image_size = self.GAN.image_size

        # latents and noise

        latents_low = torch.randn(num_rows**2, latent_dim).cuda(self.rank)
        latents_high = torch.randn(num_rows**2, latent_dim).cuda(self.rank)

        ratios = torch.linspace(0., 8., num_steps)

        frames = []
        for ratio in tqdm(ratios):
            interp_latents = slerp(ratio, latents_low, latents_high)
            generated_images = self.generate_truncated(self.GAN.GE,
                                                       interp_latents)
            images_grid = torchvision.utils.make_grid(generated_images,
                                                      nrow=num_rows)
            pil_image = transforms.ToPILImage()(images_grid.cpu())

            if self.transparent:
                background = Image.new('RGBA', pil_image.size, (255, 255, 255))
                pil_image = Image.alpha_composite(background, pil_image)

            frames.append(pil_image)

        frames[0].save(str(self.results_dir / self.name / f'{str(num)}.gif'),
                       save_all=True,
                       append_images=frames[1:],
                       duration=80,
                       loop=0,
                       optimize=True)

        if save_frames:
            folder_path = (self.results_dir / self.name / f'{str(num)}')
            folder_path.mkdir(parents=True, exist_ok=True)
            for ind, frame in enumerate(frames):
                frame.save(str(folder_path / f'{str(ind)}.{ext}'))

    def print_log(self):
        data = [('G', self.g_loss), ('D', self.d_loss),
                ('GP', self.last_gp_loss), ('SS', self.last_recon_loss),
                ('FID', self.last_fid)]

        data = [d for d in data if exists(d[1])]
        log = ' | '.join(map(lambda n: f'{n[0]}: {n[1]:.2f}', data))
        print(log)

    def model_name(self, num):
        return str(self.models_dir / self.name / f'model_{num}.pt')

    def init_folders(self):
        (self.results_dir / self.name).mkdir(parents=True, exist_ok=True)
        (self.models_dir / self.name).mkdir(parents=True, exist_ok=True)

    def clear(self):
        rmtree(str(self.models_dir / self.name), True)
        rmtree(str(self.results_dir / self.name), True)
        rmtree(str(self.config_path), True)
        self.init_folders()

    def save(self, num):
        save_data = {
            'GAN': self.GAN.state_dict(),
            'version': __version__,
            'G_scaler': self.G_scaler.state_dict(),
            'D_scaler': self.D_scaler.state_dict()
        }

        torch.save(save_data, self.model_name(num))
        self.write_config()

    def load(self, num=-1):
        self.load_config()

        name = num
        if num == -1:
            file_paths = [
                p for p in Path(self.models_dir / self.name).glob('model_*.pt')
            ]
            saved_nums = sorted(
                map(lambda x: int(x.stem.split('_')[1]), file_paths))
            if len(saved_nums) == 0:
                return
            name = saved_nums[-1]
            print(f'continuing from previous epoch - {name}')

        self.steps = name * self.save_every

        load_data = torch.load(self.model_name(name))

        if 'version' in load_data and self.is_main:
            print(f"loading from version {load_data['version']}")

        try:
            self.GAN.load_state_dict(load_data['GAN'])
        except Exception as e:
            print(
                'unable to load save model. please try downgrading the package to the version specified by the saved model'
            )
            raise e

        if 'G_scaler' in load_data:
            self.G_scaler.load_state_dict(load_data['G_scaler'])
        if 'D_scaler' in load_data:
            self.D_scaler.load_state_dict(load_data['D_scaler'])
Example #3
0
    def train_epoch(self, model: Reader, optimizer: torch.optim.Optimizer,
                    scaler: GradScaler, train: DataLoader, val: DataLoader,
                    scheduler: torch.optim.lr_scheduler.LambdaLR) -> float:
        """
        Performs one training epoch.

        :param model: The model you are training.
        :type model: Reader
        :param optimizer: Use this optimizer for training.
        :type optimizer: torch.optim.Optimizer
        :param scaler: Scaler for gradients when the mixed precision is used.
        :type scaler: GradScaler
        :param train: The train dataset loader.
        :type train: DataLoader
        :param val: The validation dataset loader.
        :type val: DataLoader
        :param scheduler: Learning rate scheduler.
        :type scheduler: torch.optim.lr_scheduler.LambdaLR
        :return: Best achieved exact match among validations.
        :rtype: float
        """

        model.train()
        loss_sum = 0
        samples = 0
        startTime = time.time()

        total_tokens = 0
        optimizer.zero_grad()

        initStep = 0
        if self.resumeSkip is not None:
            initStep = self.resumeSkip
            self.resumeSkip = None

        iterator = tqdm(enumerate(train), total=len(train), initial=initStep)

        bestExactMatch = 0.0

        for current_it, batch in iterator:
            batch: ReaderBatch
            lastScale = scaler.get_scale()
            self.n_iter += 1

            batchOnDevice = batch.to(self.device)
            samples += 1

            try:
                with torch.cuda.amp.autocast(
                        enabled=self.config["mixed_precision"]):
                    startScores, endScores, jointScore, selectionScore = self._useModel(
                        model, batchOnDevice)

                    # according to the config we can get following loss combinations
                    # join components
                    # independent components
                    # join components with HardEM
                    # independent components with HardEM

                    logSpanProb = None
                    if not self.config["independent_components_in_loss"]:
                        # joined components in loss
                        logSpanProb = Reader.scores2logSpanProb(
                            startScores, endScores, jointScore, selectionScore)

                    # User may want to use hardEMLoss with certain probability.
                    # In the original article it is not written clearly and it seams like it is the other way around.
                    # After I had consulted it with authors the idea became clear.

                    if self.config["hard_em_steps"] > 0 and \
                            random.random() <= min(self.update_it/self.config["hard_em_steps"], self.config["max_hard_em_prob"]):
                        # loss is calculated for the max answer span with max probability
                        if self.config["independent_components_in_loss"]:
                            loss = Reader.hardEMIndependentComponentsLoss(
                                startScores, endScores, jointScore,
                                selectionScore, batchOnDevice.answersMask)
                        else:
                            loss = Reader.hardEMLoss(logSpanProb,
                                                     batchOnDevice.answersMask)
                    else:
                        # loss is calculated for all answer spans
                        if self.config["independent_components_in_loss"]:
                            loss = Reader.marginalCompoundLossWithIndependentComponents(
                                startScores, endScores, jointScore,
                                selectionScore, batchOnDevice.answersMask)
                        else:
                            loss = Reader.marginalCompoundLoss(
                                logSpanProb, batchOnDevice.answersMask)

                    if self.config[
                            "use_auxiliary_loss"] and batch.isGroundTruth:
                        # we must be sure that user wants it and that the true passage is ground truth
                        loss += Reader.auxiliarySelectedLoss(selectionScore)
                    loss_sum += loss.item()

                scaler.scale(loss).backward()

            # Catch out-of-memory errors
            except RuntimeError as e:
                if "CUDA out of memory." in str(e):
                    torch.cuda.empty_cache()
                    logging.error(e)
                    tb = traceback.format_exc()
                    logging.error(tb)
                    continue
                else:
                    raise e

            # update parameters

            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(
                filter(lambda p: p.requires_grad, model.parameters()),
                self.config["max_grad_norm"])

            scaler.step(optimizer)
            scaler.update()

            optimizer.zero_grad()
            self.update_it += 1

            if math.isclose(lastScale, scaler.get_scale(),
                            rel_tol=1e-6) and scheduler is not None:
                # we should not perform scheduler step when the optimizer step was omitted due to the
                # change of scale factor
                scheduler.step()

            if self.update_it % self.config["validate_after_steps"] == 0:
                valLoss, exactMatch, passageMatch, samplesWithLoss = self.validate(
                    model, val)

                logging.info(
                    f"Steps:{self.update_it}, Training loss: {loss_sum / samples:.5f}, Validation loss: {valLoss} (samples with loss {samplesWithLoss} [{samplesWithLoss / len(val):.1%}]), Exact match: {exactMatch:.5f}, Passage match: {passageMatch:.5f}"
                )

                bestExactMatch = max(exactMatch, bestExactMatch)
                if self.update_it > self.config["first_save_after_updates_K"]:
                    checkpoint = Checkpoint(
                        model.module if isinstance(model, DataParallel) else
                        model, optimizer, scheduler, train.sampler.actPerm,
                        current_it + 1, self.config, self.update_it)
                    checkpoint.save(f"{self.config['save_dir']}/Reader_train"
                                    f"_{get_timestamp()}"
                                    f"_{socket.gethostname()}"
                                    f"_{valLoss}"
                                    f"_S_{self.update_it}"
                                    f"_E_{current_it}.pt")

                model.train()

            # statistics & logging
            total_tokens += batch.inputSequences.numel()
            if (self.n_iter + 1) % 50 == 0 or current_it == len(iterator) - 1:
                iterator.set_description(
                    f"Steps: {self.update_it} Tokens/s: {total_tokens / (time.time() - startTime)}, Training loss: {loss_sum / samples}"
                )

            if self.config["max_steps"] <= self.update_it:
                break

        logging.info(
            f"End of epoch training loss: {loss_sum / samples:.5f}, best validation exact match: {bestExactMatch}"
        )

        return bestExactMatch
Example #4
0
class Solver(object):
    def __init__(self, config):
        self.model = None
        self.args = config
        self.criterion = None
        self.optimizer = None
        self.scheduler = None
        self.device = None
        self.cuda = config.cuda
        self.train_loader = None
        self.test_loader = None
        self.infer_loader = None
        self.es = EarlyStopping(patience=self.args.es_patience)
        self.scaler = GradScaler(enabled=self.args.half)

        if not self.args.save_dir:
            self.writer = SummaryWriter()
        else:
            self.writer = SummaryWriter(log_dir="runs/" + self.args.save_dir)

        self.train_batch_plot_idx = 0
        self.test_batch_plot_idx = 0

    def load_data(self):
        if self.args.dataset.name not in datasets:
            print(
                f"This dataset is not implemented ({self.args.dataset.name}), go ahead and commit it"
            )
            exit()
        train_cache_index = 0
        train_data_transformations = []
        for idx, transformation in enumerate(
                self.args.transformations.train.data):
            if transformation.name not in transformations:
                print(
                    f"This transformation is not implemented ({transformation.name}), go ahead and commit it"
                )
                exit()
            if hasattr(transformation, 'cache_point'):
                train_cache_index = idx + 1
            train_data_transformations.append(
                transformations[transformation.name](
                    **transformation.parameters))

        train_target_transformations = []
        for transformation in self.args.transformations.train.target:
            if transformation.name not in transformations:
                print(
                    f"This transformation is not implemented ({transformation.name}), go ahead and commit it"
                )
                exit()
            train_target_transformations.append(
                transformations[transformation.name](
                    **transformation.parameters))

        train_both_transformations = []
        for transformation in self.args.transformations.train.both:
            if transformation.name not in transformations:
                print(
                    f"This transformation is not implemented ({transformation.name}), go ahead and commit it"
                )
                exit()
            train_both_transformations.append(
                transformations[transformation.name](
                    **transformation.parameters))

        train_output_transformations = []
        for transformation in self.args.transformations.train.output:
            if transformation.name not in transformations:
                print(
                    f"This transformation is not implemented ({transformation.name}), go ahead and commit it"
                )
                exit()
            train_output_transformations.append(
                transformations[transformation.name](
                    **transformation.parameters))

        train_data_transform = transforms.Compose(
            train_data_transformations
        ) if len(train_data_transformations) > 0 else None
        train_target_transform = transforms.Compose(
            train_target_transformations
        ) if len(train_target_transformations) > 0 else None
        train_both_transform = transforms.Compose(
            train_both_transformations
        ) if len(train_both_transformations) > 0 else None
        self.train_output_transform = transforms.Compose(
            train_output_transformations
        ) if len(train_output_transformations) > 0 else None

        test_cache_index = 0
        test_data_transformations = []
        for idx, transformation in enumerate(
                self.args.transformations.test.data):
            if transformation.name not in transformations:
                print(
                    f"This transformation is not implemented ({transformation.name}), go ahead and commit it"
                )
                exit()
            if hasattr(transformation, 'cache_point'):
                test_cache_index = idx + 1
            test_data_transformations.append(
                transformations[transformation.name](
                    **transformation.parameters))

        test_target_transformations = []
        for transformation in self.args.transformations.test.target:
            if transformation.name not in transformations:
                print(
                    f"This transformation is not implemented ({transformation.name}), go ahead and commit it"
                )
                exit()
            test_target_transformations.append(
                transformations[transformation.name](
                    **transformation.parameters))

        test_both_transformations = []
        for transformation in self.args.transformations.test.both:
            if transformation.name not in transformations:
                print(
                    f"This transformation is not implemented ({transformation.name}), go ahead and commit it"
                )
                exit()
            test_both_transformations.append(
                transformations[transformation.name](
                    **transformation.parameters))

        test_output_transformations = []
        for transformation in self.args.transformations.test.output:
            if transformation.name not in transformations:
                print(
                    f"This transformation is not implemented ({transformation.name}), go ahead and commit it"
                )
                exit()
            test_output_transformations.append(
                transformations[transformation.name](
                    **transformation.parameters))

        test_data_transform = transforms.Compose(
            test_data_transformations
        ) if len(test_data_transformations) > 0 else None
        test_target_transform = transforms.Compose(
            test_target_transformations
        ) if len(test_target_transformations) > 0 else None
        test_both_transform = transforms.Compose(
            test_both_transformations
        ) if len(test_both_transformations) > 0 else None
        self.test_output_transform = transforms.Compose(
            test_output_transformations
        ) if len(test_output_transformations) > 0 else None

        parameters = OmegaConf.to_container(
            self.args.dataset.train_loader_params, resolve=True)
        parameters = {k: v for k, v in parameters.items() if v is not None}
        if self.args.dataset.name in ['CIFAR-10', 'CIFAR-100', 'ImageNet2012']:
            parameters["transform"] = train_data_transform
            parameters["target_transform"] = train_target_transform
        else:
            parameters["data_transform"] = train_data_transform
            parameters["target_transform"] = train_target_transform
            parameters["both_transform"] = train_both_transform
            parameters['cache_index'] = train_cache_index
        self.train_set = datasets[self.args.dataset.name](**parameters)

        parameters = OmegaConf.to_container(
            self.args.dataset.test_loader_params, resolve=True)
        parameters = {k: v for k, v in parameters.items() if v is not None}
        if self.args.dataset.name in ['CIFAR-10', 'CIFAR-100', 'ImageNet2012']:
            parameters["transform"] = test_data_transform
            parameters["target_transform"] = test_target_transform
        else:
            parameters["data_transform"] = test_data_transform
            parameters["target_transform"] = test_target_transform
            parameters["both_transform"] = test_both_transform
            parameters['cache_index'] = test_cache_index
        self.test_set = datasets[self.args.dataset.name](**parameters)

        if hasattr(self.args.dataset,
                   'mixup_args') and self.args.dataset.mixup_args != None:
            collate_fn = FastCollateMixup(**self.args.dataset.mixup_args)
        else:
            collate_fn = None

        self.train_loader = torch.utils.data.DataLoader(
            dataset=self.train_set,
            batch_size=self.args.dataset.train_batch_size,
            shuffle=self.args.dataset.shuffle,
            num_workers=self.args.dataset.num_workers_train,
            collate_fn=collate_fn,
            drop_last=True,
            persistent_workers=self.args.dataset.num_workers_train > 0)
        self.test_loader = torch.utils.data.DataLoader(
            dataset=self.test_set,
            batch_size=self.args.dataset.test_batch_size,
            shuffle=False,
            num_workers=self.args.dataset.num_workers_test,
            persistent_workers=self.args.dataset.num_workers_test > 0)
        if self.args.infer_only is True:
            parameters = OmegaConf.to_container(
                self.args.dataset.infer_loader_params, resolve=True)
            parameters = {k: v for k, v in parameters.items() if v is not None}
            parameters["data_transform"] = test_data_transform
            parameters["target_transform"] = test_target_transform
            parameters["both_transform"] = test_both_transform
            self.infer_set = datasets[self.args.dataset.name](**parameters)
            self.infer_loader = torch.utils.data.DataLoader(
                dataset=self.infer_set,
                batch_size=self.args.dataset.test_batch_size,
                shuffle=False,
                num_workers=self.args.dataset.num_workers_test,
                persistent_workers=self.args.dataset.num_workers_test > 0)

    def init_model(self):
        if self.cuda:
            self.device = torch.device('cuda' + ":" +
                                       str(self.args.cuda_device))
            cudnn.benchmark = True

            # The flag below controls whether to allow TF32 on matmul. This flag defaults to True.
            torch.backends.cuda.matmul.allow_tf32 = True

            # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
            torch.backends.cudnn.allow_tf32 = True
        else:
            self.device = torch.device('cpu')

        parameters = OmegaConf.to_container(self.args.model.parameters,
                                            resolve=True)
        parameters = {k: v for k, v in parameters.items() if v is not None}
        try:
            self.model = getattr(models, self.args.model.name)
        except:
            print(
                f"This model is not implemented ({self.args.model.name}), go ahead and commit it"
            )
            exit()
        self.model = self.model(**parameters)

        self.save_dir = os.path.join(self.args.storage_dir, "model_weights",
                                     self.args.save_dir)
        if not os.path.isdir(self.save_dir):
            os.makedirs(self.save_dir)

        if self.args.initialization == 1:
            # xavier init
            for m in self.model.modules():
                if isinstance(m, (nn.Conv2d, nn.Linear)):
                    nn.init.xavier_uniform(m.weight,
                                           gain=nn.init.calculate_gain('relu'))
        elif self.args.initialization == 2:
            # he initialization
            for m in self.model.modules():
                if isinstance(m, (nn.Conv2d, nn.Linear)):
                    nn.init.kaiming_normal(m.weight, mode='fan_in')
        elif self.args.initialization == 3:
            # selu init
            for m in self.model.modules():
                if isinstance(m, nn.Conv2d):
                    fan_in = m.kernel_size[0] * \
                        m.kernel_size[1] * m.in_channels
                    nn.init.normal(m.weight, 0, torch.sqrt(1. / fan_in))
                elif isinstance(m, nn.Linear):
                    fan_in = m.in_features
                    nn.init.normal(m.weight, 0, torch.sqrt(1. / fan_in))
        elif self.args.initialization == 4:
            # orthogonal initialization
            for m in self.model.modules():
                if isinstance(m, (nn.Conv2d, nn.Linear)):
                    nn.init.orthogonal_(m.weight)

        if self.args.initialization_batch_norm:
            # batch norm initialization
            for m in self.model.modules():
                if isinstance(m, nn.BatchNorm2d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)

        if len(self.args.load_model) > 0:
            print("Loading model from " + self.args.load_model)
            self.model.load_state_dict(torch.load(self.args.load_model))

            # for param in self.model.parameters():
            #     param.requires_grad = True
            # for param in self.model.patch_embed.parameters():
            #     param.requires_grad = True
            # for param in self.model.norm.parameters():
            #     param.requires_grad = True
            # for param in self.model.avgpool.parameters():
            #     param.requires_grad = True
            # for param in self.model.head.parameters():
            #     param.requires_grad = True

        self.model = self.model.to(self.device)

    def init_optimizer(self):
        parameters = OmegaConf.to_container(self.args.optimizer.parameters,
                                            resolve=True)
        parameters = {k: v for k, v in parameters.items() if v is not None}
        parameters["params"] = self.model.parameters()

        try:
            self.optimizer = getattr(torch_optimizer, self.args.optimizer.name)
        except Exception as e:
            try:
                self.optimizer = getattr(optim, self.args.optimizer.name)
            except:
                print(
                    f"This optimizer is not implemented ({self.args.optimizer.name}), go ahead and commit it"
                )
                exit()

        self.optimizer = self.optimizer(**parameters)

        if self.args.optimizer.use_SAM:
            self.optimizer = optimizers['SAM'](base_optimizer=self.optimizer,
                                               rho=self.args.optimizer.SAM_rho)

        if self.args.optimizer.use_lookahead:
            self.optimizer = torch_optimizer.Lookahead(
                self.optimizer,
                k=self.args.optimizer.lookahead_k,
                alpha=self.args.optimizer.lookahead_alpha)

    def init_scheduler(self):
        if self.args.scheduler.name not in schedulers:
            print(
                f"This loss is not implemented ({self.args.scheduler.name}), go ahead and commit it"
            )
            exit()

        parameters = OmegaConf.to_container(self.args.scheduler.parameters,
                                            resolve=True)
        parameters = {k: v for k, v in parameters.items() if v is not None}
        parameters["optimizer"] = self.optimizer
        self.scheduler = schedulers[self.args.scheduler.name](**parameters)

    def init_criterion(self):
        if self.args.loss.name not in losses:
            print(
                f"This loss is not implemented ({self.args.loss.name}), go ahead and commit it"
            )
            exit()

        parameters = OmegaConf.to_container(self.args.loss.parameters,
                                            resolve=True)
        parameters = {k: v for k, v in parameters.items() if v is not None}

        self.criterion = losses[self.args.loss.name]['constructor'](
            **parameters)

    def init_metrics(self):
        self.metrics = {
            'train': {
                'batch': [],
                'epoch': []
            },
            'test': {
                'batch': [],
                'epoch': []
            },
            'solver': {
                'batch': [],
                'epoch': []
            },
        }

        for metric in self.args.metrics.train:
            if metric.name not in metrics:
                print(
                    f"This metric is not implemented ({metric.name}), go ahead and commit it"
                )
                exit()

            metric_func = metrics[metric.name]['constructor'](
                **metric.parameters)
            metric_object = Metric(metric.name,
                                   metric_func,
                                   solver_metric=False,
                                   aggregator=metric.aggregator)
            for level in metric.levels:
                self.metrics['train'][level].append(metric_object)

        for metric in self.args.metrics.test:
            if metric.name not in metrics:
                print(
                    f"This metric is not implemented ({metric.name}), go ahead and commit it"
                )
                exit()

            metric_func = metrics[metric.name]['constructor'](
                **metric.parameters)
            metric_object = Metric(metric.name,
                                   metric_func,
                                   solver_metric=False,
                                   aggregator=metric.aggregator)
            for level in metric.levels:
                self.metrics['test'][level].append(metric_object)

        for metric in self.args.metrics.solver:
            if metric.name not in metrics:
                print(
                    f"This metric is not implemented ({metric.name}), go ahead and commit it"
                )
                exit()

            metric_func = metrics[metric.name]['constructor'](
                **metric.parameters)
            metric_object = Metric(metric.name,
                                   metric_func,
                                   solver_metric=True,
                                   aggregator=metric.aggregator)
            for level in metric.levels:
                self.metrics['solver'][level].append(metric_object)

    def disable_bn(self):
        for module in self.model.modules():
            if isinstance(module,
                          nn.modules.batchnorm._NormBase) or isinstance(
                              module, nn.LayerNorm):
                module.eval()

    def enable_bn(self):
        self.model.train()

    def train(self):
        print("train:")
        self.model.train()
        total_loss = 0
        correct = 0
        total = 0

        accumulation_data = []
        accumulation_target = []

        predictions = []
        targets = []
        for batch_num, (data, target) in enumerate(self.train_loader):
            if isinstance(data, list):
                data = [i.to(self.device) for i in data]
            else:
                data = data.to(self.device)
            if isinstance(target, list):
                target = [i.to(self.device) for i in target]
            else:
                target = target.to(self.device)

            if self.args.optimizer.use_SAM:
                accumulation_data.append(data)
                accumulation_target.append(target)

            while True:
                with autocast(enabled=self.args.half):
                    output = self.model(data)
                    if self.train_output_transform is not None:
                        output = self.train_output_transform(output)
                    loss = self.criterion(output, target)
                    loss = loss / self.args.dataset.update_every

                if self.args.optimizer.grad_penalty is not None and self.args.optimizer.grad_penalty > 0.0:
                    # Creates gradients
                    scaled_grad_params = torch.autograd.grad(
                        outputs=self.scaler.scale(loss),
                        inputs=self.model.parameters(),
                        create_graph=True)

                    #Creates unscaled grad_params before computing the penalty. scaled_grad_params are
                    # not owned by any optimizer, so ordinary division is used instead of scaler.unscale_:
                    inv_scale = 1. / self.scaler.get_scale()
                    grad_params = [p * inv_scale for p in scaled_grad_params]

                    # Computes the penalty term and adds it to the loss
                    with autocast():
                        grad_norm = 0
                        for grad in grad_params:
                            grad_norm += grad.pow(2).sum()
                        grad_norm = grad_norm.sqrt()
                        loss = loss + (grad_norm *
                                       self.args.optimizer.grad_penalty)

                self.scaler.scale(loss).backward()

                def sam_closure():
                    self.disable_bn()
                    for i in range(len(accumulation_data)):
                        with autocast(enabled=self.args.half):
                            output = self.model(accumulation_data[i])
                            if self.train_output_transform is not None:
                                output = self.train_output_transform(output)
                            loss = self.criterion(output,
                                                  accumulation_target[i])
                            loss = loss / self.args.dataset.update_every

                        if self.args.optimizer.grad_penalty is not None and self.args.optimizer.grad_penalty is not False and self.args.optimizer.grad_penalty > 0.0:
                            # Creates gradients
                            scaled_grad_params = torch.autograd.grad(
                                outputs=self.scaler.scale(loss),
                                inputs=self.model.parameters(),
                                create_graph=True)

                            #Creates unscaled grad_params before computing the penalty. scaled_grad_params are
                            # not owned by any optimizer, so ordinary division is used instead of scaler.unscale_:
                            inv_scale = 1. / self.scaler.get_scale()
                            grad_params = [
                                p * inv_scale for p in scaled_grad_params
                            ]

                            # Computes the penalty term and adds it to the loss
                            with autocast():
                                grad_norm = 0
                                for grad in grad_params:
                                    grad_norm += grad.pow(2).sum()
                                grad_norm = grad_norm.sqrt()
                                loss = loss + (
                                    grad_norm *
                                    self.args.optimizer.grad_penalty)

                        self.scaler.scale(loss).backward()
                        self.scaler.unscale_(self.optimizer)
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(),
                            self.args.optimizer.max_norm)
                    self.enable_bn()

                if self.args.optimizer.batch_replay:
                    found_inf = False
                    for _, param in self.model.named_parameters():
                        if param.grad.isnan().any() or param.grad.isinf().any(
                        ):
                            found_inf = True
                            break
                    if found_inf:
                        self.scaler.update()
                        self.optimizer.zero_grad()
                        if type(self.args.optimizer.batch_replay
                                ) == int or type(
                                    self.args.optimizer.batch_replay) == float:
                            self.args.optimizer.batch_replay -= 1
                    else:
                        break
                else:
                    break

            if self.train_batch_plot_idx % self.args.dataset.update_every == 0:
                self.scaler.unscale_(self.optimizer)

                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               self.args.optimizer.max_norm)
                self.scaler.step(self.optimizer,
                                 closure=sam_closure
                                 if self.args.optimizer.use_SAM else None)
                self.scaler.update()

                self.optimizer.zero_grad()

                accumulation_data = []
                accumulation_target = []

            predictions.extend(output)
            targets.extend(target)

            metrics_results = {}
            for metric in self.metrics['train']['batch']:
                metrics_results["Train/Batch-" +
                                metric.name] = metric.calculate(output,
                                                                target,
                                                                level='batch')

            for metric in self.metrics['solver']['batch']:
                metrics_results["Solver/Batch-" +
                                metric.name] = metric.calculate(solver=self,
                                                                level='batch')

            print_metrics(self.writer, metrics_results,
                          self.get_train_batch_plot_idx())

            if self.args.progress_bar:
                progress_bar(batch_num, len(self.train_loader))
            if self.args.scheduler.name == "OneCycleLR":
                self.scheduler.step()

        return torch.stack(predictions), torch.stack(targets)

    def test(self):
        print("test:")
        self.model.eval()
        total_loss = 0
        correct = 0
        total = 0

        predictions = []
        targets = []
        with torch.no_grad():
            for batch_num, (data, target) in enumerate(self.test_loader):
                if isinstance(data, list):
                    data = [i.to(self.device) for i in data]
                else:
                    data = data.to(self.device)
                if isinstance(target, list):
                    target = [i.to(self.device) for i in target]
                else:
                    target = target.to(self.device)

                with autocast(enabled=self.args.half):
                    output = self.model(data)
                    if self.test_output_transform is not None:
                        output = self.test_output_transform(output)
                    loss = self.criterion(output, target)

                predictions.extend(output)
                targets.extend(target)

                metrics_results = {}
                for metric in self.metrics['test']['batch']:
                    metrics_results["Test/Batch-" +
                                    metric.name] = metric.calculate(
                                        output, target, level='batch')

                print_metrics(self.writer, metrics_results,
                              self.get_test_batch_plot_idx())

                if self.args.progress_bar:
                    progress_bar(batch_num, len(self.test_loader))

        return torch.stack(predictions), torch.stack(targets)

    def infer(self):
        print("infer:")
        self.model.eval()

        predictions = []
        filenames = []
        with torch.no_grad():
            for batch_num, (filename, data) in enumerate(self.infer_loader):
                if isinstance(data, list):
                    data = [i.to(self.device) for i in data]
                else:
                    data = data.to(self.device)

                with autocast(enabled=self.args.half):
                    output = self.model(data)
                    if self.test_output_transform is not None:
                        output = self.test_output_transform(output)

                predictions.extend(output)
                filenames.extend(filename)

        return filenames, torch.stack(predictions)

    def save(self, epoch, metric, tag=None):
        if tag != None:
            tag = "_" + tag
        else:
            tag = ""
        model_out_path = os.path.join(
            self.save_dir, "model_{}_{}{}.pth".format(epoch, metric, tag))
        torch.save(self.model.state_dict(), model_out_path)
        print("Checkpoint saved to {}".format(model_out_path))

    def run(self):
        if self.args.seed is not None:
            reset_seed(self.args.seed)
        self.load_data()
        self.init_model()
        self.init_optimizer()
        self.init_scheduler()
        self.init_criterion()
        self.init_metrics()

        try:
            if self.args.infer_only == True:
                filenames, predictions = self.infer(
                )  # If its the "separated" dataset, we need to average the scores of the 2/3 different projections
                predictions = predictions.argmax(-1) + 1
                save_path = os.path.join(self.save_dir, "predictions.csv")
                pd.DataFrame({
                    'Patient': filenames,
                    'Class': predictions.cpu().numpy()
                }).to_csv(save_path, header=False, index=False)
                exit()

            best_metrics = {}
            higher_is_better = metrics[self.args.optimized_metric.split('/')
                                       [-1]]['higher_is_better']
            for epoch in range(1, self.args.epochs + 1):
                print("\n===> epoch: %d/%d" % (epoch, self.args.epochs))
                self.epoch = epoch

                metrics_results = {}

                predictions, targets = self.train()
                for metric in self.metrics['train']['epoch']:
                    metric_name = "Train/" + metric.name
                    metrics_results[metric_name] = metric.calculate(
                        predictions, targets, level='epoch')

                if self.epoch % self.args.test_every == 0:
                    predictions, targets = self.test()
                    for metric in self.metrics['test']['epoch']:
                        metric_name = "Test/" + metric.name
                        metrics_results[metric_name] = metric.calculate(
                            predictions, targets, level='epoch')

                for metric in self.metrics['solver']['epoch']:
                    metric_name = "Solver/" + metric.name
                    metrics_results[metric_name] = metric.calculate(
                        solver=self, level='epoch')

                print_metrics(self.writer, metrics_results, self.epoch)

                if self.epoch % self.args.test_every == 0:
                    save_best_metric = False
                    if self.args.optimized_metric not in best_metrics:
                        best_metrics[
                            self.args.optimized_metric] = metrics_results[
                                self.args.optimized_metric]
                        save_best_metric = True
                    if higher_is_better:
                        if best_metrics[
                                self.args.optimized_metric] < metrics_results[
                                    self.args.optimized_metric]:
                            best_metrics[
                                self.args.optimized_metric] = metrics_results[
                                    self.args.optimized_metric]
                            save_best_metric = True
                    else:
                        if best_metrics[
                                self.args.optimized_metric] > metrics_results[
                                    self.args.optimized_metric]:
                            best_metrics[
                                self.args.optimized_metric] = metrics_results[
                                    self.args.optimized_metric]
                            save_best_metric = True

                    if save_best_metric:
                        self.save(epoch,
                                  best_metrics[self.args.optimized_metric])
                        print("===> BEST " + self.args.optimized_metric +
                              " PERFORMANCE: %.5f" %
                              best_metrics[self.args.optimized_metric])

                if self.args.save_model and epoch % self.args.save_interval == 0:
                    self.save(epoch, 0)

                if self.args.scheduler.name == "MultiStepLR":
                    self.scheduler.step()
                elif self.args.scheduler.name == "ReduceLROnPlateau":
                    self.scheduler.step(
                        metrics_results[self.args.scheduler_metric])
                elif self.args.scheduler.name == "OneCycleLR":
                    pass
                else:
                    self.scheduler.step()

                if self.es.step(metrics_results[self.args.es_metric]):
                    print("Early stopping")
                    raise KeyboardInterrupt
        except KeyboardInterrupt:
            pass

        print("===> BEST " + self.args.optimized_metric +
              " PERFORMANCE: %.5f" % best_metrics[self.args.optimized_metric])
        files = os.listdir(self.save_dir)
        paths = [
            os.path.join(self.save_dir, basename) for basename in files
            if "_0" not in basename
        ]
        if len(paths) > 0:
            src = max(paths, key=os.path.getctime)
            copyfile(
                src,
                os.path.join("runs", self.args.save_dir,
                             os.path.basename(src)))

        with open("runs/" + self.args.save_dir + "/README.md", 'a+') as f:
            f.write("\n## " + self.args.optimized_metric + "\n %.5f" %
                    (best_metrics[self.args.optimized_metric]))
        tensorboard_export_dump(self.writer)
        print("Saved best accuracy checkpoint")

        return best_metrics[self.args.optimized_metric]

    def get_train_batch_plot_idx(self):
        self.train_batch_plot_idx += 1
        return self.train_batch_plot_idx - 1

    def get_test_batch_plot_idx(self):
        self.test_batch_plot_idx += 1
        return self.test_batch_plot_idx - 1
Example #5
0
class Trainer():
    images_evaluated: int = 0
    accumulated_batches: int = 0
    accumulated_loss: float = 0.0
    accumulated_accuracy: float = 0.0

    all_predictions = []
    all_labels = []

    def __init__(self, options: TrainerOptions):
        self.net = options.net
        self.dataloader = options.dataloader
        self.optimizer = options.optimizer
        self.criterion = options.criterion
        self.save_dir = options.save_dir
        self.freeze = options.freeze
        self.accumulate_over_n_batches = options.accumulate_over_n_batches
        self.distributed = options.distributed
        self.gpu_rank = options.gpu_rank
        self.n_gpus = options.n_gpus
        self.test_time_bn = options.test_time_bn
        self.dtype = options.dtype
        if self.distributed:
            self.config_distributed(self.n_gpus, self.gpu_rank)
        self.mixedprecision = options.mixedprecision
        if self.mixedprecision:
            self.grad_scaler = GradScaler(init_scale=8192, growth_interval=4)
        self.multilabel = options.multilabel
        self.regression = options.regression
        self.reset_epoch_stats()

    def config_distributed(self, n_gpus, gpu_rank=None):
        self.sync_networks_distributed_if_needed()
        self.n_gpus = torch.cuda.device_count() if n_gpus is None else n_gpus
        assert gpu_rank is not None
        self.gpu_rank = gpu_rank

    def sync_networks_distributed_if_needed(self, check=True):
        if self.distributed: self.sync_network_distributed(self.net, check)

    def sync_network_distributed(self, net, check=True):
        for _, param in net.named_parameters():
            dist.broadcast(param.data, 0)

        for mod in net.modules():
            if isinstance(mod, torch.nn.BatchNorm2d):
                dist.broadcast(mod.running_mean, 0)
                dist.broadcast(mod.running_var, 0)

    def prepare_network_for_training(self):
        torch.set_grad_enabled(True)
        self.optimizer.zero_grad()
        self.net.train()
        for mod in self.freeze:
            mod.eval()

    def prepare_network_for_evaluation(self):
        torch.set_grad_enabled(False)
        self.net.eval()
        self.prepare_batchnorm_for_evaluation(self.net)

    def prepare_batchnorm_for_evaluation(self, net):
        for mod in net.modules():
            if isinstance(mod, torch.nn.BatchNorm2d):
                if self.test_time_bn: mod.train()
                else: mod.eval()

    def reset_epoch_stats(self):
        self.accumulated_loss = 0
        self.accumulated_accuracy = 0
        self.batches_evaluated = 0
        self.images_evaluated = 0
        self.accumulated_batches = 0

        self.all_predictions = []
        self.all_labels = []

    def save_batch_stats(self, loss, accuracy, predictions, labels):
        self.accumulated_loss += float(loss) * len(labels)
        self.accumulated_accuracy += accuracy * len(labels)
        self.batches_evaluated += 1
        self.images_evaluated += len(labels)

        self.all_predictions.append(predictions)
        self.all_labels.append(
            labels.copy()
        )  # https://github.com/pytorch/pytorch/issues/973#issuecomment-459398189 | fix RuntimeError: received 0 items of ancdata

    def stack_epoch_predictions(self):
        self.all_predictions, self.all_labels = self.epoch_predictions_and_labels(
            gather=True)

    def correct_loss_for_multigpu(self):
        self.accumulated_loss = 0.0
        self.accumulated_accuracy = 0.0
        for pred, label in zip(self.all_predictions, self.all_labels):
            self.accumulated_loss += float(
                self.criterion(pred[None], label[None]))
            self.accumulated_accuracy += self.accuracy_with_predictions(
                pred[None], label)
        self.images_evaluated = len(self.all_predictions)

    def epoch_predictions_and_labels(self, gather=False):
        preds, labels = [], []
        if len(self.all_predictions) > 0:
            preds = np.vstack(self.all_predictions)
            labels = np.vstack(self.all_labels)

            if self.distributed and gather:
                preds = list(self.gather(preds))
                labels = list(self.gather(labels))

            preds = torch.from_numpy(np.array(preds))
            labels = torch.from_numpy(
                np.array(labels).astype(self.all_labels[0].dtype))

            # reshape to correct shapes
            if len(self.all_labels[0].shape) == 1: labels = labels.flatten()
            # labels = labels.view(-1, self.all_labels[0].shape[0])
            if len(self.all_predictions[0].shape) == 1: preds = preds.flatten()
            # preds = preds.view(-1, self.all_predictions[0].shape[0])
            return preds.float(), labels
        else:
            return torch.FloatTensor(), torch.LongTensor()

    def gather(self, results):
        results = torch.tensor(results, dtype=torch.float32).cuda()
        tensor_list = [
            results.new_empty(results.shape) for _ in range(self.n_gpus)
        ]
        dist.all_gather(tensor_list, results)
        cpu_list = [tensor.cpu().numpy() for tensor in tensor_list]
        return np.concatenate(cpu_list, axis=0)

    def average_epoch_loss(self):
        if self.images_evaluated == 0: return -1
        return self.accumulated_loss / self.images_evaluated

    def average_epoch_accuracy(self):
        if self.images_evaluated == 0: return -1
        return self.accumulated_accuracy / self.images_evaluated

    def train_epoch(self, batch_callback) -> typing.Tuple[np.array, np.array]:
        self.sync_networks_distributed_if_needed()
        self.prepare_network_for_training()
        self.reset_epoch_stats()
        self.train_full_dataloader(batch_callback)
        self.stack_epoch_predictions()
        if self.distributed: self.correct_loss_for_multigpu()
        return self.all_predictions, self.all_labels

    def validation_epoch(self, batch_callback):
        self.prepare_network_for_evaluation()
        self.reset_epoch_stats()
        self.evaluate_full_dataloader(batch_callback)
        self.stack_epoch_predictions()
        if self.distributed: self.correct_loss_for_multigpu()
        return self.all_predictions, self.all_labels

    def train_full_dataloader(self, batch_callback):
        for x, y in self.dataloader:
            loss, accuracy, predictions = self.train_on_batch(x, y)
            self.save_batch_stats(loss, accuracy, predictions, y.cpu().numpy())
            batch_callback(self, self.batches_evaluated, loss, accuracy)

    def evaluate_full_dataloader(self, batch_callback):
        for x, y in self.dataloader:
            loss, accuracy, predictions = self.forward_batch(x, y)
            self.save_batch_stats(loss, accuracy, predictions, y.cpu().numpy())
            batch_callback(self, self.batches_evaluated, loss, accuracy)

    def forward_batch(self, x, y):
        if self.mixedprecision:
            with autocast():
                output, loss = self.forward_batch_with_loss(x, y)
        else:
            output, loss = self.forward_batch_with_loss(x, y)
        output = output.detach().cpu()
        accuracy = self.accuracy_with_predictions(output, y.cpu())
        # NOTE: removed a `del output` here, could cause memory issues
        return loss, accuracy, output.numpy()

    def forward_batch_with_loss(self, x, y):
        output = self.net.forward(x.cuda())
        label = y.cuda()
        loss = self.criterion(output, label)
        return output, loss

    def train_on_batch(self, x, y):
        loss, accuracy, predictions = self.forward_batch(x, y)
        full_loss = float(loss)
        loss = loss / self.accumulate_over_n_batches / self.n_gpus
        if self.mixedprecision: self.grad_scaler.scale(loss).backward()
        else: loss.backward()
        self.accumulated_batches += 1
        self.step_optimizer_if_needed()
        return full_loss, accuracy, predictions

    def step_optimizer_if_needed(self):
        if self.accumulated_batches == self.accumulate_over_n_batches:
            self.distribute_gradients_if_needed()
            if self.mixedprecision:
                self.grad_scaler.step(self.optimizer)
                self.grad_scaler.update()
                # prohibit scales larger than 65536, training crashes,
                # maybe due to gradient accumulation?
                if self.grad_scaler.get_scale() > 65536.0:
                    self.grad_scaler.update(new_scale=65536.0)
            else:
                self.optimizer.step()
            self.optimizer.zero_grad()
            self.accumulated_batches = 0

    def distribute_gradients_if_needed(self):
        if self.distributed:
            for _, param in self.net.named_parameters():
                if param.grad is not None:
                    dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)

    def save_checkpoint(self, name, epoch, additional={}):
        state = {
            'checkpoint': epoch,
            'state_dict': self.net.state_dict(),
            'optimizer': self.optimizer.state_dict()
        }
        state.update(additional)
        print('Saving', 'checkpoint_' + name + '_' + str(epoch) + '_network')
        try:
            torch.save(
                state,
                self.save_dir / pathlib.Path('checkpoint_' + name + '_' +
                                             str(epoch) + '_network'))
            torch.save(
                state,
                self.save_dir / pathlib.Path('checkpoint_' + name + '_last'))
        except Exception as e:
            print('WARNING: Network not stored', e)

    def checkpoint_available_for_name(self, name, epoch=-1):
        if epoch > -1:
            print(self.save_dir / pathlib.Path('checkpoint_' + name + '_' +
                                               str(epoch) + '_network'))
            print(
                os.path.isfile(self.save_dir /
                               pathlib.Path('checkpoint_' + name + '_' +
                                            str(epoch) + '_network')))
            return os.path.isfile(self.save_dir /
                                  pathlib.Path('checkpoint_' + name + '_' +
                                               str(epoch) + '_network'))
        else:
            return os.path.isfile(self.save_dir /
                                  pathlib.Path('checkpoint_' + name + '_last'))

    def load_network_checkpoint(self, name):
        state = torch.load(self.save_dir / pathlib.Path('checkpoint_' + name))
        self.load_state_dict(state)

    def load_checkpoint(self, name, epoch=-1):
        if epoch > -1:
            state = torch.load(self.save_dir /
                               pathlib.Path('checkpoint_' + name + '_' +
                                            str(epoch) + '_network'),
                               map_location=lambda storage, loc: storage)
        else:
            state = torch.load(self.save_dir /
                               pathlib.Path('checkpoint_' + name + '_last'),
                               map_location=lambda storage, loc: storage)
        return state

    def load_state_dict(self, state):
        try:
            self.optimizer.load_state_dict(state['optimizer'])
        except KeyError:
            print('WARNING: Optimizer not restored')
        self.net.load_state_dict(state['state_dict'])

    def load_checkpoint_if_available(self, name, epoch=-1):
        if self.checkpoint_available_for_name(name, epoch):
            state = self.load_checkpoint(name, epoch)
            self.load_state_dict(state)
            return True, state
        return False, None

    def accuracy_with_predictions(self, predictions, labels):
        if self.regression:
            return 0
        if self.multilabel:
            equal = np.equal(np.round(torch.sigmoid(predictions.float())),
                             labels.numpy() == 1)
            equal_c = np.sum(equal, axis=1)
            equal = (equal_c == labels.shape[1]).sum()
        elif predictions.shape[1] == 1:
            equal = np.equal(np.round(torch.sigmoid(predictions.float())),
                             labels)
        else:
            equal = np.equal(
                np.argmax(torch.softmax(predictions.float(), dim=1), axis=1),
                labels)
        return float(equal.sum()) / float(predictions.shape[0])
Example #6
0
class TestGradientScalingAMP(unittest.TestCase):
    def setUp(self):
        self.x = torch.tensor([2.0]).cuda().half()
        weight = 3.0
        bias = 5.0
        self.error = 1.0
        self.target = torch.tensor([self.x * weight + bias + self.error
                                    ]).cuda()
        self.loss_fn = torch.nn.L1Loss()

        self.model = torch.nn.Linear(1, 1)
        self.model.weight.data = torch.tensor([[weight]])
        self.model.bias.data = torch.tensor([bias])
        self.model.cuda()
        self.params = list(self.model.parameters())

        self.namespace_dls = argparse.Namespace(
            optimizer="adam",
            lr=[0.1],
            adam_betas="(0.9, 0.999)",
            adam_eps=1e-8,
            weight_decay=0.0,
            threshold_loss_scale=1,
            min_loss_scale=1e-4,
        )
        self.scaler = GradScaler(
            init_scale=1,
            growth_interval=1,
        )

    def run_iter(self, model, params, optimizer):
        optimizer.zero_grad()
        with autocast():
            y = model(self.x)
            loss = self.loss_fn(y, self.target)
        self.scaler.scale(loss).backward()
        self.assertEqual(
            loss, torch.tensor(1.0, device="cuda:0", dtype=torch.float16))

        self.scaler.unscale_(optimizer)
        grad_norm = optimizer.clip_grad_norm(0)
        self.assertAlmostEqual(grad_norm.item(), 2.2361, 4)

        self.scaler.step(optimizer)
        self.scaler.update()
        self.assertEqual(
            model.weight,
            torch.tensor([[3.1]], device="cuda:0", requires_grad=True),
        )
        self.assertEqual(
            model.bias,
            torch.tensor([5.1], device="cuda:0", requires_grad=True),
        )
        self.assertEqual(self.scaler.get_scale(), 2.0)

    def test_automatic_mixed_precision(self):
        model = copy.deepcopy(self.model)
        params = list(model.parameters())
        optimizer = build_optimizer(self.namespace_dls, params)

        self.run_iter(model, params, optimizer)
def main(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # delete args.output_dir if the flag is set and the directory exists
    if args.clear_output_dir and args.output_dir.exists():
        rmtree(args.output_dir)
    args.output_dir.mkdir(parents=True, exist_ok=True)
    args.checkpoint_dir = args.output_dir / 'checkpoints'
    args.checkpoint_dir.mkdir(parents=True, exist_ok=True)

    args.cuda = not args.no_cuda and torch.cuda.is_available()
    args.device = torch.device("cuda" if args.cuda else "cpu")

    train_loader, val_loader, test_loader = get_loaders(args)

    summary = Summary(args)

    scaler = GradScaler(enabled=args.mixed_precision)
    args.output_logits = (args.loss in ['bce', 'binarycrossentropy']
                          and args.model != 'identity')

    model = get_model(args, summary)
    if args.weights_dir is not None:
        model = utils.load_weights(args, model)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = StepLR(optimizer,
                       step_size=5,
                       gamma=args.gamma,
                       verbose=args.verbose == 2)
    loss_function = utils.get_loss_function(name=args.loss)

    critic = None if args.critic is None else Critic(args, summary=summary)

    utils.save_args(args)

    args.global_step = 0
    for epoch in range(args.epochs):
        print(f'Epoch {epoch + 1:03d}/{args.epochs:03d}')
        start = time()
        train_results = train(args,
                              model=model,
                              data=train_loader,
                              optimizer=optimizer,
                              loss_function=loss_function,
                              scaler=scaler,
                              summary=summary,
                              epoch=epoch,
                              critic=critic)
        val_results = validate(args,
                               model=model,
                               data=val_loader,
                               loss_function=loss_function,
                               summary=summary,
                               epoch=epoch,
                               critic=critic)
        end = time()

        scheduler.step()

        summary.scalar('elapse', end - start, step=epoch, mode=0)
        summary.scalar('lr', scheduler.get_last_lr()[0], step=epoch, mode=0)
        summary.scalar('gradient_scale',
                       scaler.get_scale(),
                       step=epoch,
                       mode=0)

        print(f'Train\t\tLoss: {train_results["Loss"]:.04f}\n'
              f'Validation\tLoss: {val_results["Loss"]:.04f}\t'
              f'MAE: {val_results["MAE"]:.04f}\t'
              f'PSNR: {val_results["PSNR"]:.02f}\t'
              f'SSIM: {val_results["SSIM"]:.04f}\n')

    utils.save_model(args, model)

    test(args,
         model=model,
         data=test_loader,
         loss_function=loss_function,
         summary=summary,
         epoch=args.epochs,
         critic=critic)

    summary.close()
Example #8
0
def run_epoch(model,
              optimizer,
              train_ldr,
              logger,
              debug_mode: bool,
              tbX_writer,
              iter_count: int,
              avg_loss: float,
              local_rank: int,
              loss_name: str,
              save_path: str,
              gcs_ckpt_handler,
              scaler: GradScaler = None) -> tuple:
    """
    Performs a forwards and backward pass through the model
    Args:
        iter_count (int): count of iterations
        save_path (str): path to directory where model is saved
        gcs_ckpt_handler: facilities saving files to google cloud storage
        scaler (GradScaler): gradient scaler to prevent gradient underflow when autocast
            uses float16 precision for forward pass
    Returns:
        Tuple[int, float]: train state of # batch iterations and average loss
    """
    # booleans and constants for logging
    is_rank_0 = (torch.distributed.get_rank() == 0)
    use_log = (logger is not None and is_rank_0)
    log_modulus = 100  # limits certain logging function to report less frequently
    exp_w = 0.985  # exponential weight for exponential moving average loss
    avg_grad_norm = 0
    model_t, data_t = 0.0, 0.0
    end_t = time.time()

    # progress bar for rank_0 process
    tq = tqdm.tqdm(train_ldr) if is_rank_0 else train_ldr

    # counter for model checkpointing
    batch_counter = 0
    device = torch.device("cuda:" + str(local_rank))

    # if scaler is enabled, amp is being used
    use_amp = scaler.is_enabled()
    print(f"Amp is being used: {use_amp}")

    # training loop
    for batch in tq:
        if use_log:
            logger.info(
                f"train: ====== Iteration: {iter_count} in run_epoch =======")

        ##############  Mid-epoch checkpoint ###############
        if is_rank_0 \
        and batch_counter % (len(train_ldr) // gcs_ckpt_handler.chkpt_per_epoch) == 0 \
        and batch_counter != 0:
            preproc = train_ldr.dataset.preproc
            save(model.module, preproc, save_path, tag='ckpt')
            gcs_ckpt_handler.upload_to_gcs("ckpt_model_state_dict.pth")
            gcs_ckpt_handler.upload_to_gcs("ckpt_preproc.pyc")
            # save the run_sate
            ckpt_state_path = os.path.join(save_path, "ckpt_run_state.pickle")
            write_pickle(ckpt_state_path,
                         {'run_state': (iter_count, avg_loss)})
            gcs_ckpt_handler.upload_to_gcs("ckpt_run_state.pickle")
            # checkpoint tensorboard
            gcs_ckpt_handler.upload_tensorboard_ckpt()

        batch_counter += 1
        ####################################################

        # convert the temprorary generator batch to a permanent list
        batch = list(batch)

        # save the batch information
        if use_log:
            if debug_mode:
                save_batch_log_stats(batch, logger)
                log_batchnorm_mean_std(model.module.state_dict(), logger)

        start_t = time.time()
        optimizer.zero_grad(
            set_to_none=True)  # set grads to None for modest perf improvement

        #  will autocast to lower precision if amp is used. otherwise, it's no-operation
        with autocast(enabled=use_amp):
            # unpack the batch
            inputs, labels, input_lens, label_lens = model.module.collate(
                *batch)
            inputs = inputs.cuda()  #.to(device) #.cuda(local_rank)
            out, rnn_args = model(inputs, softmax=False)

            # use the loss function defined in `loss_name`
            if loss_name == "native":
                loss = native_loss(out, labels, input_lens, label_lens,
                                   model.module.blank)
            elif loss_name == "awni":
                loss = awni_loss(out, labels, input_lens, label_lens,
                                 model.module.blank)
            elif loss_name == "naren":
                loss = naren_loss(out, labels, input_lens, label_lens,
                                  model.module.blank)

        # backward pass
        loss = loss.cuda()  # amp needs the loss to be on cuda
        scaler.scale(loss).backward()

        if use_log:
            if debug_mode:
                plot_grad_flow_bar(model.module.named_parameters(),
                                   get_logger_filename(logger))
                log_param_grad_norms(model.module.named_parameters(), logger)

        # gradient clipping and optimizer step, scaling disabled if amp is not used
        scaler.unscale_(optimizer)
        grad_norm = nn.utils.clip_grad_norm_(model.parameters(), 200).item()
        scaler.step(optimizer)
        scaler.update()

        # logging in rank_0 process
        if is_rank_0:
            # calculate timers
            prev_end_t = end_t
            end_t = time.time()
            model_t += end_t - start_t
            data_t += start_t - prev_end_t

            # creating scalers from grad_norm and loss for weighted
            # TODO, needed with pytorch 0.4, may not be necessary anymore
            if isinstance(grad_norm, torch.Tensor):
                grad_norm = grad_norm.item()
            if isinstance(loss, torch.Tensor):
                loss = loss.item()

            # calculating the weighted average of loss and grad_norm
            if iter_count == 0:
                avg_loss = loss
                avg_grad_norm = grad_norm
            else:
                avg_loss = exp_w * avg_loss + (1 - exp_w) * loss
                avg_grad_norm = exp_w * avg_grad_norm + (1 - exp_w) * grad_norm

            # writing to the tensorboard log files
            tbX_writer.add_scalars('train/loss', {"loss": loss}, iter_count)
            tbX_writer.add_scalars('train/loss', {"avg_loss": avg_loss},
                                   iter_count)

            # adding this to suppress a tbX WARNING about inf values
            # TODO, this may or may not be a good idea as it masks inf in tensorboard
            if grad_norm == float('inf') or math.isnan(grad_norm):
                tbX_grad_norm = 1
            else:
                tbX_grad_norm = grad_norm
            tbX_writer.add_scalars('train/grad', {"grad_norm": tbX_grad_norm},
                                   iter_count)

            # progress bar update
            tq.set_postfix(it=iter_count,
                           grd_nrm=grad_norm,
                           lss=loss,
                           lss_av=avg_loss,
                           t_mdl=model_t,
                           t_data=data_t,
                           scl=scaler.get_scale())
            if use_log:
                logger.info(f'train: loss is inf: {loss == float("inf")}')
                logger.info(
                    f"train: iter={iter_count}, loss={round(loss,3)}, grad_norm={round(grad_norm,3)}"
                )

            if iter_count % log_modulus == 0:
                if use_log: log_cpu_mem_disk_usage(logger)

        # checks for nan gradients
        if check_nan_params_grads(model.module.parameters()):
            print("\n~~~ NaN value detected in gradients or parameters ~~~\n")
            if use_log:
                logger.error(
                    f"train: labels: {[labels]}, label_lens: {label_lens} state_dict: {model.module.state_dict()}"
                )
                log_model_grads(model.module.named_parameters(), logger)
                save_batch_log_stats(batch, logger)
                log_param_grad_norms(model.module.named_parameters(), logger)
                plot_grad_flow_bar(model.module.named_parameters(),
                                   get_logger_filename(logger))

            #debug_mode = True
            #torch.autograd.set_detect_anomaly(True)

        iter_count += 1

    return iter_count, avg_loss
Example #9
0
class Trainer():
    def __init__(self, cfg, writer, img_writer, logger, run_id):
        # Copy shared config fields
        if "monodepth_options" in cfg:
            cfg["data"].update(cfg["monodepth_options"])
            cfg["model"].update(cfg["monodepth_options"])
            cfg["training"]["monodepth_loss"].update(cfg["monodepth_options"])
        if "generated_depth_dir" in cfg["data"]:
            dataset_name = f"{cfg['data']['dataset']}_" \
                           f"{cfg['data']['width']}x{cfg['data']['height']}"
            depth_teacher = cfg["data"].get("depth_teacher", None)
            assert not (depth_teacher and cfg['model'].get('detph_estimator_weights') is not None)
            if depth_teacher is not None:
                cfg["data"]["generated_depth_dir"] += dataset_name + "/" + depth_teacher + "/"
            else:
                cfg["data"]["generated_depth_dir"] += dataset_name + "/" + cfg['model']['depth_estimator_weights'] + "/"

        # Setup seeds
        setup_seeds(cfg.get("seed", 1337))
        if cfg["data"]["dataset_seed"] == "same":
            cfg["data"]["dataset_seed"] = cfg["seed"]

        # Setup device
        torch.backends.cudnn.benchmark = cfg["training"].get("benchmark", True)
        self.cfg = cfg
        self.writer = writer
        self.img_writer = img_writer
        self.logger = logger
        self.run_id = run_id
        self.mIoU = 0
        self.fwAcc = 0
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.setup_segmentation_unlabeled()

        self.unlabeled_require_depth = (self.cfg["training"]["unlabeled_segmentation"] is not None and
                                        (self.cfg["training"]["unlabeled_segmentation"]["mix_mask"] == "depth" or
                                         self.cfg["training"]["unlabeled_segmentation"]["mix_mask"] == "depthcomp" or
                                         self.cfg["training"]["unlabeled_segmentation"]["mix_mask"] == "depthhist"))

        # Prepare depth estimates
        do_precalculate_depth = self.cfg["training"]["segmentation_lambda"] != 0 and self.unlabeled_require_depth and \
                                self.cfg['model']['segmentation_name'] != 'mtl_pad'
        use_depth_teacher = cfg["data"].get("depth_teacher", None) is not None
        if do_precalculate_depth or use_depth_teacher:
            assert not (do_precalculate_depth and use_depth_teacher)
            if not self.cfg["training"].get("disable_depth_estimator", False):
                print("Prepare depth estimates")
                depth_estimator = DepthEstimator(cfg)
                depth_estimator.prepare_depth_estimates()
                del depth_estimator
                torch.cuda.empty_cache()
        else:
            self.cfg["data"]["generated_depth_dir"] = None

        # Setup Dataloader
        load_labels, load_sequence = True, True
        if self.cfg["training"]["monodepth_lambda"] == 0:
            load_sequence = False
        if self.cfg["training"]["segmentation_lambda"] == 0:
            load_labels = False
        train_data_cfg = deepcopy(self.cfg["data"])
        if not do_precalculate_depth and not use_depth_teacher:
            train_data_cfg["generated_depth_dir"] = None
        self.train_loader = build_loader(train_data_cfg, "train", load_labels=load_labels, load_sequence=load_sequence)
        if self.cfg["training"].get("minimize_entropy_unlabeled", False) or self.enable_unlabled_segmentation:
            unlabeled_segmentation_cfg = deepcopy(self.cfg["data"])
            if not self.only_unlabeled and self.mix_use_gt:
                unlabeled_segmentation_cfg["load_onehot"] = True
            if self.only_unlabeled:
                unlabeled_segmentation_cfg.update({"load_unlabeled": True, "load_labeled": False})
            elif self.only_labeled:
                unlabeled_segmentation_cfg.update({"load_unlabeled": False, "load_labeled": True})
            else:
                unlabeled_segmentation_cfg.update({"load_unlabeled": True, "load_labeled": True})
            if self.mix_video:
                assert not self.mix_use_gt and not self.only_labeled and not self.only_unlabeled, \
                    "Video sample indices are not compatible with non-video indices."
                unlabeled_segmentation_cfg.update({"only_sequences_with_segmentation": not self.mix_video,
                                                   "restrict_to_subset": None})
            self.unlabeled_loader = build_loader(unlabeled_segmentation_cfg, "train",
                                                 load_labels=load_labels if not self.mix_video else False,
                                                 load_sequence=load_sequence)
        else:
            self.unlabeled_loader = None
        self.val_loader = build_loader(self.cfg["data"], "val", load_labels=load_labels,
                                       load_sequence=load_sequence)
        self.n_classes = self.train_loader.n_classes

        # monodepth dataloader settings uses drop_last=True and shuffle=True even for val
        self.train_data_loader = data.DataLoader(
            self.train_loader,
            batch_size=self.cfg["training"]["batch_size"],
            num_workers=self.cfg["training"]["n_workers"],
            shuffle=self.cfg["data"]["shuffle_trainset"],
            pin_memory=True,
            # Setting to false will cause crash at the end of epoch
            drop_last=True,
        )
        if self.unlabeled_loader is not None:
            self.unlabeled_data_loader = infinite_iterator(data.DataLoader(
                self.unlabeled_loader,
                batch_size=self.cfg["training"]["batch_size"],
                num_workers=self.cfg["training"]["n_workers"],
                shuffle=self.cfg["data"]["shuffle_trainset"],
                pin_memory=True,
                # Setting to false will cause crash at the end of epoch
                drop_last=True,
            ))

        self.val_batch_size = self.cfg["training"]["val_batch_size"]
        self.val_data_loader = data.DataLoader(
            self.val_loader,
            batch_size=self.val_batch_size,
            num_workers=self.cfg["training"]["n_workers"],
            pin_memory=True,
            # If using a dataset with odd number of samples (CamVid), the memory consumption suddenly increases for the
            # last batch. This can be circumvented by dropping the last batch. Only do that if it is necessary for your
            # system as it will result in an incomplete validation set.
            # drop_last=True,
        )

        # Setup Model
        self.model = get_model(cfg["model"], self.n_classes).to(self.device)
        # print(self.model)
        assert not (self.enable_unlabled_segmentation and self.cfg["training"]["save_monodepth_ema"])
        if self.enable_unlabled_segmentation and not self.only_labeled:
            print("Create segmentation ema model.")
            self.ema_model = self.create_ema_model(self.model).to(self.device)
        elif self.cfg["training"]["save_monodepth_ema"]:
            print("Create depth ema model.")
            # TODO: Try to remove unnecessary components and fit into gpu for better performance
            self.ema_model = self.create_ema_model(self.model)  # .to(self.device)
        else:
            self.ema_model = None

        # Setup optimizer, lr_scheduler and loss function
        optimizer_cls = get_optimizer(cfg)
        optimizer_params = {k: v for k, v in cfg["training"]["optimizer"].items() if
                            k not in ["name", "backbone_lr", "pose_lr", "depth_lr", "segmentation_lr"]}
        train_params = get_train_params(self.model, self.cfg)
        self.optimizer = optimizer_cls(train_params, **optimizer_params)

        self.scheduler = get_scheduler(self.optimizer, self.cfg["training"]["lr_schedule"])

        # Creates a GradScaler once at the beginning of training.
        self.scaler = GradScaler(enabled=self.cfg["training"]["amp"])

        self.loss_fn = get_segmentation_loss_function(self.cfg)
        self.monodepth_loss_calculator_train = get_monodepth_loss(self.cfg, is_train=True)
        self.monodepth_loss_calculator_val = get_monodepth_loss(self.cfg, is_train=False, batch_size=self.val_batch_size)

        if cfg["training"]["early_stopping"] is None:
            logger.info("Using No Early Stopping")
            self.earlyStopping = None
        else:
            self.earlyStopping = EarlyStopping(
                patience=round(cfg["training"]["early_stopping"]["patience"] / cfg["training"]["val_interval"]),
                min_delta=cfg["training"]["early_stopping"]["min_delta"],
                cumulative_delta=cfg["training"]["early_stopping"]["cum_delta"],
                logger=logger
            )

    def extract_monodepth_ema_params(self, model, ema_model):
        model_names = ["depth"]
        if not self.cfg["model"]["freeze_backbone"]:
            model_names.append("encoder")

        return extract_ema_params(model, ema_model, model_names)

    def extract_pad_ema_params(self, model, ema_model):
        model_names = ["depth", "encoder", "mtl_decoder"]
        return extract_ema_params(model, ema_model, model_names)

    def create_ema_model(self, model):
        ema_cfg = deepcopy(self.cfg["model"])
        ema_cfg["disable_pose"] = True
        ema_model = get_model(ema_cfg, self.n_classes)
        if self.cfg["training"]["save_monodepth_ema"]:
            mp, mcp = self.extract_monodepth_ema_params(model, ema_model)
        elif self.cfg['model']['segmentation_name'] == 'mtl_pad':
            mp, mcp = self.extract_pad_ema_params(model, ema_model)
        else:
            mp, mcp = list(model.parameters()), list(ema_model.parameters())
        for param in mcp:
            param.detach_()
        assert len(mp) == len(mcp), f"len(mp)={len(mp)}; len(mcp)={len(mcp)}"
        n = len(mp)
        for i in range(0, n):
            mcp[i].data[:] = mp[i].to(mcp[i].device, non_blocking=True).data[:].clone()
        return ema_model

    def update_ema_variables(self, ema_model, model, alpha_teacher, iteration):
        if self.cfg["training"]["save_monodepth_ema"]:
            model_params, ema_params = self.extract_monodepth_ema_params(model, ema_model)
        elif self.cfg['model']['segmentation_name'] == 'mtl_pad':
            model_params, ema_params = self.extract_pad_ema_params(model, ema_model)
        else:
            model_params, ema_params = model.parameters(), ema_model.parameters()
        # Use the "true" average until the exponential average is more correct
        alpha_teacher = min(1 - 1 / (iteration + 1), alpha_teacher)
        for ema_param, param in zip(ema_params, model_params):
            ema_param.data[:] = alpha_teacher * ema_param[:].data[:] + \
                                (1 - alpha_teacher) * param.to(ema_param.device, non_blocking=True)[:].data[:]
        return ema_model

    def save_resume(self, step):
        if self.ema_model is not None:
            raise NotImplementedError("ema model not supported")
        state = {
            "epoch": step + 1,
            "model_state": self.model.state_dict(),
            "optimizer_state": self.optimizer.state_dict(),
            "scheduler_state": self.scheduler.state_dict(),
            "best_iou": self.best_iou,
        }
        save_path = os.path.join(
            self.writer.file_writer.get_logdir(),
            "best_model.pkl"
        )
        torch.save(state, save_path)
        return save_path

    def save_monodepth_models(self):
        if self.cfg["training"]["save_monodepth_ema"]:
            print("Save ema monodepth models.")
            assert self.ema_model is not None
            model_to_save = self.ema_model
        else:
            model_to_save = self.model
        models = ["depth", "pose_encoder", "pose"]
        if not self.cfg["model"]["freeze_backbone"]:
            models.append("encoder")
        for model_name in models:
            save_path = os.path.join(self.writer.file_writer.get_logdir(), "{}.pth".format(model_name))
            to_save = model_to_save.models[model_name].state_dict()
            torch.save(to_save, save_path)

    def load_resume(self, strict=True, load_model_only=False):
        if os.path.isfile(self.cfg["training"]["resume"]):
            self.logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(self.cfg["training"]["resume"])
            )
            checkpoint = torch.load(self.cfg["training"]["resume"])
            self.model.load_state_dict(checkpoint["model_state"], strict=strict)
            if not load_model_only:
                self.optimizer.load_state_dict(checkpoint["optimizer_state"])
                self.scheduler.load_state_dict(checkpoint["scheduler_state"])
            self.start_iter = checkpoint["epoch"]
            self.best_iou = checkpoint["best_iou"]
            self.logger.info(
                "Loaded checkpoint '{}' (iter {})".format(
                    self.cfg["training"]["resume"], checkpoint["epoch"]
                )
            )
        else:
            self.logger.info("No checkpoint found at '{}'".format(self.cfg["training"]["resume"]))

    def tensorboard_training_images(self):
        num_saved = 0
        if self.cfg["training"]["n_tensorboard_trainimgs"] == 0:
            return
        for inputs in self.train_data_loader:
            images = inputs[("color_aug", 0, 0)]
            labels = inputs["lbl"]
            for img, label in zip(images.numpy(), labels.numpy()):
                if num_saved < self.cfg["training"]["n_tensorboard_trainimgs"]:
                    num_saved += 1
                    self.img_writer.add_image(
                        "trainset_{}/{}_0image".format(self.run_id.replace('/', '_'), num_saved), img,
                        global_step=0)
                    colored_image = self.val_loader.decode_segmap_tocolor(label)
                    self.img_writer.add_image(
                        "trainset_{}/{}_1ground_truth".format(self.run_id.replace('/', '_'), num_saved),
                        colored_image,
                        global_step=0, dataformats="HWC")
            if num_saved >= self.cfg["training"]["n_tensorboard_trainimgs"]:
                break

    def _train_batchnorm(self, model, train, only_encoder=False):
        if only_encoder:
            modules = model.models["encoder"].modules()
        else:
            modules = model.modules()
        for m in modules:
            if isinstance(m, nn.BatchNorm2d):
                m.train(train)

    def train_step(self, inputs, step):
        self.model.train()
        if self.ema_model is not None:
            self.ema_model.train()

        for k, v in inputs.items():
            if torch.is_tensor(v):
                inputs[k] = v.to(self.device, non_blocking=True)

        if self.enable_unlabled_segmentation:
            unlabeled_inputs = self.unlabeled_data_loader.__next__()
            for k in unlabeled_inputs.keys():
                if "color_aug" in k or "K" in k or "inv_K" in k or "color" in k or k in ["onehot_lbl", "pseudo_depth"]:
                    # print(f"Move {k} to gpu.")
                    unlabeled_inputs[k] = unlabeled_inputs[k].to(self.device, non_blocking=True)

        self.optimizer.zero_grad()
        segmentation_loss = torch.tensor(0)
        segmentation_total_loss = torch.tensor(0)
        mono_loss = torch.tensor(0)
        feat_dist_loss = torch.tensor(0)
        mono_total_loss = torch.tensor(0)

        if self.cfg["model"].get("freeze_backbone_bn", False):
            self._train_batchnorm(self.model, False, only_encoder=True)

        with autocast(enabled=self.cfg["training"]["amp"]):
            outputs = self.model(inputs)

        # Train monodepth
        if self.cfg["training"]["monodepth_lambda"] > 0:
            for k, v in outputs.items():
                if "depth" in k or "cam_T_cam" in k:
                    outputs[k] = v.to(torch.float32)
            self.monodepth_loss_calculator_train.generate_images_pred(inputs, outputs)
            mono_losses = self.monodepth_loss_calculator_train.compute_losses(inputs, outputs)
            mono_lambda = self.cfg["training"]["monodepth_lambda"]
            mono_loss = mono_lambda * mono_losses["loss"]
            feat_dist_lambda = self.cfg["training"]["feat_dist_lambda"]
            if feat_dist_lambda > 0:
                feat_dist = torch.dist(outputs["encoder_features"], outputs["imnet_features"], p=2)
                feat_dist_loss = feat_dist_lambda * feat_dist
            mono_total_loss = mono_loss + feat_dist_loss

            self.scaler.scale(mono_total_loss).backward(retain_graph=True)

        # Train depth on pseudo-labels
        if self.cfg["training"].get("pseudo_depth_lambda", 0) > 0:
            # Crop away bottom of image with own car
            with torch.no_grad():
                depth_loss_mask = torch.ones(outputs["disp", 0].shape, device=self.device)
                depth_loss_mask[:, :, int(outputs["disp", 0].shape[2] * 0.9):, :] = 0
            pseudo_depth_loss = berhu(outputs["disp", 0], inputs["pseudo_depth"], depth_loss_mask)
            pseudo_depth_loss *= self.cfg["training"]["pseudo_depth_lambda"]
            self.scaler.scale(pseudo_depth_loss).backward(retain_graph=True)
        else:
            pseudo_depth_loss = torch.tensor(0)

        # Train segmentation
        if self.cfg["training"]["segmentation_lambda"] > 0:
            with autocast(enabled=self.cfg["training"]["amp"]):
                segmentation_loss = self.loss_fn(input=outputs["semantics"], target=inputs["lbl"])
                if "intermediate_semantics" in outputs:
                    segmentation_loss += self.loss_fn(input=outputs["intermediate_semantics"],
                                                      target=inputs["lbl"])
                    segmentation_loss /= 2
                segmentation_loss *= self.cfg["training"]["segmentation_lambda"]
                segmentation_total_loss = segmentation_loss
            self.scaler.scale(segmentation_total_loss).backward()
            if self.enable_unlabled_segmentation:
                unlabeled_loss, unlabeled_mono_loss = self.train_step_segmentation_unlabeled(unlabeled_inputs, step)
                segmentation_total_loss += unlabeled_loss
                mono_total_loss += unlabeled_mono_loss

        if self.cfg["training"].get("clip_grad_norm") is not None:
            # Unscales the gradients of optimizer's assigned params in-place
            self.scaler.unscale_(self.optimizer)
            # Since the gradients of optimizer's assigned params are unscaled, clips as usual:
            if self.cfg["training"].get("disable_depth_grad_clip", False):
                torch.nn.utils.clip_grad_norm_(get_params(self.model, ["encoder", "segmentation"]),
                                               self.cfg["training"]["clip_grad_norm"])
            else:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg["training"]["clip_grad_norm"])
        # optimizer's gradients are already unscaled, so scaler.step does not unscale them,
        # although it still skips optimizer.step() if the gradients contain infs or NaNs.
        self.scaler.step(self.optimizer)
        self.scaler.update()
        if isinstance(self.scheduler, ReduceLROnPlateau):
            self.scheduler.step(metrics=self.mIoU)
        else:
            self.scheduler.step()

        # update Mean teacher network
        if self.ema_model is not None:
            self.ema_model = self.update_ema_variables(ema_model=self.ema_model, model=self.model,
                                                       alpha_teacher=0.99, iteration=step)

        total_loss = segmentation_total_loss + mono_total_loss + pseudo_depth_loss

        return {
            'segmentation_loss': segmentation_loss.detach(),
            'mono_loss': mono_loss.detach(),
            'pseudo_depth_loss': pseudo_depth_loss.detach(),
            'feat_dist_loss': feat_dist_loss.detach(),
            'segmentation_total_loss': segmentation_total_loss.detach(),
            'mono_total_loss': mono_total_loss.detach(),
            'total_loss': total_loss.detach()
        }

    def setup_segmentation_unlabeled(self):
        if self.cfg["training"].get("unlabeled_segmentation", None) is None:
            self.enable_unlabled_segmentation = False
            return
        unlabeled_cfg = self.cfg["training"]["unlabeled_segmentation"]
        self.enable_unlabled_segmentation = True
        self.consistency_weight = unlabeled_cfg["consistency_weight"]
        self.mix_mask = unlabeled_cfg.get("mix_mask", None)
        self.unlabeled_color_jitter = unlabeled_cfg.get("color_jitter")
        self.unlabeled_blur = unlabeled_cfg.get("blur")
        self.only_unlabeled = unlabeled_cfg.get("only_unlabeled", True)
        self.only_labeled = unlabeled_cfg.get("only_labeled", False)
        self.mix_video = unlabeled_cfg.get("mix_video", False)
        assert not (self.only_unlabeled and self.only_labeled)
        self.mix_use_gt = unlabeled_cfg.get("mix_use_gt", False)
        self.unlabeled_debug_imgs = unlabeled_cfg.get("debug_images", False)
        self.depthcomp_margin = unlabeled_cfg["depthcomp_margin"]
        self.depthcomp_foreground_threshold = unlabeled_cfg["depthcomp_foreground_threshold"]
        self.unlabeled_backward_first_pseudo_label = unlabeled_cfg["backward_first_pseudo_label"]
        self.depthmix_online_depth = unlabeled_cfg.get("depthmix_online_depth", False)

    def generate_mix_mask(self, mode, argmax_u_w, unlabeled_imgs, depths):
        if mode == "class":
            for image_i in range(self.cfg["training"]["batch_size"]):
                classes = torch.unique(argmax_u_w[image_i])
                classes = classes[classes != 250]
                nclasses = classes.shape[0]
                classes = (classes[torch.Tensor(
                    np.random.choice(nclasses, int((nclasses - nclasses % 2) / 2), replace=False)).long()]).cuda()
                if image_i == 0:
                    MixMask = transformmasks.generate_class_mask(argmax_u_w[image_i], classes).unsqueeze(0).cuda()
                else:
                    MixMask = torch.cat(
                        (MixMask, transformmasks.generate_class_mask(argmax_u_w[image_i], classes).unsqueeze(0).cuda()))
        elif self.mix_mask == "depthcomp":
            assert self.cfg["training"]["batch_size"] == 2
            for image_i, other_image_i in [(0, 1), (1, 0)]:
                own_disp = depths[image_i]
                other_disp = depths[other_image_i]
                # Margin avoids too much of mixing road with same depth
                foreground_mask = torch.ge(own_disp, other_disp - self.depthcomp_margin).long()
                # Avoid hiding the real background of the other image with own a bit closer background
                if isinstance(self.depthcomp_foreground_threshold, tuple) or isinstance(
                        self.depthcomp_foreground_threshold, list):
                    ft_l, ft_u = self.depthcomp_foreground_threshold
                    assert ft_u > ft_l
                    ft = torch.rand(1, device=own_disp.device) * (ft_u - ft_l) + ft_l
                else:
                    ft = self.depthcomp_foreground_threshold
                foreground_mask *= torch.ge(own_disp, ft).long()
                if image_i == 0:
                    MixMask = foreground_mask
                else:
                    MixMask = torch.cat((MixMask, foreground_mask))
        elif mode == "depth":
            for image_i in range(self.cfg["training"]["batch_size"]):
                generated_depth = depths[image_i]
                min_depth = 0.1
                max_depth = 0.4
                depth_threshold = torch.rand(1, device=depths.device) * (max_depth - min_depth) + min_depth
                if image_i == 0:
                    MixMask = transformmasks.generate_depth_mask(generated_depth, depth_threshold).cuda()
                else:
                    MixMask = torch.cat(
                        (MixMask, transformmasks.generate_depth_mask(generated_depth, depth_threshold).cuda()))
        elif mode == "depthhist":
            for image_i in range(self.cfg["training"]["batch_size"]):
                generated_depth = depths[image_i]
                hist, bin_edges = np.histogram(torch.log(1 + generated_depth).flatten(), bins=100, density=True)
                # Exclude the first bin as it sometimes has a meaningless peak
                for v, e in zip(np.flip(hist)[1:], np.flip(bin_edges)[1:]):
                    if v > 1.5:
                        max_depth = torch.tensor([e])
                        break

                hist = np.cumsum(hist) / np.sum(hist)
                for v, e in zip(hist, bin_edges):
                    if v > 0.4:
                        min_depth = torch.tensor([e])
                        break
                depth_threshold = torch.rand(1) * (max_depth - min_depth) + min_depth
                if image_i == 0:
                    MixMask = transformmasks.generate_depth_mask(generated_depth, depth_threshold).cuda()
                else:
                    MixMask = torch.cat(
                        (MixMask, transformmasks.generate_depth_mask(generated_depth, depth_threshold).cuda()))
        elif mode is None:
            MixMask = torch.ones((unlabeled_imgs.shape[0], *unlabeled_imgs.shape[2:]), device=self.device)
        else:
            raise NotImplementedError(f"Unknown mix_mask {self.mix_mask}")

        return MixMask

    def calc_pseudo_label_loss(self, teacher_softmax, student_logits):
        max_probs, pseudo_label = torch.max(teacher_softmax, dim=1)
        pseudo_label[max_probs == 0] = self.unlabeled_loader.ignore_index
        unlabeled_weight = torch.sum(max_probs.ge(0.968).long() == 1).item() / np.prod(pseudo_label.shape)
        pixelWiseWeight = unlabeled_weight * torch.ones(max_probs.shape, device=self.device)
        L_u = self.consistency_weight * cross_entropy2d(input=student_logits, target=pseudo_label,
                                                        pixel_weights=pixelWiseWeight)
        return L_u, pseudo_label

    def train_step_segmentation_unlabeled(self, unlabeled_inputs, step):
        def strongTransform(parameters, data=None, target=None):
            assert ((data is not None) or (target is not None))
            data, target = transformsgpu.mix(mask=parameters["Mix"], data=data, target=target)
            data, target = transformsgpu.color_jitter(jitter=parameters["ColorJitter"], data=data, target=target)
            data, target = transformsgpu.gaussian_blur(blur=parameters["GaussianBlur"], data=data, target=None)
            return data, target

        unlabeled_imgs = unlabeled_inputs[("color_aug", 0, 0)]

        # First Step: Run teacher to generate pseudo labels
        self.ema_model.use_pose_net = False
        logits_u_w = self.ema_model(unlabeled_inputs)["semantics"]
        softmax_u_w = torch.softmax(logits_u_w.detach(), dim=1)
        if self.mix_use_gt:
            with torch.no_grad():
                for i in range(unlabeled_imgs.shape[0]):
                    # .data is necessary to access truth value of tensor
                    if unlabeled_inputs["is_labeled"][i].data:
                        softmax_u_w[i] = unlabeled_inputs["onehot_lbl"][i]
        _, argmax_u_w = torch.max(softmax_u_w, dim=1)

        # Second Step: Run student network on unaugmented data to generate depth for DepthMix, calculate monodepth loss,
        # and unaugmented segmentation pseudo label loss
        mono_loss = 0
        L_1 = 0
        if self.depthmix_online_depth:
            outputs_1 = self.model(unlabeled_inputs)
            if self.cfg["training"]["monodepth_lambda"] > 0:
                self.monodepth_loss_calculator_train.generate_images_pred(unlabeled_inputs, outputs_1)
                mono_losses = self.monodepth_loss_calculator_train.compute_losses(unlabeled_inputs, outputs_1)
                mono_lambda = self.cfg["training"]["monodepth_lambda"]
                mono_loss = mono_lambda * mono_losses["loss"]
                self.scaler.scale(mono_loss).backward(retain_graph=self.unlabeled_backward_first_pseudo_label)
                depths = outputs_1[("disp", 0)].detach()
                for j in range(depths.shape[0]):
                    dmin = torch.min(depths[j])
                    dmax = torch.max(depths[j])
                    depths[j] = torch.clamp(depths[j], dmin, dmax)
                    depths[j] = (depths[j] - dmin) / (dmax - dmin)
            else:
                depths = unlabeled_inputs["pseudo_depth"]
            if self.unlabeled_backward_first_pseudo_label:
                logits_1 = outputs_1["semantics"]
                L_1, _ = self.calc_pseudo_label_loss(teacher_softmax=softmax_u_w, student_logits=logits_1)
                self.scaler.scale(L_1).backward()
        elif "pseudo_depth" in unlabeled_inputs:
            depths = unlabeled_inputs["pseudo_depth"]
        else:
            depths = [None] * unlabeled_imgs.shape[0]

        # Third Step: Run Mix
        MixMask = self.generate_mix_mask(self.mix_mask, argmax_u_w, unlabeled_imgs, depths)

        strong_parameters = {"Mix": MixMask}
        if self.unlabeled_color_jitter:
            strong_parameters["ColorJitter"] = random.uniform(0, 1)
        else:
            strong_parameters["ColorJitter"] = 0
        if self.unlabeled_blur:
            strong_parameters["GaussianBlur"] = random.uniform(0, 1)
        else:
            strong_parameters["GaussianBlur"] = 0

        inputs_u_s, _ = strongTransform(strong_parameters, data=unlabeled_imgs)
        unlabeled_inputs[("color_aug", 0, 0)] = inputs_u_s
        outputs = self.model(unlabeled_inputs)
        logits_u_s = outputs["semantics"]

        softmax_u_w_mixed, _ = strongTransform(strong_parameters, data=softmax_u_w)
        L_2, pseudo_label = self.calc_pseudo_label_loss(teacher_softmax=softmax_u_w_mixed, student_logits=logits_u_s)
        self.scaler.scale(L_2).backward()

        for j, (f, img, ps_lab, mask, d) in enumerate(
                zip(unlabeled_inputs["filename"], inputs_u_s, pseudo_label, MixMask, depths)):
            if (step + 1) % self.cfg["training"]["print_interval"] != 0:
                continue
            fn = f"{self.cfg['training']['log_path']}/class_mix_debug/{step}_{j}_img.jpg"
            os.makedirs(os.path.dirname(fn), exist_ok=True)
            rows, cols = 2, 2
            fig, axs = plt.subplots(rows, cols, sharex='col', sharey='row',
                                    gridspec_kw={'hspace': 0, 'wspace': 0},
                                    figsize=(4 * cols, 4 * rows))
            axs[0][0].imshow(img.permute(1, 2, 0).cpu().numpy())
            axs[0][1].imshow(mask.float().cpu().numpy(), cmap="gray")
            if d is not None:
                axs[1][1].imshow(d[0].cpu().numpy(), cmap="plasma")
            axs[1][0].imshow(self.val_loader.decode_segmap_tocolor(ps_lab.cpu().numpy()))
            for ax in axs.flat:
                ax.axis("off")
            plt.savefig(fn)
            plt.close()

        return L_2 + L_1, mono_loss

    def train(self):
        self.start_iter = 0
        self.best_iou = -100.0
        if self.cfg["training"]["resume"] is not None:
            self.load_resume()
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = self.cfg["training"]["optimizer"]["lr"]

        train_loss_meter = AverageMeterDict()
        time_meter = AverageMeter()

        step = self.start_iter
        flag = True

        self.tensorboard_training_images()

        start_ts = time.time()
        while step <= self.cfg["training"]["train_iters"] and flag:
            for inputs in self.train_data_loader:
                # torch.cuda.empty_cache()
                step += 1
                losses = self.train_step(inputs, step)

                time_meter.update(time.time() - start_ts)
                train_loss_meter.update(losses)

                if (step + 1) % self.cfg["training"]["print_interval"] == 0:
                    fmt_str = "Iter [{}/{}]  Loss: {:.4f}  Time/Image: {:.4f}"
                    print_str = fmt_str.format(
                        step + 1,
                        self.cfg["training"]["train_iters"],
                        train_loss_meter.avgs["total_loss"],
                        time_meter.avg / self.cfg["training"]["batch_size"],
                    )

                    self.logger.info(print_str)
                    for k, v in train_loss_meter.avgs.items():
                        self.writer.add_scalar("training/" + k, v, step + 1)
                    self.writer.add_scalar("training/learning_rate", get_lr(self.optimizer), step + 1)
                    self.writer.add_scalar("training/time_per_image",
                                           time_meter.avg / self.cfg["training"]["batch_size"], step + 1)
                    self.writer.add_scalar("training/amp_scale", self.scaler.get_scale(), step + 1)
                    self.writer.add_scalar("training/memory", psutil.virtual_memory().used / 1e9, step + 1)
                    time_meter.reset()
                    train_loss_meter.reset()

                if (step + 1) % current_val_interval(self.cfg, step + 1) == 0 or (step + 1) == self.cfg["training"][
                    "train_iters"
                ]:
                    self.validate(step)

                    if self.mIoU >= self.best_iou:
                        self.best_iou = self.mIoU
                        if self.cfg["training"]["save_model"]:
                            self.save_resume(step)

                    if self.earlyStopping is not None:
                        if not self.earlyStopping.step(self.mIoU):
                            flag = False
                            break

                if (step + 1) == self.cfg["training"]["train_iters"]:
                    flag = False
                    break

                start_ts = time.time()

        return step

    def validate(self, step):
        self.model.eval()
        val_loss_meter = AverageMeterDict()
        running_metrics_val = runningScore(self.n_classes)
        imgs_to_save = []
        with torch.no_grad():
            for inputs_val in tqdm(self.val_data_loader,
                                   total=len(self.val_data_loader)):
                if self.cfg["model"]["disable_monodepth"]:
                    required_inputs = [("color_aug", 0, 0), "lbl"]
                else:
                    required_inputs = inputs_val.keys()
                for k, v in inputs_val.items():
                    if torch.is_tensor(v) and k in required_inputs:
                        inputs_val[k] = v.to(self.device, non_blocking=True)
                images_val = inputs_val[("color_aug", 0, 0)]
                with autocast(enabled=self.cfg["training"]["amp"]):
                    outputs = self.model(inputs_val)

                if self.cfg["training"]["segmentation_lambda"] > 0:
                    labels_val = inputs_val["lbl"]
                    semantics = outputs["semantics"]
                    val_segmentation_loss = self.loss_fn(input=semantics, target=labels_val)
                    # Handle inconsistent size between input and target
                    n, c, h, w = semantics.size()
                    nt, ht, wt = labels_val.size()
                    if h != ht and w != wt:  # upsample labels
                        semantics = F.interpolate(
                            semantics, size=(ht, wt),
                            mode="bilinear", align_corners=True
                        )
                    pred = semantics.data.max(1)[1].cpu().numpy()
                    gt = labels_val.data.cpu().numpy()

                    running_metrics_val.update(gt, pred)
                else:
                    pred = [None] * images_val.shape[0]
                    gt = [None] * images_val.shape[0]
                    val_segmentation_loss = torch.tensor(0)

                if not self.cfg["model"]["disable_monodepth"]:
                    if not self.cfg["model"]["disable_pose"]:
                        self.monodepth_loss_calculator_val.generate_images_pred(inputs_val, outputs)
                        mono_losses = self.monodepth_loss_calculator_val.compute_losses(inputs_val, outputs)
                        val_mono_loss = mono_losses["loss"]
                    else:
                        outputs.update(self.model.predict_test_disp(inputs_val))
                        self.monodepth_loss_calculator_val.generate_depth_test_pred(outputs)
                        val_mono_loss = torch.tensor(0)
                else:
                    outputs[("disp", 0)] = [None] * images_val.shape[0]
                    val_mono_loss = torch.tensor(0)

                if self.cfg["data"].get("depth_teacher", None) is not None:
                    # Crop away bottom of image with own car
                    with torch.no_grad():
                        depth_loss_mask = torch.ones(outputs["disp", 0].shape, device=self.device)
                        depth_loss_mask[:, :, int(outputs["disp", 0].shape[2] * 0.9):, :] = 0
                    val_pseudo_depth_loss = berhu(outputs["disp", 0], inputs_val["pseudo_depth"], depth_loss_mask,
                                              apply_log=self.cfg["training"].get("pseudo_depth_loss_log", False))
                else:
                    val_pseudo_depth_loss = torch.tensor(0)

                val_loss_meter.update({
                    "segmentation_loss": val_segmentation_loss.detach(),
                    "monodepth_loss": val_mono_loss.detach(),
                    "pseudo_depth_loss": val_pseudo_depth_loss.detach()
                })

                for img, label, output, depth in zip(images_val, gt, pred, outputs[("disp", 0)]):
                    if len(imgs_to_save) < self.cfg["training"]["n_tensorboard_imgs"]:
                        imgs_to_save.append([
                            img, label, output,
                            depth if depth is None else depth.detach()])

        for k, v in val_loss_meter.avgs.items():
            self.writer.add_scalar("validation/" + k, v, step + 1)
        if self.cfg["training"]["segmentation_lambda"] > 0:
            score, class_iou = running_metrics_val.get_scores()
            for k, v in score.items():
                print(k, v)
                self.writer.add_scalar("val_metrics/{}".format(k), v, step + 1)
            for k, v in class_iou.items():
                self.writer.add_scalar("val_metrics/cls_{}".format(k), v, step + 1)
            self.mIoU = score["Mean IoU : \t"]
            self.fwAcc = score["FreqW Acc : \t"]

        for j, imgs in enumerate(imgs_to_save):
            # Only log the first image as they won't change -> save memory
            if (step + 1) // current_val_interval(self.cfg, step + 1) == 1:
                self.img_writer.add_image(
                    "{}/{}_0image".format(self.run_id.replace('/', '_'), j), imgs[0], global_step=step + 1)
                if imgs[1] is not None:
                    colored_image = self.val_loader.decode_segmap_tocolor(imgs[1])
                    self.img_writer.add_image(
                        "{}/{}_1ground_truth".format(self.run_id.replace('/', '_'), j), colored_image,
                        global_step=step + 1, dataformats="HWC")
            if imgs[2] is not None:
                colored_image = self.val_loader.decode_segmap_tocolor(imgs[2])
                self.img_writer.add_image(
                    "{}/{}_2prediction".format(self.run_id.replace('/', '_'), j), colored_image, global_step=step + 1,
                    dataformats="HWC")
            if imgs[3] is not None:
                colored_image = _colorize(imgs[3], "plasma", max_percentile=100)
                self.img_writer.add_image(
                    "{}/{}_3depth".format(self.run_id.replace('/', '_'), j), colored_image, global_step=step + 1,
                    dataformats="HWC")
Example #10
0
class BaseModule(nn.Module):
    def __init__(self,
                 cuda=True,
                 warmup_ratio=0.1,
                 num_training_steps=1000,
                 device_idxs=(),
                 mixed_precision=False):
        super().__init__()

        # Other parameters
        self.num_warmup_steps = int(warmup_ratio * num_training_steps)
        self.num_training_steps = num_training_steps
        self.cuda = cuda
        if self.cuda:
            self.devices = device_idxs
        else:
            self.devices = ['cpu']
        self.model_device = device_idxs[0]
        self.mixed_precision = mixed_precision

        # Mixed precision training support
        if self.mixed_precision:
            self.scaler = GradScaler()

    def linear_scheduler(self, optimizer, last_epoch=-1):
        return lr_scheduler.LambdaLR(optimizer, self.lr_lambda, last_epoch)

    def lr_lambda(self, current_step):
        if current_step < self.num_warmup_steps:
            return float(current_step) / float(max(1, self.num_warmup_steps))
        return max(
            0.0, float(self.num_training_steps - current_step) / float(
                max(1, self.num_training_steps - self.num_warmup_steps))
        )

    def backward(self, r=1, l2=False):
        # Loss scaling (can be used for accumulation normalizing)
        self.loss_grad = self.loss_grad * r

        # L2 normalization
        if l2:
            if self.mixed_precision:
                grad_params = torch.autograd.grad(self.scaler.scale(self.loss_grad), self.parameters(),
                                                  create_graph=True)
                inv_scale = 1 / self.scaler.get_scale()
                grad_params = [p * inv_scale for p in grad_params]
            else:
                grad_params = torch.autograd.grad(self.loss_grad, self.parameters(), create_graph=True)
            with autocast(self.mixed_precision):
                grad_norm = 0
                for grad in grad_params:
                    grad_norm += grad.pow(2).sum()
                grad_norm = grad_norm.sqrt()
                self.loss_grad = self.loss_grad + grad_norm

        # Backward
        if self.mixed_precision:
            self.scaler.scale(self.loss_grad).backward()
        else:
            self.loss_grad.backward()

    def optimize(self, clip=True):
        if clip:
            if self.mixed_precision:
                self.scaler.unscale_(self.optimizer)
            nn.utils.clip_grad_norm_(self.parameters(), 1.0)
        if self.mixed_precision:
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            self.optimizer.step()
        self.scheduler.step()

    def save_model(self, checkpoint_name, state_dict_only=True):
        dataparallel = self.single_gpu()
        if not os.path.isdir('checkpoint'):
            os.makedirs('checkpoint', exist_ok=True)
        save_path = os.path.join('checkpoint', checkpoint_name + '.th')
        if state_dict_only:
            torch.save(self.state_dict(), save_path)
        else:
            torch.save(self, save_path)
        self.multi_gpus(dataparallel)
        saved_component = 'state dict' if state_dict_only else 'model'
        print(f'Saved {saved_component} to {save_path}')

    def load_model(self, path, is_state_dict=True):
        dataparallel = self.single_gpu()
        state_dict = torch.load(path, map_location='cpu')
        if not is_state_dict:
            state_dict = state_dict.state_dict()
        self.load_state_dict(state_dict)
        self.multi_gpus(dataparallel)
        loaded_component = 'state dict' if is_state_dict else 'model'
        print(f'Loaded {loaded_component} from {path}')

    @classmethod
    def tensor(cls, x):
        try:
            return torch.tensor(x)
        except:
            return torch.stack(x)

    @classmethod
    def get_pad_amount(cls, max_lens, x, first=False):
        zeros = torch.zeros_like(max_lens)
        if first:
            idxs = torch.stack([zeros, max_lens - torch.tensor(x.shape)])
        else:
            idxs = torch.stack([max_lens - torch.tensor(x.shape), zeros])
        return list(idxs.T.reshape(-1).flip(0))

    @classmethod
    def pad_seq(cls, x, val=0, first=False):
        if isinstance(x, torch.Tensor):
            return x
        try:
            return BaseModule.tensor(x)
        except:
            x = [BaseModule.pad_seq(x_, val=val, first=first) for x_ in x]
            max_lens = torch.tensor([max(x_.shape[i] for x_ in x) for i in range(x[0].ndim)])
            return torch.stack([pad(x_, pad=BaseModule.get_pad_amount(max_lens, x_, first), value=val) for x_ in x])

    @classmethod
    def getattr(cls, obj, name, *args, **kwargs):
        if '.' in name:
            split_index = name.index('.')
            return cls.getattr(getattr(obj, name[:split_index]), name[split_index + 1:], *args, **kwargs)
        return getattr(obj, name, *args, **kwargs)

    @classmethod
    def setattr(cls, obj, name, value):
        if '.' in name:
            split_index = name.index('.')
            return cls.setattr(getattr(obj, name[:split_index]), name[split_index + 1:], value)
        return setattr(obj, name, value)

    def single_gpu(self):
        dataparallel = set()
        for name, module in self.named_modules():
            if isinstance(module, nn.DataParallel):
                dataparallel.add(name)
        for name in dataparallel:
            BaseModule.setattr(self, name, BaseModule.getattr(self, name).module)
        return dataparallel

    def multi_gpus(self, modules):
        for name in modules:
            BaseModule.setattr(self, name, DataParallel(BaseModule.getattr(self, name), device_ids=self.devices, output_device=self.model_device))

    @classmethod
    def ratio(cls, x, y, ndigits=3):
        if y == 0:
            return 0
        return round(x / y, ndigits)

    def make_position_ids(self, attention_mask):
        position_ids = attention_mask.long().cumsum(-1) - 1
        position_ids.masked_fill_(attention_mask == 0, 0)
        return position_ids
Example #11
0
def train(net,
          data_loader,
          criterion,
          optimizer,
          epochs=10,
          save_every=20,
          model_path=None,
          use_drive=False,
          resume=False,
          reset=False,
          track_grad_norm=False,
          scheduler=None,
          plot=False,
          use_amp=False):
    "Training Loop"

    device = next(net.parameters()).device

    save_path, load_path = search_drive(model_path, use_drive)

    init_epoch = 0

    if load_path and os.path.exists(load_path) and not reset:
        checkpoint = torch.load(load_path, map_location=device)
        if 'net_state_dict' in checkpoint:
            net.load_state_dict(checkpoint['net_state_dict'])
        else:
            net.load_state_dict(checkpoint)
        if 'epoch' in checkpoint:
            init_epoch = checkpoint['epoch']
        if 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print("Training Checkpoint restored: " + load_path)
        if not resume:
            net.eval()
            return
    else:
        if model_path:
            print("No Checkpoint found / Reset.")
        if save_path:
            print("Path: " + save_path)

    assert valid_data_loader(
        data_loader), f"invalid data_loader: {data_loader}"

    net.train()

    USE_AMP = device.type == 'cuda' and use_amp
    if USE_AMP:
        scaler = GradScaler()

    TRACKING = None
    if plot:
        TRACKING = defaultdict(list, loss=[])

    print("Beginning training.", flush=True)

    with tqdmEpoch(epochs, len(data_loader)) as pbar:
        saved_epoch = 0
        for epoch in range(1 + init_epoch, 1 + init_epoch + epochs):
            total_count = 0.0
            total_loss = 0.0
            total_correct = 0.0
            grad_total = 0.0

            for inputs, labels in data_loader:
                optimizer.zero_grad()

                if USE_AMP:
                    with autocast():
                        outputs = net(inputs)
                        loss = criterion(outputs, labels)
                    scaler.scale(loss).backward()
                    grad_scale = scaler.get_scale()
                else:
                    outputs = net(inputs)
                    loss = criterion(outputs, labels)
                    loss.backward()
                    grad_scale = 1

                if track_grad_norm:
                    for param in net.parameters():
                        grad_total += (param.grad.norm(2) / grad_scale).item()

                if USE_AMP:
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()

                batch_size = len(inputs)
                total_count += batch_size
                total_loss += loss.item() * batch_size
                total_correct += count_correct(outputs, labels)
                pbar.set_postfix(
                    loss=total_loss / total_count,
                    acc=f"{total_correct / total_count * 100:.0f}%",
                    chkpt=saved_epoch,
                    refresh=False,
                )
                pbar.update()

            loss = total_loss / total_count
            accuracy = total_correct / total_count
            # grad_norm = grad_total  / total_count

            if scheduler is not None:
                scheduler.step(loss)

            if TRACKING:
                TRACKING['loss'].append(loss)
                TRACKING['accuracy'].append(accuracy)
                if track_grad_norm:
                    TRACKING['|grad|'].append(grad_norm)

            if save_path is not None \
                and (save_every is not None
                     and epoch % save_every == 0
                     or epoch == init_epoch + epochs - 1):
                torch.save(
                    {
                        'epoch': epoch,
                        'net_state_dict': net.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                    }, save_path)
                saved_epoch = epoch

            pbar.set_postfix(
                loss=total_loss / total_count,
                acc=f"{total_correct / total_count * 100:.0f}%",
                chkpt=saved_epoch,
            )

    print(flush=True, end='')
    # net.eval()

    if TRACKING:
        plot_metrics(TRACKING, step_start=init_epoch)
        # plt.xlabel('epochs')
        # plt.show()
        return TRACKING
Example #12
0
def invert(
    data_loader,
    loss_fn,
    optimizer,
    steps=10,
    scheduler=None,
    use_amp=False,
    grad_norm_fn=None,
    callback_fn=None,
    plot=False,
    fig_path=None,
    track_per_batch=False,
    track_grad_norm=False,
    print_grouped=False,
):

    assert valid_data_loader(
        data_loader), f"invalid data_loader: {data_loader}"

    params = sum((p_group['params'] for p_group in optimizer.param_groups), [])
    lrs = [p_group['lr'] for p_group in optimizer.param_groups]
    device = params[0].device
    USE_AMP = (device.type == 'cuda') and use_amp
    if USE_AMP:
        scaler = GradScaler()

    num_batches = len(data_loader)
    track_len = steps * num_batches if track_per_batch else steps
    metrics = pd.DataFrame({'step': [None] * track_len})

    def process_result(res):
        if isinstance(res, dict):
            loss = res['loss']
            info = res
            for k, v in info.items():
                info[k] = v.item() if isinstance(v, torch.Tensor) else v
        else:
            loss = res
            info = {'loss': loss.item()}
        return loss, info

    print(flush=True)

    if callback_fn:
        callback_fn(0, None)

    with tqdmEpoch(steps, num_batches) as pbar:
        for epoch in range(steps):
            for batch_i, data in enumerate(data_loader):

                optimizer.zero_grad()

                if USE_AMP:
                    with autocast():
                        res = loss_fn(data)
                    loss, info = process_result(res)
                    scaler.scale(loss).backward()
                    grad_scale = scaler.get_scale()
                else:
                    res = loss_fn(data)
                    loss, info = process_result(res)
                    loss.backward()
                    grad_scale = 1

                if USE_AMP:
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()

                if scheduler is not None:
                    scheduler.step(loss)

                if track_grad_norm or grad_norm_fn:
                    # XXX: probably shouldn't multiply with lr
                    total_norm = torch.norm(
                        torch.stack([
                            p.grad.detach().norm() / grad_scale  # * lr
                            for p, lr in zip(params, lrs)
                        ])).item()

                    if grad_norm_fn:
                        rescale_coef = grad_norm_fn(total_norm) / total_norm
                        for param in params:
                            param.grad.detach().mul_(rescale_coef)

                    info['|grad|'] = total_norm

                pbar.set_postfix(**{
                    k: v
                    for k, v in info.items() if ']' not in k
                },
                                 refresh=False)
                pbar.update()

                if track_per_batch:
                    batch_total = epoch * num_batches + batch_i
                    step = batch_total
                    # step = epoch + (batch_i + 1) / num_batches
                else:
                    step = epoch
                    # step = epoch + 1 + batch_i / num_batches

                for k, v in info.items():
                    if k not in metrics:  # add new column
                        metrics[k] = None
                    if metrics[k][step] is None:
                        metrics[k][step] = v
                    else:
                        metrics[k][step] += v

                if not track_per_batch and batch_i == 0:
                    metrics['step'][epoch] = epoch + 1
                if track_per_batch:
                    metrics['step'][batch_total] = (batch_total +
                                                    1) / num_batches
                # batch end

            if not track_per_batch:
                for k, v in metrics.items():
                    if k != 'step':
                        metrics[k][epoch] /= num_batches

            if callback_fn:
                callback_fn(epoch + 1, metrics.iloc[step])
            # epoch end

    print(flush=True)

    if plot and steps > 1:
        plot_metrics(metrics, fig_path=fig_path, smoothing=0)

    return metrics
Example #13
0
class Trainer:
    def __init__(
        self,
        name="default",
        results_dir="results",
        models_dir="models",
        base_dir="./",
        optimizer="adam",
        latent_dim=256,
        image_size=128,
        fmap_max=512,
        transparent=False,
        batch_size=4,
        gp_weight=10,
        gradient_accumulate_every=1,
        attn_res_layers=[],
        sle_spatial=False,
        disc_output_size=5,
        antialias=False,
        lr=2e-4,
        lr_mlp=1.0,
        ttur_mult=1.0,
        save_every=1000,
        evaluate_every=1000,
        trunc_psi=0.6,
        aug_prob=None,
        aug_types=["translation", "cutout"],
        dataset_aug_prob=0.0,
        calculate_fid_every=None,
        is_ddp=False,
        rank=0,
        world_size=1,
        log=False,
        amp=False,
        *args,
        **kwargs,
    ):
        self.GAN_params = [args, kwargs]
        self.GAN = None

        self.name = name

        base_dir = Path(base_dir)
        self.base_dir = base_dir
        self.results_dir = base_dir / results_dir
        self.models_dir = base_dir / models_dir
        self.config_path = self.models_dir / name / ".config.json"

        assert is_power_of_two(
            image_size
        ), "image size must be a power of 2 (64, 128, 256, 512, 1024)"
        assert all(
            map(is_power_of_two, attn_res_layers)
        ), "resolution layers of attention must all be powers of 2 (16, 32, 64, 128, 256, 512)"

        self.optimizer = optimizer
        self.latent_dim = latent_dim
        self.image_size = image_size
        self.fmap_max = fmap_max
        self.transparent = transparent

        self.aug_prob = aug_prob
        self.aug_types = aug_types

        self.lr = lr
        self.ttur_mult = ttur_mult
        self.batch_size = batch_size
        self.gradient_accumulate_every = gradient_accumulate_every

        self.gp_weight = gp_weight

        self.evaluate_every = evaluate_every
        self.save_every = save_every
        self.steps = 0

        self.generator_top_k_gamma = 0.99
        self.generator_top_k_frac = 0.5

        self.attn_res_layers = attn_res_layers
        self.sle_spatial = sle_spatial
        self.disc_output_size = disc_output_size
        self.antialias = antialias

        self.d_loss = 0
        self.g_loss = 0
        self.last_gp_loss = None
        self.last_recon_loss = None
        self.last_fid = None

        self.init_folders()

        self.loader = None
        self.dataset_aug_prob = dataset_aug_prob

        self.calculate_fid_every = calculate_fid_every

        self.is_ddp = is_ddp
        self.is_main = rank == 0
        self.rank = rank
        self.world_size = world_size

        self.syncbatchnorm = is_ddp

        self.amp = amp
        self.G_scaler = None
        self.D_scaler = None
        if self.amp:
            self.G_scaler = GradScaler()
            self.D_scaler = GradScaler()

    @property
    def image_extension(self):
        return "jpg" if not self.transparent else "png"

    @property
    def checkpoint_num(self):
        return floor(self.steps // self.save_every)

    def init_GAN(self):
        args, kwargs = self.GAN_params

        # set some global variables before instantiating GAN

        global norm_class
        global Blur

        norm_class = nn.SyncBatchNorm if self.syncbatchnorm else nn.BatchNorm2d
        Blur = nn.Identity if not self.antialias else Blur

        # handle bugs when
        # switching from multi-gpu back to single gpu

        if self.syncbatchnorm and not self.is_ddp:
            import torch.distributed as dist

            os.environ["MASTER_ADDR"] = "localhost"
            os.environ["MASTER_PORT"] = "12355"
            dist.init_process_group("nccl", rank=0, world_size=1)

        # instantiate GAN

        self.GAN = LightweightGAN(
            optimizer=self.optimizer,
            lr=self.lr,
            latent_dim=self.latent_dim,
            attn_res_layers=self.attn_res_layers,
            sle_spatial=self.sle_spatial,
            image_size=self.image_size,
            ttur_mult=self.ttur_mult,
            fmap_max=self.fmap_max,
            disc_output_size=self.disc_output_size,
            transparent=self.transparent,
            rank=self.rank,
            *args,
            **kwargs,
        )

        if self.is_ddp:
            ddp_kwargs = {
                "device_ids": [self.rank],
                "output_device": self.rank,
                "find_unused_parameters": True,
            }

            self.G_ddp = DDP(self.GAN.G, **ddp_kwargs)
            self.D_ddp = DDP(self.GAN.D, **ddp_kwargs)
            self.D_aug_ddp = DDP(self.GAN.D_aug, **ddp_kwargs)

    def write_config(self):
        self.config_path.write_text(json.dumps(self.config()))

    def load_config(self):
        config = (
            self.config()
            if not self.config_path.exists()
            else json.loads(self.config_path.read_text())
        )
        self.image_size = config["image_size"]
        self.transparent = config["transparent"]
        self.syncbatchnorm = config["syncbatchnorm"]
        self.disc_output_size = config["disc_output_size"]
        self.attn_res_layers = config.pop("attn_res_layers", [])
        self.sle_spatial = config.pop("sle_spatial", False)
        self.optimizer = config.pop("optimizer", "adam")
        self.fmap_max = config.pop("fmap_max", 512)
        del self.GAN
        self.init_GAN()

    def config(self):
        return {
            "image_size": self.image_size,
            "transparent": self.transparent,
            "syncbatchnorm": self.syncbatchnorm,
            "disc_output_size": self.disc_output_size,
            "optimizer": self.optimizer,
            "attn_res_layers": self.attn_res_layers,
            "sle_spatial": self.sle_spatial,
        }

    def set_data_src(self, folder):
        self.dataset = ImageDataset(
            folder,
            self.image_size,
            transparent=self.transparent,
            aug_prob=self.dataset_aug_prob,
        )
        sampler = (
            DistributedSampler(
                self.dataset, rank=self.rank, num_replicas=self.world_size, shuffle=True
            )
            if self.is_ddp
            else None
        )
        dataloader = DataLoader(
            self.dataset,
            num_workers=math.ceil(NUM_CORES / self.world_size),
            batch_size=math.ceil(self.batch_size / self.world_size),
            sampler=sampler,
            shuffle=not self.is_ddp,
            drop_last=True,
            pin_memory=True,
        )
        self.loader = cycle(dataloader)

        # auto set augmentation prob for user if dataset is detected to be low
        num_samples = len(self.dataset)
        if not exists(self.aug_prob) and num_samples < 1e5:
            self.aug_prob = min(0.5, (1e5 - num_samples) * 3e-6)
            print(
                f"autosetting augmentation probability to {round(self.aug_prob * 100)}%"
            )

    def train(self):
        assert exists(
            self.loader
        ), "You must first initialize the data source with `.set_data_src(<folder of images>)`"
        device = torch.device(f"cuda:{self.rank}")

        if not exists(self.GAN):
            self.init_GAN()

        self.GAN.train()
        total_disc_loss = torch.zeros([], device=device)
        total_gen_loss = torch.zeros([], device=device)

        batch_size = math.ceil(self.batch_size / self.world_size)

        # image_size = self.GAN.image_size
        latent_dim = self.GAN.latent_dim

        aug_prob = default(self.aug_prob, 0)
        aug_types = self.aug_types
        aug_kwargs = {"prob": aug_prob, "types": aug_types}

        G = self.GAN.G if not self.is_ddp else self.G_ddp
        # D = self.GAN.D if not self.is_ddp else self.D_ddp
        D_aug = self.GAN.D_aug if not self.is_ddp else self.D_aug_ddp

        apply_gradient_penalty = self.steps % 4 == 0

        # amp related contexts and functions

        amp_context = autocast if self.amp else null_context

        def backward(amp, loss, scaler):
            if amp:
                return scaler.scale(loss).backward()
            loss.backward()

        def optimizer_step(amp, optimizer, scaler):
            if amp:
                scaler.step(optimizer)
                scaler.update()
                return
            optimizer.step()

        backward = partial(backward, self.amp)
        optimizer_step = partial(optimizer_step, self.amp)

        # train discriminator
        self.GAN.D_opt.zero_grad()
        for i in gradient_accumulate_contexts(
            self.gradient_accumulate_every, self.is_ddp, ddps=[D_aug, G]
        ):
            latents = torch.randn(batch_size, latent_dim).cuda(self.rank)
            image_batch = next(self.loader).cuda(self.rank)
            image_batch.requires_grad_()

            with amp_context():
                generated_images = G(latents)
                fake_output, fake_output_32x32, _ = D_aug(
                    generated_images.detach(), detach=True, **aug_kwargs
                )

                real_output, real_output_32x32, real_aux_loss = D_aug(
                    image_batch, calc_aux_loss=True, **aug_kwargs
                )

                real_output_loss = real_output
                fake_output_loss = fake_output

                divergence = hinge_loss(real_output_loss, fake_output_loss)
                divergence_32x32 = hinge_loss(real_output_32x32, fake_output_32x32)
                disc_loss = divergence + divergence_32x32

                aux_loss = real_aux_loss
                disc_loss = disc_loss + aux_loss

            if apply_gradient_penalty:
                outputs = [real_output, real_output_32x32]
                outputs = (
                    list(map(self.D_scaler.scale, outputs)) if self.amp else outputs
                )

                scaled_gradients = torch_grad(
                    outputs=outputs,
                    inputs=image_batch,
                    grad_outputs=list(
                        map(
                            lambda t: torch.ones(t.size(), device=image_batch.device),
                            outputs,
                        )
                    ),
                    create_graph=True,
                    retain_graph=True,
                    only_inputs=True,
                )[0]

                inv_scale = (1.0 / self.D_scaler.get_scale()) if self.amp else 1.0
                gradients = scaled_gradients * inv_scale

                with amp_context():
                    gradients = gradients.reshape(batch_size, -1)
                    gp = self.gp_weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean()

                    if not torch.isnan(gp):
                        disc_loss = disc_loss + gp
                        self.last_gp_loss = gp.clone().detach().item()

            with amp_context():
                disc_loss = disc_loss / self.gradient_accumulate_every

            disc_loss.register_hook(raise_if_nan)
            backward(disc_loss, self.D_scaler)
            total_disc_loss += divergence

        self.last_recon_loss = aux_loss.item()
        self.d_loss = float(total_disc_loss.item() / self.gradient_accumulate_every)
        optimizer_step(self.GAN.D_opt, self.D_scaler)

        # train generator

        self.GAN.G_opt.zero_grad()

        for i in gradient_accumulate_contexts(
            self.gradient_accumulate_every, self.is_ddp, ddps=[G, D_aug]
        ):
            latents = torch.randn(batch_size, latent_dim).cuda(self.rank)

            with amp_context():
                generated_images = G(latents)
                fake_output, fake_output_32x32, _ = D_aug(
                    generated_images, **aug_kwargs
                )
                fake_output_loss = fake_output.mean(dim=1) + fake_output_32x32.mean(
                    dim=1
                )

                epochs = (
                    self.steps * batch_size * self.gradient_accumulate_every
                ) / len(self.dataset)
                k_frac = max(
                    self.generator_top_k_gamma ** epochs, self.generator_top_k_frac
                )
                k = math.ceil(batch_size * k_frac)

                if k != batch_size:
                    fake_output_loss, _ = fake_output_loss.topk(k=k, largest=False)

                loss = fake_output_loss.mean()
                gen_loss = loss

                gen_loss = gen_loss / self.gradient_accumulate_every
            gen_loss.register_hook(raise_if_nan)
            backward(gen_loss, self.G_scaler)
            total_gen_loss += loss

        self.g_loss = float(total_gen_loss.item() / self.gradient_accumulate_every)
        optimizer_step(self.GAN.G_opt, self.G_scaler)

        # calculate moving averages

        if self.is_main and self.steps % 10 == 0 and self.steps > 20000:
            self.GAN.EMA()

        if self.is_main and self.steps <= 25000 and self.steps % 1000 == 2:
            self.GAN.reset_parameter_averaging()

        # save from NaN errors

        if any(torch.isnan(l) for l in (total_gen_loss, total_disc_loss)):
            print(
                f"NaN detected for generator or discriminator. Loading from checkpoint #{self.checkpoint_num}"
            )
            self.load(self.checkpoint_num)
            raise NanException

        del total_disc_loss
        del total_gen_loss

        # periodically save results

        if self.is_main:
            if self.steps % self.save_every == 0:
                self.save(self.checkpoint_num)

            if self.steps % self.evaluate_every == 0 or (
                self.steps % 100 == 0 and self.steps < 20000
            ):
                self.evaluate(floor(self.steps / self.evaluate_every))

            if (
                exists(self.calculate_fid_every)
                and self.steps % self.calculate_fid_every == 0
                and self.steps != 0
            ):
                num_batches = math.ceil(CALC_FID_NUM_IMAGES / self.batch_size)
                fid = self.calculate_fid(num_batches)
                self.last_fid = fid

                with open(
                    str(self.results_dir / self.name / "fid_scores.txt"), "a"
                ) as f:
                    f.write(f"{self.steps},{fid}\n")

        self.steps += 1

    @torch.no_grad()
    def evaluate(self, num=0, num_image_tiles=8, trunc=1.0):
        self.GAN.eval()

        ext = self.image_extension
        num_rows = num_image_tiles

        latent_dim = self.GAN.latent_dim
        # image_size = self.GAN.image_size

        # latents and noise

        latents = torch.randn((num_rows ** 2, latent_dim)).cuda(self.rank)

        # regular

        generated_images = self.generate_truncated(self.GAN.G, latents)
        torchvision.utils.save_image(
            generated_images,
            str(self.results_dir / self.name / f"{str(num)}.{ext}"),
            nrow=num_rows,
        )

        # moving averages

        generated_images = self.generate_truncated(self.GAN.GE, latents)
        torchvision.utils.save_image(
            generated_images,
            str(self.results_dir / self.name / f"{str(num)}-ema.{ext}"),
            nrow=num_rows,
        )

    @torch.no_grad()
    def calculate_fid(self, num_batches):
        torch.cuda.empty_cache()

        real_path = str(self.results_dir / self.name / "fid_real") + "/"
        fake_path = str(self.results_dir / self.name / "fid_fake") + "/"

        # remove any existing files used for fid calculation and recreate directories
        rmtree(real_path, ignore_errors=True)
        rmtree(fake_path, ignore_errors=True)
        os.makedirs(real_path)
        os.makedirs(fake_path)

        for batch_num in tqdm(
            range(num_batches), desc="calculating FID - saving reals"
        ):
            real_batch = next(self.loader)
            for k in range(real_batch.size(0)):
                torchvision.utils.save_image(
                    real_batch[k, :, :, :],
                    real_path + "{}.png".format(k + batch_num * self.batch_size),
                )

        # generate a bunch of fake images in results / name / fid_fake
        self.GAN.eval()
        ext = self.image_extension

        latent_dim = self.GAN.latent_dim
        # image_size = self.GAN.image_size

        for batch_num in tqdm(
            range(num_batches), desc="calculating FID - saving generated"
        ):
            # latents and noise
            latents = torch.randn(self.batch_size, latent_dim).cuda(self.rank)

            # moving averages
            generated_images = self.generate_truncated(self.GAN.GE, latents)

            for j in range(generated_images.size(0)):
                torchvision.utils.save_image(
                    generated_images[j, :, :, :],
                    str(
                        Path(fake_path)
                        / f"{str(j + batch_num * self.batch_size)}-ema.{ext}"
                    ),
                )

        return fid_score.calculate_fid_given_paths(
            [real_path, fake_path], 256, True, 2048
        )

    @torch.no_grad()
    def generate_truncated(self, G, style, trunc_psi=0.75, num_image_tiles=8):
        generated_images = evaluate_in_chunks(self.batch_size, G, style)
        return generated_images.clamp_(0.0, 1.0)

    @torch.no_grad()
    def generate_interpolation(
        self, num=0, num_image_tiles=8, trunc=1.0, num_steps=100, save_frames=False
    ):
        self.GAN.eval()
        ext = self.image_extension
        num_rows = num_image_tiles

        latent_dim = self.GAN.latent_dim
        # image_size = self.GAN.image_size

        # latents and noise

        latents_low = torch.randn(num_rows ** 2, latent_dim).cuda(self.rank)
        latents_high = torch.randn(num_rows ** 2, latent_dim).cuda(self.rank)

        ratios = torch.linspace(0.0, 8.0, num_steps)

        frames = []
        for ratio in tqdm(ratios):
            interp_latents = slerp(ratio, latents_low, latents_high)
            generated_images = self.generate_truncated(self.GAN.GE, interp_latents)
            images_grid = torchvision.utils.make_grid(generated_images, nrow=num_rows)
            pil_image = transforms.ToPILImage()(images_grid.cpu())

            if self.transparent:
                background = Image.new("RGBA", pil_image.size, (255, 255, 255))
                pil_image = Image.alpha_composite(background, pil_image)

            frames.append(pil_image)

        frames[0].save(
            str(self.results_dir / self.name / f"{str(num)}.gif"),
            save_all=True,
            append_images=frames[1:],
            duration=80,
            loop=0,
            optimize=True,
        )

        if save_frames:
            folder_path = self.results_dir / self.name / f"{str(num)}"
            folder_path.mkdir(parents=True, exist_ok=True)
            for ind, frame in enumerate(frames):
                frame.save(str(folder_path / f"{str(ind)}.{ext}"))

    def print_log(self):
        data = [
            ("G", self.g_loss),
            ("D", self.d_loss),
            ("GP", self.last_gp_loss),
            ("SS", self.last_recon_loss),
            ("FID", self.last_fid),
        ]

        data = [d for d in data if exists(d[1])]
        log = " | ".join(map(lambda n: f"{n[0]}: {n[1]:.2f}", data))
        print(log)

    def model_name(self, num):
        return str(self.models_dir / self.name / f"model_{num}.pt")

    def init_folders(self):
        (self.results_dir / self.name).mkdir(parents=True, exist_ok=True)
        (self.models_dir / self.name).mkdir(parents=True, exist_ok=True)

    def clear(self):
        rmtree(str(self.models_dir / self.name), True)
        rmtree(str(self.results_dir / self.name), True)
        rmtree(str(self.config_path), True)
        self.init_folders()

    def save(self, num):
        save_data = {"GAN": self.GAN.state_dict(), "version": __version__}

        if self.amp:
            save_data = {
                **save_data,
                "G_scaler": self.G_scaler.state_dict(),
                "D_scaler": self.D_scaler.state_dict(),
            }

        torch.save(save_data, self.model_name(num))
        self.write_config()

    def load(self, num=-1):
        self.load_config()

        name = num
        if num == -1:
            file_paths = [
                p for p in Path(self.models_dir / self.name).glob("model_*.pt")
            ]
            saved_nums = sorted(map(lambda x: int(x.stem.split("_")[1]), file_paths))
            if len(saved_nums) == 0:
                return
            name = saved_nums[-1]
            print(f"continuing from previous epoch - {name}")

        self.steps = name * self.save_every

        load_data = torch.load(self.model_name(name))

        if "version" in load_data and self.is_main:
            print(f"loading from version {load_data['version']}")

        try:
            self.GAN.load_state_dict(load_data["GAN"])
        except Exception as e:
            print(
                "unable to load save model. please try downgrading the package to the version specified by the saved model"
            )
            raise e

        if self.amp:
            if "G_scaler" in load_data:
                self.G_scaler.load_state_dict(load_data["G_scaler"])
            if "D_scaler" in load_data:
                self.D_scaler.load_state_dict(load_data["D_scaler"])