Exemplo n.º 1
0
    def val_step(self, *inputs, **kwargs):
        """val_step() API for module wrapped by DistributedDataParallel.

        This method is basically the same as
        ``DistributedDataParallel.forward()``, while replacing
        ``self.module.forward()`` with ``self.module.val_step()``.
        It is compatible with PyTorch 1.1 - 1.5.
        """
        if getattr(self, 'require_forward_param_sync', True):
            self._sync_params()
        if self.device_ids:
            inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
            if len(self.device_ids) == 1:
                output = self.module.val_step(*inputs[0], **kwargs[0])
            else:
                outputs = self.parallel_apply(
                    self._module_copies[:len(inputs)], inputs, kwargs)
                output = self.gather(outputs, self.output_device)
        else:
            output = self.module.val_step(*inputs, **kwargs)

        if torch.is_grad_enabled() and getattr(
                self, 'require_backward_grad_sync', True):
            if self.find_unused_parameters:
                self.reducer.prepare_for_backward(list(_find_tensors(output)))
            else:
                self.reducer.prepare_for_backward([])
        else:
            if TORCH_VERSION > '1.2':
                self.require_forward_param_sync = False
        return output
Exemplo n.º 2
0
    def forward(self, *inputs, **kwargs):
        """Modified to support not scattering inputs when it's only on one device"""
        if self.require_forward_param_sync:
            self._sync_params()

        if self.device_ids:
            if len(self.device_ids) == 1:
                output = self.module(*inputs, **kwargs)
            else:
                inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
                outputs = self.parallel_apply(
                    self._module_copies[:len(inputs)], inputs, kwargs)
                output = self.gather(outputs, self.output_device)
        else:
            output = self.module(*inputs, **kwargs)

        if torch.is_grad_enabled() and self.require_backward_grad_sync:
            self.require_forward_param_sync = True
            # We'll return the output object verbatim since it is a freeform
            # object. We need to find any tensors in this object, though,
            # because we need to figure out which parameters were used during
            # this forward pass, to ensure we short circuit reduction for any
            # unused parameters. Only if `find_unused_parameters` is set.
            if self.find_unused_parameters:
                self.reducer.prepare_for_backward(list(_find_tensors(output)))
            else:
                self.reducer.prepare_for_backward([])
        else:
            self.require_forward_param_sync = False

        return output
Exemplo n.º 3
0
    def train_step(self, *inputs, **kwargs):
        """train_step() API for module wrapped by DistributedDataParallel.

        This method is basically the same as
        ``DistributedDataParallel.forward()``, while replacing
        ``self.module.forward()`` with ``self.module.train_step()``.
        It is compatible with PyTorch 1.1 - 1.5.
        """

        # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
        # end of backward to the beginning of forward.
        if ('parrots' not in TORCH_VERSION
                and digit_version(TORCH_VERSION) >= digit_version('1.7')
                and self.reducer._rebuild_buckets()):
            print_log(
                'Reducer buckets have been rebuilt in this iteration.',
                logger='mmcv')

        if ('parrots' not in TORCH_VERSION
                and digit_version(TORCH_VERSION) >= digit_version('1.11.0')):
            if self._check_sync_bufs_pre_fwd():
                self._sync_buffers()
        else:
            if (getattr(self, 'require_forward_param_sync', False)
                    and self.require_forward_param_sync):
                self._sync_params()

        if self.device_ids:
            inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
            if len(self.device_ids) == 1:
                output = self.module.train_step(*inputs[0], **kwargs[0])
            else:
                outputs = self.parallel_apply(
                    self._module_copies[:len(inputs)], inputs, kwargs)
                output = self.gather(outputs, self.output_device)
        else:
            output = self.module.train_step(*inputs, **kwargs)

        if ('parrots' not in TORCH_VERSION
                and digit_version(TORCH_VERSION) >= digit_version('1.11.0')):
            if self._check_sync_bufs_post_fwd():
                self._sync_buffers()

        if (torch.is_grad_enabled()
                and getattr(self, 'require_backward_grad_sync', False)
                and self.require_backward_grad_sync):
            if self.find_unused_parameters:
                self.reducer.prepare_for_backward(list(_find_tensors(output)))
            else:
                self.reducer.prepare_for_backward([])
        else:
            if ('parrots' not in TORCH_VERSION
                    and digit_version(TORCH_VERSION) > digit_version('1.2')):
                self.require_forward_param_sync = False
        return output
