Пример #1
0
    def model_create(self):
        from .. import models_64x64
        config = self.myargs.config.trainer.model

        if config.d_use_lp:
            self.logger.info("Use LayerNorm in D.")
            D = models_64x64.DiscriminatorWGANGPLN(3)
        else:
            self.logger.info("Use InstanceNorm in D.")
            D = models_64x64.DiscriminatorWGANGP(3)
        G = models_64x64.Generator(config.z_dim)
        G_ema = models_64x64.Generator(config.z_dim)
        ema = ema_model.EMA(G, G_ema, decay=0.9999, start_itr=config.ema_start)
        D.cuda()
        G.cuda()
        G_ema.cuda()

        self.myargs.checkpoint_dict['D'] = D
        self.myargs.checkpoint_dict['G'] = G
        self.myargs.checkpoint_dict['G_ema'] = G_ema

        self.logger.info_msg('Number of params in G: {} D: {}'.format(*[
            sum([p.data.nelement() for p in net.parameters()])
            for net in [G, D]
        ]))

        return G, D, G_ema, ema
Пример #2
0
    def model_create(self):
        from .. import models_64x64
        config = self.myargs.config.model

        D = models_64x64.DiscriminatorWGANGP(3)
        print(D)
        G = models_64x64.Generator(config.z_dim)
        print(G)
        G_ema = models_64x64.Generator(config.z_dim)
        ema = ema_model.EMA(G, G_ema, decay=0.9999, start_itr=config.ema_start)

        D = D.cuda(self.args.rank)
        G = G.cuda(self.args.rank)
        G_ema = G_ema.cuda(self.args.rank)

        self.myargs.checkpoint_dict['D'] = D
        self.myargs.checkpoint_dict['G'] = G
        self.myargs.checkpoint_dict['G_ema'] = G_ema

        print('Number of params in G: {} D: {}'.format(*[
            sum([p.data.nelement() for p in net.parameters()])
            for net in [G, D]
        ]))

        return G, D, G_ema, ema
Пример #3
0
    def model_create(self):
        from .. import models_64x64
        config = self.myargs.config.model

        D = models_64x64.DiscriminatorWGANGP(3)
        print(D)
        G = models_64x64.Generator(config.z_dim)
        print(G)
        G_ema = models_64x64.Generator(config.z_dim)
        ema = ema_model.EMA(G, G_ema, decay=0.9999, start_itr=config.ema_start)

        D.cuda()
        G.cuda()
        G_ema.cuda()
        if config.parallel:
            D = nn.DataParallel(D)
            G = nn.DataParallel(G)

        self.myargs.checkpoint_dict['D'] = D
        self.myargs.checkpoint_dict['G'] = G
        self.myargs.checkpoint_dict['G_ema'] = G_ema

        self.print_number_params(models=dict(G=G, D=D))

        return G, D, G_ema, ema
Пример #4
0
    def model_create(self):
        config = self.config.model
        resolution = imsize_dict[self.config.dataset.dataset]
        n_classes = nclass_dict[self.config.dataset.dataset]
        self.n_classes = n_classes
        if hasattr(self.config.dataset, 'attr'):
            assert n_classes == 2**len(self.config.dataset.attr)
        if getattr(config, 'use_cbn', False):
            self.G = SharedGeneratorCBN(resolution=resolution,
                                        n_classes=n_classes,
                                        no_optim=False,
                                        config=config.generator).cuda()
        else:
            self.G = SharedGeneratorNoSkip(resolution=resolution,
                                           no_optim=False,
                                           config=config.generator).cuda()

        self.G_optim = self.G.optim
        self.myargs.checkpoint_dict['G'] = self.G
        self.myargs.checkpoint_dict['G_optim'] = self.G.optim
        if getattr(config, 'use_ema', False):
            self.G_ema = copy.deepcopy(self.G)
            self.myargs.checkpoint_dict['G_ema'] = self.G_ema
            self.ema = ema_model.EMA(self.G,
                                     self.G_ema,
                                     decay=config.ema_decay,
                                     start_itr=config.ema_start)

        # Create controller
        controller_c = config.controller
        self.controller = Controller(
            n_classes=n_classes,
            num_layers=(self.G.num_layers +
                        1 if self.G.output_sample_arc else self.G.num_layers),
            num_branches=len(self.G.ops),
            config=controller_c).cuda()
        self.controller_optim = self.controller.optim
        self.myargs.checkpoint_dict['controller'] = self.controller
        self.myargs.checkpoint_dict['C_optim'] = self.controller.optim

        self.G_C = G_Controller(self.G, self.controller)
        # self.G_C = torch.nn.DataParallel(self.G_C)

        if getattr(config, 'use_cdisc', False):
            disc_c = config.discriminator_cond
            D_activation = activation_dict[disc_c.D_activation]
            self.D = BigGAN.Discriminator(
                **{
                    **disc_c, 'resolution': resolution,
                    'n_classes': n_classes,
                    'D_activation': D_activation
                }, **disc_c.optimizer).cuda()
            self.D_optim = self.D.optim
        else:
            from ..models.autogan_cifar10_a import Discriminator
            from ..models import optimizer_dict
            disc_c = config.discriminator
            disc_optim_c = disc_c.optimizer
            self.D = Discriminator(args=disc_c).cuda()
            adam_eps = getattr(disc_optim_c, 'adam_eps', 1.e-8)
            self.D_optim = self.optim = optimizer_dict[disc_optim_c.type](
                params=self.D.parameters(),
                lr=disc_optim_c.D_lr,
                betas=(disc_optim_c.D_B1, disc_optim_c.D_B2),
                eps=adam_eps)

        self.myargs.checkpoint_dict['D'] = self.D
        self.myargs.checkpoint_dict['D_optim'] = self.D_optim
        self.D = torch.nn.DataParallel(self.D)

        self.models = {'controller': self.controller, 'G': self.G, 'D': self.D}
        self.print_number_params(models=self.models)