Exemplo n.º 4
0
    def train_step(self,
                   data_batch,
                   optimizer,
                   ddp_reducer=None,
                   running_status=None):
        """Train step function.

        This function implements the standard training iteration for
        asynchronous adversarial training. Namely, in each iteration, we first
        update discriminator and then compute loss for generator with the newly
        updated discriminator.

        As for distributed training, we use the ``reducer`` from ddp to
        synchronize the necessary params in current computational graph.

        Args:
            data_batch (dict): Input data from dataloader.
            optimizer (dict): Dict contains optimizer for generator and
                discriminator.
            ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp.
                It is used to prepare for ``backward()`` in ddp. Defaults to
                None.
            running_status (dict | None, optional): Contains necessary basic
                information for training, e.g., iteration number. Defaults to
                None.

        Returns:
            dict: Contains 'log_vars', 'num_samples', and 'results'.
        """
        # get data from data_batch
        real_imgs = data_batch['real_img']
        # If you adopt ddp, this batch size is local batch size for each GPU.
        # If you adopt dp, this batch size is the global batch size as usual.
        batch_size = real_imgs.shape[0]

        # get running status
        if running_status is not None:
            curr_iter = running_status['iteration']
        else:
            # dirty walkround for not providing running status
            if not hasattr(self, 'iteration'):
                self.iteration = 0
            curr_iter = self.iteration

        if dist.is_initialized():
            # randomly sample a scale for current training iteration
            chosen_scale = np.random.choice(self.multi_input_scales, 1,
                                            self.multi_scale_probability)[0]

            chosen_scale = torch.tensor(chosen_scale, dtype=torch.int).cuda()
            dist.broadcast(chosen_scale, 0)
            chosen_scale = int(chosen_scale.item())

        else:
            mmcv.print_log(
                'Distributed training has not been initialized. Degrade to '
                'the standard stylegan2',
                logger='mmgen',
                level=logging.WARN)
            chosen_scale = 0

        curr_size = (4 + chosen_scale) * (2**self.num_upblocks)
        # adjust the shape of images
        if real_imgs.shape[-2:] != (curr_size, curr_size):
            real_imgs = F.interpolate(real_imgs,
                                      size=(curr_size, curr_size),
                                      mode='bilinear',
                                      align_corners=True)

        # disc training
        set_requires_grad(self.discriminator, True)
        optimizer['discriminator'].zero_grad()
        # TODO: add noise sampler to customize noise sampling
        with torch.no_grad():
            fake_imgs = self.generator(None,
                                       num_batches=batch_size,
                                       chosen_scale=chosen_scale)

        # disc pred for fake imgs and real_imgs
        disc_pred_fake = self.discriminator(fake_imgs)
        disc_pred_real = self.discriminator(real_imgs)
        # get data dict to compute losses for disc
        data_dict_ = dict(gen=self.generator,
                          disc=self.discriminator,
                          disc_pred_fake=disc_pred_fake,
                          disc_pred_real=disc_pred_real,
                          fake_imgs=fake_imgs,
                          real_imgs=real_imgs,
                          iteration=curr_iter,
                          batch_size=batch_size,
                          gen_partial=partial(self.generator,
                                              chosen_scale=chosen_scale))

        loss_disc, log_vars_disc = self._get_disc_loss(data_dict_)

        # prepare for backward in ddp. If you do not call this function before
        # back propagation, the ddp will not dynamically find the used params
        # in current computation.
        if ddp_reducer is not None:
            ddp_reducer.prepare_for_backward(_find_tensors(loss_disc))
        loss_disc.backward()
        optimizer['discriminator'].step()

        # skip generator training if only train discriminator for current
        # iteration
        if (curr_iter + 1) % self.disc_steps != 0:
            results = dict(fake_imgs=fake_imgs.cpu(),
                           real_imgs=real_imgs.cpu())
            log_vars_disc['curr_size'] = curr_size
            outputs = dict(log_vars=log_vars_disc,
                           num_samples=batch_size,
                           results=results)
            if hasattr(self, 'iteration'):
                self.iteration += 1
            return outputs

        # generator training
        set_requires_grad(self.discriminator, False)
        optimizer['generator'].zero_grad()

        # TODO: add noise sampler to customize noise sampling
        fake_imgs = self.generator(None,
                                   num_batches=batch_size,
                                   chosen_scale=chosen_scale)
        disc_pred_fake_g = self.discriminator(fake_imgs)

        data_dict_ = dict(gen=self.generator,
                          disc=self.discriminator,
                          fake_imgs=fake_imgs,
                          disc_pred_fake_g=disc_pred_fake_g,
                          iteration=curr_iter,
                          batch_size=batch_size,
                          gen_partial=partial(self.generator,
                                              chosen_scale=chosen_scale))

        loss_gen, log_vars_g = self._get_gen_loss(data_dict_)

        # prepare for backward in ddp. If you do not call this function before
        # back propagation, the ddp will not dynamically find the used params
        # in current computation.
        if ddp_reducer is not None:
            ddp_reducer.prepare_for_backward(_find_tensors(loss_gen))

        loss_gen.backward()
        optimizer['generator'].step()

        log_vars = {}
        log_vars.update(log_vars_g)
        log_vars.update(log_vars_disc)
        log_vars['curr_size'] = curr_size

        results = dict(fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu())
        outputs = dict(log_vars=log_vars,
                       num_samples=batch_size,
                       results=results)

        if hasattr(self, 'iteration'):
            self.iteration += 1
        return outputs
Exemplo n.º 5
0
    def train_step(self,
                   data_batch,
                   optimizer,
                   ddp_reducer=None,
                   loss_scaler=None,
                   use_apex_amp=False,
                   running_status=None):
        """Train step function.

        This function implements the standard training iteration for
        asynchronous adversarial training. Namely, in each iteration, we first
        update discriminator and then compute loss for generator with the newly
        updated discriminator.

        As for distributed training, we use the ``reducer`` from ddp to
        synchronize the necessary params in current computational graph.

        Args:
            data_batch (dict): Input data from dataloader.
            optimizer (dict): Dict contains optimizer for generator and
                discriminator.
            ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp.
                It is used to prepare for ``backward()`` in ddp. Defaults to
                None.
            loss_scaler (:obj:`torch.cuda.amp.GradScaler` | None, optional):
                The loss/gradient scaler used for auto mixed-precision
                training. Defaults to ``None``.
            use_apex_amp (bool, optional). Whether to use apex.amp. Defaults to
                ``False``.
            running_status (dict | None, optional): Contains necessary basic
                information for training, e.g., iteration number. Defaults to
                None.

        Returns:
            dict: Contains 'log_vars', 'num_samples', and 'results'.
        """
        # get data from data_batch
        real_imgs = data_batch[self.real_img_key]
        # If you adopt ddp, this batch size is local batch size for each GPU.
        # If you adopt dp, this batch size is the global batch size as usual.
        batch_size = real_imgs.shape[0]

        # get running status
        if running_status is not None:
            curr_iter = running_status['iteration']
        else:
            # dirty walkround for not providing running status
            if not hasattr(self, 'iteration'):
                self.iteration = 0
            curr_iter = self.iteration

        # disc training
        set_requires_grad(self.discriminator, True)
        optimizer['discriminator'].zero_grad()
        # TODO: add noise sampler to customize noise sampling
        with torch.no_grad():
            fake_imgs = self.generator(None, num_batches=batch_size)

        # disc pred for fake imgs and real_imgs
        disc_pred_fake = self.discriminator(fake_imgs)
        disc_pred_real = self.discriminator(real_imgs)
        # get data dict to compute losses for disc
        data_dict_ = dict(gen=self.generator,
                          disc=self.discriminator,
                          disc_pred_fake=disc_pred_fake,
                          disc_pred_real=disc_pred_real,
                          fake_imgs=fake_imgs,
                          real_imgs=real_imgs,
                          iteration=curr_iter,
                          batch_size=batch_size,
                          loss_scaler=loss_scaler)

        loss_disc, log_vars_disc = self._get_disc_loss(data_dict_)

        # prepare for backward in ddp. If you do not call this function before
        # back propagation, the ddp will not dynamically find the used params
        # in current computation.
        if ddp_reducer is not None:
            ddp_reducer.prepare_for_backward(_find_tensors(loss_disc))

        if loss_scaler:
            # add support for fp16
            loss_scaler.scale(loss_disc).backward()
        elif use_apex_amp:
            from apex import amp
            with amp.scale_loss(loss_disc,
                                optimizer['discriminator'],
                                loss_id=0) as scaled_loss_disc:
                scaled_loss_disc.backward()
        else:
            loss_disc.backward()

        if loss_scaler:
            loss_scaler.unscale_(optimizer['discriminator'])
            # note that we do not contain clip_grad procedure
            loss_scaler.step(optimizer['discriminator'])
            # loss_scaler.update will be called in runner.train()
        else:
            optimizer['discriminator'].step()

        # skip generator training if only train discriminator for current
        # iteration
        if (curr_iter + 1) % self.disc_steps != 0:
            results = dict(fake_imgs=fake_imgs.cpu(),
                           real_imgs=real_imgs.cpu())
            outputs = dict(log_vars=log_vars_disc,
                           num_samples=batch_size,
                           results=results)
            if hasattr(self, 'iteration'):
                self.iteration += 1
            return outputs

        # generator training
        set_requires_grad(self.discriminator, False)
        optimizer['generator'].zero_grad()

        # TODO: add noise sampler to customize noise sampling
        fake_imgs = self.generator(None, num_batches=batch_size)
        disc_pred_fake_g = self.discriminator(fake_imgs)

        data_dict_ = dict(gen=self.generator,
                          disc=self.discriminator,
                          fake_imgs=fake_imgs,
                          disc_pred_fake_g=disc_pred_fake_g,
                          iteration=curr_iter,
                          batch_size=batch_size,
                          loss_scaler=loss_scaler)

        loss_gen, log_vars_g = self._get_gen_loss(data_dict_)

        # prepare for backward in ddp. If you do not call this function before
        # back propagation, the ddp will not dynamically find the used params
        # in current computation.
        if ddp_reducer is not None:
            ddp_reducer.prepare_for_backward(_find_tensors(loss_gen))

        if loss_scaler:
            loss_scaler.scale(loss_gen).backward()
        elif use_apex_amp:
            from apex import amp
            with amp.scale_loss(loss_gen, optimizer['generator'],
                                loss_id=1) as scaled_loss_disc:
                scaled_loss_disc.backward()
        else:
            loss_gen.backward()

        if loss_scaler:
            loss_scaler.unscale_(optimizer['generator'])
            # note that we do not contain clip_grad procedure
            loss_scaler.step(optimizer['generator'])
            # loss_scaler.update will be called in runner.train()
        else:
            optimizer['generator'].step()

        log_vars = {}
        log_vars.update(log_vars_g)
        log_vars.update(log_vars_disc)

        results = dict(fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu())
        outputs = dict(log_vars=log_vars,
                       num_samples=batch_size,
                       results=results)

        if hasattr(self, 'iteration'):
            self.iteration += 1
        return outputs