Пример #5
0
    def model_create(self):
        import BigGAN as model
        import utils
        config = self.config.model

        Generator = model.Generator
        Discriminator = model.Discriminator
        G_D = model.G_D

        print('Create generator: {}'.format(Generator))
        self.resolution = utils.imsize_dict[self.config.dataset.dataset]
        self.n_classes = utils.nclass_dict[self.config.dataset.dataset]
        G_activation = utils.activation_dict[config.Generator.G_activation]
        self.G = Generator(
            **{
                **config.Generator, 'resolution': self.resolution,
                'n_classes': self.n_classes,
                'G_activation': G_activation
            }, **config.optimizer).cuda()
        optim_type = getattr(config.optimizer, 'optim_type', None)
        if optim_type == 'radam':
            print('Using radam optimizer.')
            from template_lib.optimizer import radam
            self.G.optim = radam.RAdam(params=self.G.parameters(),
                                       lr=config.optimizer.G_lr,
                                       betas=(config.optimizer.G_B1,
                                              config.optimizer.G_B2),
                                       weight_decay=0,
                                       eps=config.optimizer.adam_eps)

        print('Create discriminator: {}'.format(Discriminator))
        D_activation = utils.activation_dict[config.Discriminator.D_activation]
        self.D = Discriminator(logger=self.logger,
                               **{
                                   **config.Discriminator, 'resolution':
                                   self.resolution,
                                   'n_classes': self.n_classes,
                                   'D_activation': D_activation
                               },
                               **config.optimizer).cuda()
        if optim_type == 'radam':
            print('Using radam optimizer.')
            from TOOLS import radam
            self.D.optim = radam.RAdam(params=self.D.parameters(),
                                       lr=config.optimizer.D_lr,
                                       betas=(config.optimizer.D_B1,
                                              config.optimizer.D_B2),
                                       weight_decay=0,
                                       eps=config.optimizer.adam_eps)

        if getattr(self.config.train_one_epoch, 'weigh_loss', False):
            self.create_alpha()

        self.G_ema = Generator(logger=self.logger,
                               **{
                                   **config.Generator, 'resolution':
                                   self.resolution,
                                   'n_classes': self.n_classes,
                                   'G_activation': G_activation,
                                   'skip_init': True,
                                   'no_optim': True
                               }).cuda()
        self.ema = ema_model.EMA(self.G,
                                 self.G_ema,
                                 decay=0.9999,
                                 start_itr=config.ema_start)

        print('Create G_D: {}'.format(G_D))
        self.GD = G_D(self.G, self.D)
        if config['parallel']:
            self.GD = nn.DataParallel(self.GD)

        self.myargs.checkpoint_dict['G'] = self.G
        self.myargs.checkpoint_dict['G_optim'] = self.G.optim
        self.myargs.checkpoint_dict['D'] = self.D
        self.myargs.checkpoint_dict['D_optim'] = self.D.optim
        self.myargs.checkpoint_dict['G_ema'] = self.G_ema

        models = {'G': self.G, 'D': self.D}
        self.print_number_params(models=models)