Exemplo n.º 6
0
    def train_step(self,
                   data_batch,
                   optimizer,
                   ddp_reducer=None,
                   running_status=None):
        """Train step function.

        This function implements the standard training iteration for
        asynchronous adversarial training. Namely, in each iteration, we first
        update discriminator and then compute loss for generator with the newly
        updated discriminator.

        As for distributed training, we use the ``reducer`` from ddp to
        synchronize the necessary params in current computational graph.

        Args:
            data_batch (dict): Input data from dataloader.
            optimizer (dict): Dict contains optimizer for generator and
                discriminator.
            ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp.
                It is used to prepare for ``backward()`` in ddp. Defaults to
                None.
            running_status (dict | None, optional): Contains necessary basic
                information for training, e.g., iteration number. Defaults to
                None.

        Returns:
            dict: Contains 'log_vars', 'num_samples', and 'results'.
        """

        # get running status
        if running_status is not None:
            curr_iter = running_status['iteration']
        else:
            # dirty walkround for not providing running status
            if not hasattr(self, 'iteration'):
                self.iteration = 0
            curr_iter = self.iteration

        # init each scale
        if curr_iter % self.train_cfg['iters_per_scale'] == 0:
            self.curr_stage += 1
            # load weights from prev scale
            self.get_module(self.generator,
                            'check_and_load_prev_weight')(self.curr_stage)
            self.get_module(self.discriminator,
                            'check_and_load_prev_weight')(self.curr_stage)
            # build optimizer for each scale
            g_module = self.get_module(self.generator, 'blocks')
            param_list = g_module[self.curr_stage].parameters()

            self.g_optim = torch.optim.Adam(param_list,
                                            lr=self.train_cfg['lr_g'],
                                            betas=(0.5, 0.999))
            d_module = self.get_module(self.discriminator, 'blocks')
            self.d_optim = torch.optim.Adam(
                d_module[self.curr_stage].parameters(),
                lr=self.train_cfg['lr_d'],
                betas=(0.5, 0.999))

            self.optimizer = dict(generator=self.g_optim,
                                  discriminator=self.d_optim)

            self.g_scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer=self.g_optim, **self.train_cfg['lr_scheduler_args'])
            self.d_scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer=self.d_optim, **self.train_cfg['lr_scheduler_args'])

        optimizer = self.optimizer

        # setup fixed noises and reals pyramid
        if curr_iter == 0 or len(self.reals) == 0:
            keys = [k for k in data_batch.keys() if 'real_scale' in k]
            scales = len(keys)
            self.reals = [data_batch[f'real_scale{s}'] for s in range(scales)]

            # here we do not padding fixed noises
            self.construct_fixed_noises()

        # disc training
        set_requires_grad(self.discriminator, True)
        for _ in range(self.train_cfg['disc_steps']):
            optimizer['discriminator'].zero_grad()
            # TODO: add noise sampler to customize noise sampling
            with torch.no_grad():
                fake_imgs = self.generator(data_batch['input_sample'],
                                           self.fixed_noises,
                                           self.noise_weights,
                                           rand_mode='rand',
                                           curr_scale=self.curr_stage)

            # disc pred for fake imgs and real_imgs
            disc_pred_fake = self.discriminator(fake_imgs.detach(),
                                                self.curr_stage)
            disc_pred_real = self.discriminator(self.reals[self.curr_stage],
                                                self.curr_stage)
            # get data dict to compute losses for disc
            data_dict_ = dict(iteration=curr_iter,
                              gen=self.generator,
                              disc=self.discriminator,
                              disc_pred_fake=disc_pred_fake,
                              disc_pred_real=disc_pred_real,
                              fake_imgs=fake_imgs,
                              real_imgs=self.reals[self.curr_stage],
                              disc_partial=partial(self.discriminator,
                                                   curr_scale=self.curr_stage))

            loss_disc, log_vars_disc = self._get_disc_loss(data_dict_)

            # prepare for backward in ddp. If you do not call this function
            # before back propagation, the ddp will not dynamically find the
            # used params in current computation.
            if ddp_reducer is not None:
                ddp_reducer.prepare_for_backward(_find_tensors(loss_disc))
            loss_disc.backward()
            optimizer['discriminator'].step()

        log_vars_disc.update(dict(curr_stage=self.curr_stage))

        # generator training
        set_requires_grad(self.discriminator, False)
        for _ in range(self.train_cfg['generator_steps']):
            optimizer['generator'].zero_grad()

            # TODO: add noise sampler to customize noise sampling
            fake_imgs = self.generator(data_batch['input_sample'],
                                       self.fixed_noises,
                                       self.noise_weights,
                                       rand_mode='rand',
                                       curr_scale=self.curr_stage)
            disc_pred_fake_g = self.discriminator(fake_imgs,
                                                  curr_scale=self.curr_stage)

            recon_imgs = self.generator(data_batch['input_sample'],
                                        self.fixed_noises,
                                        self.noise_weights,
                                        rand_mode='recon',
                                        curr_scale=self.curr_stage)

            data_dict_ = dict(iteration=curr_iter,
                              gen=self.generator,
                              disc=self.discriminator,
                              fake_imgs=fake_imgs,
                              recon_imgs=recon_imgs,
                              real_imgs=self.reals[self.curr_stage],
                              disc_pred_fake_g=disc_pred_fake_g)

            loss_gen, log_vars_g = self._get_gen_loss(data_dict_)

            # prepare for backward in ddp. If you do not call this function
            # before back propagation, the ddp will not dynamically find the
            # used params in current computation.
            if ddp_reducer is not None:
                ddp_reducer.prepare_for_backward(_find_tensors(loss_gen))

            loss_gen.backward()
            optimizer['generator'].step()

        # end of each scale
        # calculate noise weight for next scale
        if (curr_iter % self.train_cfg['iters_per_scale']
                == 0) and (self.curr_stage < len(self.reals) - 1):

            with torch.no_grad():
                g_recon = self.generator(data_batch['input_sample'],
                                         self.fixed_noises,
                                         self.noise_weights,
                                         rand_mode='recon',
                                         curr_scale=self.curr_stage)
                if isinstance(g_recon, dict):
                    g_recon = g_recon['fake_img']
                g_recon = F.interpolate(
                    g_recon, self.reals[self.curr_stage + 1].shape[-2:])

            mse = F.mse_loss(g_recon.detach(), self.reals[self.curr_stage + 1])
            rmse = torch.sqrt(mse)
            self.noise_weights.append(
                self.train_cfg.get('noise_weight_init', 0.1) * rmse.item())

            # try to release GPU memory.
            torch.cuda.empty_cache()

        log_vars = {}
        log_vars.update(log_vars_g)
        log_vars.update(log_vars_disc)

        results = dict(fake_imgs=fake_imgs.cpu(),
                       real_imgs=self.reals[self.curr_stage].cpu(),
                       recon_imgs=recon_imgs.cpu(),
                       curr_stage=self.curr_stage,
                       fixed_noises=self.fixed_noises,
                       noise_weights=self.noise_weights)
        outputs = dict(log_vars=log_vars, num_samples=1, results=results)

        # update lr scheduler
        self.d_scheduler.step()
        self.g_scheduler.step()

        if hasattr(self, 'iteration'):
            self.iteration += 1

        return outputs
Exemplo n.º 7
0
    def train_step(self,
                   data,
                   optimizer,
                   ddp_reducer=None,
                   loss_scaler=None,
                   use_apex_amp=False,
                   running_status=None):
        """The iteration step during training.

        This method defines an iteration step during training. Different from
        other repo in **MM** series, we allow the back propagation and
        optimizer updating to directly follow the iterative training schedule
        of DDPMs.
        Of course, we will show that you can also move the back
        propagation outside of this method, and then optimize the parameters
        in the optimizer hook. But this will cause extra GPU memory cost as a
        result of retaining computational graph. Otherwise, the training
        schedule should be modified in the detailed implementation.


        Args:
            optimizer (dict): Dict contains optimizer for denoising network.
            running_status (dict | None, optional): Contains necessary basic
                information for training, e.g., iteration number. Defaults to
                None.
        """

        # get running status
        if running_status is not None:
            curr_iter = running_status['iteration']
        else:
            # dirty walkround for not providing running status
            if not hasattr(self, 'iteration'):
                self.iteration = 0
            curr_iter = self.iteration

        real_imgs = data[self.real_img_key]
        # denoising training
        optimizer['denoising'].zero_grad()
        denoising_dict_ = self.reconstruction_step(data,
                                                   timesteps=self.sampler,
                                                   sample_model='orig',
                                                   return_noise=True)
        denoising_dict_['iteration'] = curr_iter
        denoising_dict_['real_imgs'] = real_imgs
        denoising_dict_['loss_scaler'] = loss_scaler

        loss, log_vars = self._get_loss(denoising_dict_)

        # prepare for backward in ddp. If you do not call this function before
        # back propagation, the ddp will not dynamically find the used params
        # in current computation.
        if ddp_reducer is not None:
            ddp_reducer.prepare_for_backward(_find_tensors(loss))

        if loss_scaler:
            # add support for fp16
            loss_scaler.scale(loss).backward()
        elif use_apex_amp:
            from apex import amp
            with amp.scale_loss(loss, optimizer['denoising'],
                                loss_id=0) as scaled_loss_disc:
                scaled_loss_disc.backward()
        else:
            loss.backward()

        if loss_scaler:
            loss_scaler.unscale_(optimizer['denoising'])
            # note that we do not contain clip_grad procedure
            loss_scaler.step(optimizer['denoising'])
            # loss_scaler.update will be called in runner.train()
        else:
            optimizer['denoising'].step()

        # image used for vislization
        results = dict(real_imgs=real_imgs,
                       x_0_pred=denoising_dict_['x_0_pred'],
                       x_t=denoising_dict_['diffusion_batches'],
                       x_t_1=denoising_dict_['fake_img'])
        outputs = dict(log_vars=log_vars,
                       num_samples=real_imgs.shape[0],
                       results=results)

        if hasattr(self, 'iteration'):
            self.iteration += 1

        return outputs
Exemplo n.º 8
0
    def train_step(self,
                   data_batch,
                   optimizer,
                   ddp_reducer=None,
                   running_status=None):
        """Training step function.

        Args:
            data_batch (dict): Dict of the input data batch.
            optimizer (dict[torch.optim.Optimizer]): Dict of optimizers for
                the generator and discriminator.
            ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp.
                It is used to prepare for ``backward()`` in ddp. Defaults to
                None.
            running_status (dict | None, optional): Contains necessary basic
                information for training, e.g., iteration number. Defaults to
                None.

        Returns:
            dict: Dict of loss, information for logger, the number of samples\
                and results for visualization.
        """
        # data
        target_domain = self._default_domain
        source_domain = self.get_other_domains(self._default_domain)[0]
        source_image = data_batch[f'img_{source_domain}']
        target_image = data_batch[f'img_{target_domain}']

        # get running status
        if running_status is not None:
            curr_iter = running_status['iteration']
        else:
            # dirty walkround for not providing running status
            if not hasattr(self, 'iteration'):
                self.iteration = 0
            curr_iter = self.iteration

        # forward generator
        outputs = dict()
        results = self(source_image,
                       target_domain=self._default_domain,
                       test_mode=False)
        outputs[f'real_{source_domain}'] = results['source']
        outputs[f'fake_{target_domain}'] = results['target']
        outputs[f'real_{target_domain}'] = target_image
        log_vars = dict()

        # discriminator
        set_requires_grad(self.discriminators, True)
        # optimize
        optimizer['discriminators'].zero_grad()
        loss_d, log_vars_d = self._get_disc_loss(outputs)
        log_vars.update(log_vars_d)
        # prepare for backward in ddp. If you do not call this function before
        # back propagation, the ddp will not dynamically find the used params
        # in current computation.
        if ddp_reducer is not None:
            ddp_reducer.prepare_for_backward(_find_tensors(loss_d))
        loss_d.backward()
        optimizer['discriminators'].step()

        # generator, no updates to discriminator parameters.
        if (curr_iter % self.disc_steps == 0
                and curr_iter >= self.disc_init_steps):
            set_requires_grad(self.discriminators, False)
            # optimize
            optimizer['generators'].zero_grad()
            loss_g, log_vars_g = self._get_gen_loss(outputs)
            log_vars.update(log_vars_g)
            # prepare for backward in ddp. If you do not call this function
            # before back propagation, the ddp will not dynamically find the
            # used params in current computation.
            if ddp_reducer is not None:
                ddp_reducer.prepare_for_backward(_find_tensors(loss_g))
            loss_g.backward()
            optimizer['generators'].step()

        if hasattr(self, 'iteration'):
            self.iteration += 1

        image_results = dict()
        image_results[f'real_{source_domain}'] = outputs[
            f'real_{source_domain}'].cpu()
        image_results[f'fake_{target_domain}'] = outputs[
            f'fake_{target_domain}'].cpu()
        image_results[f'real_{target_domain}'] = outputs[
            f'real_{target_domain}'].cpu()

        results = dict(log_vars=log_vars,
                       num_samples=len(outputs[f'real_{source_domain}']),
                       results=image_results)

        return results
Exemplo n.º 9
0
    def train_step(self,
                   data_batch,
                   optimizer,
                   ddp_reducer=None,
                   running_status=None):
        """Train step function.

        This function implements the standard training iteration for
        asynchronous adversarial training. Namely, in each iteration, we first
        update discriminator and then compute loss for generator with the newly
        updated discriminator.

        As for distributed training, we use the ``reducer`` from ddp to
        synchronize the necessary params in current computational graph.

        Args:
            data_batch (dict): Input data from dataloader.
            optimizer (dict): Dict contains optimizer for generator and
                discriminator.
            ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp.
                It is used to prepare for ``backward()`` in ddp. Defaults to
                None.
            running_status (dict | None, optional): Contains necessary basic
                information for training, e.g., iteration number. Defaults to
                None.

        Returns:
            dict: Contains 'log_vars', 'num_samples', and 'results'.
        """
        # get data from data_batch
        real_imgs = data_batch['real_img']
        # If you adopt ddp, this batch size is local batch size for each GPU.
        batch_size = real_imgs.shape[0]

        # get running status
        if running_status is not None:
            curr_iter = running_status['iteration']
        else:
            # dirty walkround for not providing running status
            if not hasattr(self, 'iteration'):
                self.iteration = 0
            curr_iter = self.iteration

        # check if optimizer from model
        if hasattr(self, 'optimizer'):
            optimizer = self.optimizer

        # update current stage
        self.curr_stage = int(
            min(sum(self.cum_nkimgs <= self.shown_nkimg.item()),
                len(self.scales) - 1))
        self.curr_scale = self.scales[self.curr_stage]
        self._curr_scale_int = self._next_scale_int.clone()
        # add new scale and update training status
        if self.curr_stage != self.prev_stage:
            self.prev_stage = self.curr_stage
            self._actual_nkimgs.append(self.shown_nkimg.item())
            # reset optimizer
            if self.reset_optim_for_new_scale:
                optim_cfg = deepcopy(self.train_cfg['optimizer_cfg'])
                optim_cfg['generator']['lr'] = self.g_lr_schedule.get(
                    str(self.curr_scale[0]), self.g_lr_base)
                optim_cfg['discriminator']['lr'] = self.d_lr_schedule.get(
                    str(self.curr_scale[0]), self.d_lr_base)
                self.optimizer = build_optimizers(self, optim_cfg)
                optimizer = self.optimizer
                mmcv.print_log('Reset optimizer for new scale', logger='mmgen')

        # update training configs, like transition weight for torgb layers.
        # get current transition weight for interpolating two torgb layers
        if self.curr_stage == 0:
            transition_weight = 1.
        else:
            transition_weight = (
                self.shown_nkimg.item() -
                self._actual_nkimgs[-1]) / self.transition_kimgs
            # clip to [0, 1]
            transition_weight = min(max(transition_weight, 0.), 1.)
        self._curr_transition_weight = torch.tensor(transition_weight).to(
            self._curr_transition_weight)

        # resize real image to target scale
        if real_imgs.shape[2:] == self.curr_scale:
            pass
        elif real_imgs.shape[2] >= self.curr_scale[0] and real_imgs.shape[
                3] >= self.curr_scale[1]:
            real_imgs = self.interp_real_to(real_imgs, size=self.curr_scale)
        else:
            raise RuntimeError(
                f'The scale of real image {real_imgs.shape[2:]} is smaller '
                f'than current scale {self.curr_scale}.')

        # disc training
        set_requires_grad(self.discriminator, True)
        optimizer['discriminator'].zero_grad()
        # TODO: add noise sampler to customize noise sampling
        with torch.no_grad():
            fake_imgs = self.generator(None,
                                       num_batches=batch_size,
                                       curr_scale=self.curr_scale[0],
                                       transition_weight=transition_weight)

        # disc pred for fake imgs and real_imgs
        disc_pred_fake = self.discriminator(
            fake_imgs,
            curr_scale=self.curr_scale[0],
            transition_weight=transition_weight)
        disc_pred_real = self.discriminator(
            real_imgs,
            curr_scale=self.curr_scale[0],
            transition_weight=transition_weight)
        # get data dict to compute losses for disc
        data_dict_ = dict(
            iteration=curr_iter,
            gen=self.generator,
            disc=self.discriminator,
            disc_pred_fake=disc_pred_fake,
            disc_pred_real=disc_pred_real,
            fake_imgs=fake_imgs,
            real_imgs=real_imgs,
            curr_scale=self.curr_scale[0],
            transition_weight=transition_weight,
            gen_partial=partial(self.generator,
                                curr_scale=self.curr_scale[0],
                                transition_weight=transition_weight),
            disc_partial=partial(self.discriminator,
                                 curr_scale=self.curr_scale[0],
                                 transition_weight=transition_weight))

        loss_disc, log_vars_disc = self._get_disc_loss(data_dict_)

        # prepare for backward in ddp. If you do not call this function before
        # back propagation, the ddp will not dynamically find the used params
        # in current computation.
        if ddp_reducer is not None:
            ddp_reducer.prepare_for_backward(_find_tensors(loss_disc))
        loss_disc.backward()
        optimizer['discriminator'].step()

        # update training log status
        if dist.is_initialized():
            _batch_size = batch_size * dist.get_world_size()
        else:
            if 'batch_size' not in running_status:
                raise RuntimeError(
                    'You should offer "batch_size" in running status for PGGAN'
                )
            _batch_size = running_status['batch_size']
        self.shown_nkimg += (_batch_size / 1000.)
        log_vars_disc.update(
            dict(shown_nkimg=self.shown_nkimg.item(),
                 curr_scale=self.curr_scale[0],
                 transition_weight=transition_weight))

        # skip generator training if only train discriminator for current
        # iteration
        if (curr_iter + 1) % self.disc_steps != 0:
            results = dict(fake_imgs=fake_imgs.cpu(),
                           real_imgs=real_imgs.cpu())
            outputs = dict(log_vars=log_vars_disc,
                           num_samples=batch_size,
                           results=results)
            if hasattr(self, 'iteration'):
                self.iteration += 1
            return outputs

        # generator training
        set_requires_grad(self.discriminator, False)
        optimizer['generator'].zero_grad()

        # TODO: add noise sampler to customize noise sampling
        fake_imgs = self.generator(None,
                                   num_batches=batch_size,
                                   curr_scale=self.curr_scale[0],
                                   transition_weight=transition_weight)
        disc_pred_fake_g = self.discriminator(
            fake_imgs,
            curr_scale=self.curr_scale[0],
            transition_weight=transition_weight)

        data_dict_ = dict(iteration=curr_iter,
                          gen=self.generator,
                          disc=self.discriminator,
                          fake_imgs=fake_imgs,
                          disc_pred_fake_g=disc_pred_fake_g)

        loss_gen, log_vars_g = self._get_gen_loss(data_dict_)

        # prepare for backward in ddp. If you do not call this function before
        # back propagation, the ddp will not dynamically find the used params
        # in current computation.
        if ddp_reducer is not None:
            ddp_reducer.prepare_for_backward(_find_tensors(loss_gen))

        loss_gen.backward()
        optimizer['generator'].step()

        log_vars = {}
        log_vars.update(log_vars_g)
        log_vars.update(log_vars_disc)
        log_vars.update({'batch_size': batch_size})

        results = dict(fake_imgs=fake_imgs.cpu(), real_imgs=real_imgs.cpu())
        outputs = dict(log_vars=log_vars,
                       num_samples=batch_size,
                       results=results)

        if hasattr(self, 'iteration'):
            self.iteration += 1

        # check if a new scale will be added in the next iteration
        _curr_stage = int(
            min(sum(self.cum_nkimgs <= self.shown_nkimg.item()),
                len(self.scales) - 1))
        # in the next iteration, we will switch to a new scale
        if _curr_stage != self.curr_stage:
            # `self._next_scale_int` is updated at the end of `train_step`
            self._next_scale_int = self._next_scale_int * 2
        return outputs
Exemplo n.º 10
0
    def train_step(self,
                   data_batch,
                   optimizer,
                   ddp_reducer=None,
                   running_status=None):
        """Training step function.

        Args:
            data_batch (dict): Dict of the input data batch.
            optimizer (dict[torch.optim.Optimizer]): Dict of optimizers for
                the generators and discriminators.
            ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp.
                It is used to prepare for ``backward()`` in ddp. Defaults to
                None.
            running_status (dict | None, optional): Contains necessary basic
                information for training, e.g., iteration number. Defaults to
                None.

        Returns:
            dict: Dict of loss, information for logger, the number of samples\
                and results for visualization.
        """
        # get running status
        if running_status is not None:
            curr_iter = running_status['iteration']
        else:
            # dirty walkround for not providing running status
            if not hasattr(self, 'iteration'):
                self.iteration = 0
            curr_iter = self.iteration

        # forward generators
        outputs = dict()
        for target_domain in self._reachable_domains:
            # fetch data by domain
            source_domain = self.get_other_domains(target_domain)[0]
            img = data_batch[f'img_{source_domain}']
            # translation process
            results = self(img, test_mode=False, target_domain=target_domain)
            outputs[f'real_{source_domain}'] = results['source']
            outputs[f'fake_{target_domain}'] = results['target']
            # cycle process
            results = self(results['target'],
                           test_mode=False,
                           target_domain=source_domain)
            outputs[f'cycle_{source_domain}'] = results['target']

        log_vars = dict()

        # discriminators
        set_requires_grad(self.discriminators, True)
        # optimize
        optimizer['discriminators'].zero_grad()
        loss_d, log_vars_d = self._get_disc_loss(outputs)
        log_vars.update(log_vars_d)
        if ddp_reducer is not None:
            ddp_reducer.prepare_for_backward(_find_tensors(loss_d))
        loss_d.backward()
        optimizer['discriminators'].step()

        # generators, no updates to discriminator parameters.
        if (curr_iter % self.disc_steps == 0
                and curr_iter >= self.disc_init_steps):
            set_requires_grad(self.discriminators, False)
            # optimize
            optimizer['generators'].zero_grad()
            loss_g, log_vars_g = self._get_gen_loss(outputs)
            log_vars.update(log_vars_g)
            if ddp_reducer is not None:
                ddp_reducer.prepare_for_backward(_find_tensors(loss_g))
            loss_g.backward()
            optimizer['generators'].step()

        if hasattr(self, 'iteration'):
            self.iteration += 1

        image_results = dict()
        for domain in self._reachable_domains:
            image_results[f'real_{domain}'] = outputs[f'real_{domain}'].cpu()
            image_results[f'fake_{domain}'] = outputs[f'fake_{domain}'].cpu()
        results = dict(log_vars=log_vars,
                       num_samples=len(outputs[f'real_{domain}']),
                       results=image_results)

        return results