Example #1
0
class SceneTrainer(Trainer):
    def build_models(self):
        self.gan_config = GAN_CONFIGS[self.args.config]
        self.gan_config = self.gan_config.scale_model(self.args.model_scale)
        g_norm_factory = {
            'id': nn.Identity,
            'bn': nn.BatchNorm2d,
        }[self.args.norm]
        d_norm_factory = g_norm_factory  # nn.Identity
        d_norm_factory = {
            'id': nn.Identity,
            'bn': nn.BatchNorm2d,
        }[self.args.norm]
        g_input_factory = functools.partial(
            SceneStructureBlock,
            scene_size=self.args.scene_size,
            patch_size=self.args.patch_size,
            num_patches=self.args.num_patches,
            refine_patches=self.args.refine_patches,
            patch_noise=self.args.patch_noise,
        )
        activation_factory = {
            'relu': functools.partial(nn.LeakyReLU, 0.2),
            'selu': nn.SELU,
            'elu': nn.ELU,
        }[self.args.activation]

        g_input_factory = functools.partial(
            g_input_factory, activation_factory=activation_factory)
        g_block_factory = functools.partial(
            ResidualGeneratorBlock,
            norm_factory=g_norm_factory,
            activation_factory=activation_factory)
        d_block_factory = functools.partial(
            ResidualDiscriminatorBlock,
            norm_factory=d_norm_factory,
            activation_factory=activation_factory)
        g_output_factory = functools.partial(
            GeneratorOutput,
            norm_factory=g_norm_factory,
            activation_factory=activation_factory)
        d_output_factory = functools.partial(
            DiscriminatorOutput,
            norm_factory=d_norm_factory,
            activation_factory=activation_factory)
        self.g = StructuredSceneGenerator(
            self.gan_config,
            input_factory=g_input_factory,
            block_factory=g_block_factory,
            output_factory=g_output_factory,
        ).to(self.device)
        self.target_g = StructuredSceneGenerator(
            self.gan_config,
            input_factory=g_input_factory,
            block_factory=g_block_factory,
            output_factory=g_output_factory,
        ).to(self.device)

        self.d = Discriminator(
            self.gan_config,
            block_factory=d_block_factory,
            output_factory=d_output_factory,
        ).to(self.device)
        self.optimizer_g = torch.optim.Adam(self.g.parameters(),
                                            lr=self.args.lr_g,
                                            betas=(0., 0.999))
        self.optimizer_d = torch.optim.Adam(self.d.parameters(),
                                            lr=self.args.lr_d,
                                            betas=(0., 0.999))
        self.d_loss = discriminator_hinge_loss
        self.g_loss = generator_hinge_loss
        self.bce_loss = nn.BCEWithLogitsLoss()
        # self.bce_loss = nn.BCELoss()
        if self.args.activation == 'selu':
            self.init_params_selu(self.g.parameters())
            self.init_params_selu(self.d.parameters())
        self.update_target_generator(1.)  # copy weights

    def init_params_selu(self, params):
        for p in params:
            d = p.data
            if len(d.shape) == 1:
                d.zero_()
                # d.normal_(std=1e-8)
            else:
                in_dims, _ = nn.init._calculate_fan_in_and_fan_out(d)
                d.normal_(std=np.sqrt(1. / in_dims))

    def train_batch(self, imgs):
        #print(imgs.min(), imgs.mean(), imgs.max())
        imgs = imgs.to(self.device)
        self.g.train()
        self.d.train()
        # train discriminator
        toggle_grad(self.g, False)
        toggle_grad(self.d, True)
        self.optimizer_d.zero_grad()
        batch_imgs, labels = self.make_adversarial_batch(imgs)
        real, fake = batch_imgs[:self.args.batch_size], batch_imgs[self.args.
                                                                   batch_size:]
        # torchvision.utils.save_image(real, 'batch_real.png', normalize=True, range=(-1, 1))
        # torchvision.utils.save_image(fake, 'batch_fake.png', normalize=True, range=(-1, 1))
        if self.args.grad_penalty:
            real.requires_grad_()
        p_labels_real = self.d(real)
        p_labels_fake = self.d(fake.detach())
        p_labels = torch.cat([p_labels_real, p_labels_fake], dim=0)
        # loss_real, loss_fake = self.d_loss(labels, p_labels)
        # d_loss = loss_real + loss_fake
        # d_loss = (
        #     self.bce_loss(p_labels_real, labels[:len(labels) // 2]) +
        #     self.bce_loss(p_labels_fake, labels[len(labels) // 2:])
        # )
        d_loss = self.bce_loss(p_labels, labels)
        d_grad_penalty = 0.
        if self.args.grad_penalty:
            d_grad_penalty = self.args.grad_penalty * gradient_penalty(
                p_labels_real, real)
            d_loss += d_grad_penalty
        d_loss.backward()
        self.optimizer_d.step()

        # train generator
        toggle_grad(self.g, True)
        toggle_grad(self.d, False)
        self.optimizer_g.zero_grad()
        batch_imgs, labels = self.make_generator_batch(imgs)
        #torchvision.utils.save_image(batch_imgs, 'batch.png', normalize=True, range=(-1, 1))
        p_labels = self.d(batch_imgs)
        #g_loss = self.g_loss(p_labels)
        g_loss = self.bce_loss(p_labels, labels)
        g_loss.backward()
        self.optimizer_g.step()

        self.update_target_generator()

        return dict(g_loss=float(g_loss),
                    d_loss=float(d_loss),
                    gp=float(d_grad_penalty))

    #
    # def make_generator_batch(self, real_data, **g_kwargs):
    #     z, generated_data = self.sample_g(len(real_data), return_z_final=True, **g_kwargs)
    #     labels = torch.ones(len(generated_data), 1).to(self.device)
    #     return generated_data, labels, z
    #
    @torch.no_grad()
    def update_target_generator(self, lr=None):
        if lr is None:
            lr = self.args.lr_target_g
        for g_p, target_g_p in zip(self.g.parameters(),
                                   self.target_g.parameters()):
            target_g_p.add_((g_p - target_g_p) * self.args.lr_target_g)

    @classmethod
    def add_args_to_parser(cls, p):
        super().add_args_to_parser(p)
        p.add_argument('--scene-size', type=int, default=16)
        p.add_argument('--patch-size', type=int, default=3)
        p.add_argument('--num-patches', type=int, default=20)
        p.add_argument('--refine-patches', action='store_true')
        p.add_argument('--patch-noise', action='store_true')
Example #2
0
    def build_models(self):
        self.gan_config = GAN_CONFIGS[self.args.config]
        self.gan_config = self.gan_config.scale_model(self.args.model_scale)
        g_norm_factory = {
            'id': nn.Identity,
            'bn': nn.BatchNorm2d,
        }[self.args.norm]
        d_norm_factory = g_norm_factory  # nn.Identity
        d_norm_factory = {
            'id': nn.Identity,
            'bn': nn.BatchNorm2d,
        }[self.args.norm]
        g_input_factory = functools.partial(
            SceneStructureBlock,
            scene_size=self.args.scene_size,
            patch_size=self.args.patch_size,
            num_patches=self.args.num_patches,
            refine_patches=self.args.refine_patches,
            patch_noise=self.args.patch_noise,
        )
        activation_factory = {
            'relu': functools.partial(nn.LeakyReLU, 0.2),
            'selu': nn.SELU,
            'elu': nn.ELU,
        }[self.args.activation]

        g_input_factory = functools.partial(
            g_input_factory, activation_factory=activation_factory)
        g_block_factory = functools.partial(
            ResidualGeneratorBlock,
            norm_factory=g_norm_factory,
            activation_factory=activation_factory)
        d_block_factory = functools.partial(
            ResidualDiscriminatorBlock,
            norm_factory=d_norm_factory,
            activation_factory=activation_factory)
        g_output_factory = functools.partial(
            GeneratorOutput,
            norm_factory=g_norm_factory,
            activation_factory=activation_factory)
        d_output_factory = functools.partial(
            DiscriminatorOutput,
            norm_factory=d_norm_factory,
            activation_factory=activation_factory)
        self.g = StructuredSceneGenerator(
            self.gan_config,
            input_factory=g_input_factory,
            block_factory=g_block_factory,
            output_factory=g_output_factory,
        ).to(self.device)
        self.target_g = StructuredSceneGenerator(
            self.gan_config,
            input_factory=g_input_factory,
            block_factory=g_block_factory,
            output_factory=g_output_factory,
        ).to(self.device)

        self.d = Discriminator(
            self.gan_config,
            block_factory=d_block_factory,
            output_factory=d_output_factory,
        ).to(self.device)
        self.optimizer_g = torch.optim.Adam(self.g.parameters(),
                                            lr=self.args.lr_g,
                                            betas=(0., 0.999))
        self.optimizer_d = torch.optim.Adam(self.d.parameters(),
                                            lr=self.args.lr_d,
                                            betas=(0., 0.999))
        self.d_loss = discriminator_hinge_loss
        self.g_loss = generator_hinge_loss
        self.bce_loss = nn.BCEWithLogitsLoss()
        # self.bce_loss = nn.BCELoss()
        if self.args.activation == 'selu':
            self.init_params_selu(self.g.parameters())
            self.init_params_selu(self.d.parameters())
        self.update_target_generator(1.)  # copy weights
Example #3
0
    def build_models(self):
        self.gan_config = GAN_CONFIGS[self.args.config]
        self.gan_config = self.gan_config.scale_model(self.args.model_scale)
        self.gan_config = self.gan_config._replace(
            data_dims=self.args.embedding_dims)
        g_norm_factory = {
            'id': nn.Identity,
            'bn': nn.BatchNorm1d,
        }[self.args.norm]
        d_norm_factory = g_norm_factory  # nn.Identity
        g_input_factory = {
            'mlp': GeneratorInputMLP1d,
        }[self.args.g_base]
        activation_factory = {
            'relu': functools.partial(nn.LeakyReLU, 0.2),
            'selu': nn.SELU,
            'elu': nn.ELU,
        }[self.args.activation]

        g_input_factory = functools.partial(
            g_input_factory, activation_factory=activation_factory)
        d_input_factory = functools.partial(
            DiscriminatorInput,
            activation_factory=activation_factory,
            conv_factory=nn.Conv1d,
        )
        g_block_factory = functools.partial(
            ResidualGeneratorBlock,
            norm_factory=g_norm_factory,
            activation_factory=activation_factory,
            conv_factory=nn.Conv1d,
        )
        d_block_factory = functools.partial(
            ResidualDiscriminatorBlock,
            norm_factory=d_norm_factory,
            activation_factory=activation_factory,
            conv_factory=nn.Conv1d,
            avg_pool_factory=nn.AvgPool1d,
            interpolate=functools.partial(F.interpolate,
                                          scale_factor=0.5,
                                          mode='linear',
                                          align_corners=False))
        g_output_factory = functools.partial(
            GeneratorOutput,
            norm_factory=g_norm_factory,
            activation_factory=activation_factory,
            conv_factory=nn.Conv1d,
            output_activation_factory=nn.Identity,
        )
        d_output_factory = functools.partial(
            DiscriminatorOutput,
            norm_factory=d_norm_factory,
            activation_factory=activation_factory)
        self.g = Generator(
            self.gan_config,
            input_factory=g_input_factory,
            block_factory=g_block_factory,
            output_factory=g_output_factory,
        ).to(self.device)
        self.target_g = Generator(
            self.gan_config,
            input_factory=g_input_factory,
            block_factory=g_block_factory,
            output_factory=g_output_factory,
        ).to(self.device)

        self.d = Discriminator(
            self.gan_config,
            input_factory=d_input_factory,
            block_factory=d_block_factory,
            output_factory=d_output_factory,
        ).to(self.device)
        self.optimizer_g = torch.optim.Adam(self.g.parameters(),
                                            lr=self.args.lr_g,
                                            betas=(0., 0.999))
        self.optimizer_d = torch.optim.Adam(self.d.parameters(),
                                            lr=self.args.lr_d,
                                            betas=(0., 0.999))
        self.d_loss = discriminator_hinge_loss
        self.g_loss = generator_hinge_loss
        self.bce_loss = nn.BCEWithLogitsLoss()
        # self.bce_loss = nn.BCELoss()
        print(self.g)
        print(self.d)
        if self.args.activation == 'selu':
            self.init_params_selu(self.g.parameters())
            self.init_params_selu(self.d.parameters())
        self.update_target_generator(1.)  # copy weights
        self.pretraining_embedding = self.args.pretrain_embedding
Example #4
0
class TextCNNTrainer(Trainer):
    def build_models(self):
        self.gan_config = GAN_CONFIGS[self.args.config]
        self.gan_config = self.gan_config.scale_model(self.args.model_scale)
        self.gan_config = self.gan_config._replace(
            data_dims=self.args.embedding_dims)
        g_norm_factory = {
            'id': nn.Identity,
            'bn': nn.BatchNorm1d,
        }[self.args.norm]
        d_norm_factory = g_norm_factory  # nn.Identity
        g_input_factory = {
            'mlp': GeneratorInputMLP1d,
        }[self.args.g_base]
        activation_factory = {
            'relu': functools.partial(nn.LeakyReLU, 0.2),
            'selu': nn.SELU,
            'elu': nn.ELU,
        }[self.args.activation]

        g_input_factory = functools.partial(
            g_input_factory, activation_factory=activation_factory)
        d_input_factory = functools.partial(
            DiscriminatorInput,
            activation_factory=activation_factory,
            conv_factory=nn.Conv1d,
        )
        g_block_factory = functools.partial(
            ResidualGeneratorBlock,
            norm_factory=g_norm_factory,
            activation_factory=activation_factory,
            conv_factory=nn.Conv1d,
        )
        d_block_factory = functools.partial(
            ResidualDiscriminatorBlock,
            norm_factory=d_norm_factory,
            activation_factory=activation_factory,
            conv_factory=nn.Conv1d,
            avg_pool_factory=nn.AvgPool1d,
            interpolate=functools.partial(F.interpolate,
                                          scale_factor=0.5,
                                          mode='linear',
                                          align_corners=False))
        g_output_factory = functools.partial(
            GeneratorOutput,
            norm_factory=g_norm_factory,
            activation_factory=activation_factory,
            conv_factory=nn.Conv1d,
            output_activation_factory=nn.Identity,
        )
        d_output_factory = functools.partial(
            DiscriminatorOutput,
            norm_factory=d_norm_factory,
            activation_factory=activation_factory)
        self.g = Generator(
            self.gan_config,
            input_factory=g_input_factory,
            block_factory=g_block_factory,
            output_factory=g_output_factory,
        ).to(self.device)
        self.target_g = Generator(
            self.gan_config,
            input_factory=g_input_factory,
            block_factory=g_block_factory,
            output_factory=g_output_factory,
        ).to(self.device)

        self.d = Discriminator(
            self.gan_config,
            input_factory=d_input_factory,
            block_factory=d_block_factory,
            output_factory=d_output_factory,
        ).to(self.device)
        self.optimizer_g = torch.optim.Adam(self.g.parameters(),
                                            lr=self.args.lr_g,
                                            betas=(0., 0.999))
        self.optimizer_d = torch.optim.Adam(self.d.parameters(),
                                            lr=self.args.lr_d,
                                            betas=(0., 0.999))
        self.d_loss = discriminator_hinge_loss
        self.g_loss = generator_hinge_loss
        self.bce_loss = nn.BCEWithLogitsLoss()
        # self.bce_loss = nn.BCELoss()
        print(self.g)
        print(self.d)
        if self.args.activation == 'selu':
            self.init_params_selu(self.g.parameters())
            self.init_params_selu(self.d.parameters())
        self.update_target_generator(1.)  # copy weights
        self.pretraining_embedding = self.args.pretrain_embedding

    def init_params_selu(self, params):
        for p in params:
            d = p.data
            if len(d.shape) == 1:
                d.zero_()
                # d.normal_(std=1e-8)
            else:
                in_dims, _ = nn.init._calculate_fan_in_and_fan_out(d)
                d.normal_(std=np.sqrt(1. / in_dims))

    def setup_components(self):
        self.components.add_components(
            ModelCheckpointComponent(),
            TextSamplerComponent(),
        )

        if self.args.metrics_collector:
            metrics_collector_class = {
                'katib': KatibMetricsComponent,
                'kubeflow': KubeflowMetricsComponent,
                'tensorboard': TensorboardComponent,
            }[self.args.metrics_collector]
            metrics_collector = metrics_collector_class(self.args.metrics_path)
            self.components.add_components(metrics_collector)

    def prepare_dataset(self):
        max_doc_size = self.g.max_size
        self.dataset = TextDataset.from_path(self.args.data_path,
                                             doc_len=max_doc_size)
        self.embedding = SkipGram(len(self.dataset.vocab),
                                  self.args.embedding_dims,
                                  sparse=True,
                                  padding_idx=self.dataset.vocab.stoi['<PAD>'])
        self.embedding = self.embedding.to(self.device)
        self.optimizer_embedding = torch.optim.SGD(
            self.embedding.parameters(),
            lr=self.args.lr_d,
        )
        return self.dataset

    def train_batch(self, input_indexes):
        input_indexes = input_indexes.long().to(self.device)

        # train embedding
        self.embedding.train()
        toggle_grad(self.embedding, True)
        self.optimizer_embedding.zero_grad()
        # extract random windows from the input docs
        window_size = self.args.context * 2 + 1
        offsets = np.random.randint(0, window_size, len(input_indexes))
        windows = torch.stack([
            input_indexes[i, ..., offset:offset + window_size]
            for i, offset in enumerate(offsets)
        ])
        # get a list of pivot words and their contexts from the windows
        words = windows[..., self.args.context]
        contexts = torch.cat([
            windows[..., :self.args.context], windows[...,
                                                      self.args.context + 1:]
        ],
                             dim=-1)
        # get the loss
        embedding_loss = self.embedding.loss(words, contexts)
        embedding_loss.backward()
        self.optimizer_embedding.step()

        self.pretraining_embedding = max(self.pretraining_embedding - 1, 0)
        if not self.pretraining_embedding:
            inputs = self.embedding(input_indexes).permute((0, 2, 1)).detach()
            self.g.train()
            self.d.train()
            # train discriminator
            toggle_grad(self.embedding, False)
            toggle_grad(self.g, False)
            toggle_grad(self.d, True)
            self.optimizer_d.zero_grad()
            batch_inputs, labels = self.make_adversarial_batch(inputs)
            real, fake = batch_inputs[:self.args.batch_size], batch_inputs[
                self.args.batch_size:]
            if self.args.grad_penalty:
                real.requires_grad_()
            p_labels_real = self.d(real)
            p_labels_fake = self.d(fake.detach())
            p_labels = torch.cat([p_labels_real, p_labels_fake], dim=0)
            d_loss = self.bce_loss(p_labels, labels)
            d_grad_penalty = 0.
            if self.args.grad_penalty:
                d_grad_penalty = self.args.grad_penalty * gradient_penalty(
                    p_labels_real, real)
                d_loss += d_grad_penalty
            d_loss.backward()
            self.optimizer_d.step()

            # train generator
            toggle_grad(self.g, True)
            toggle_grad(self.d, False)
            self.optimizer_g.zero_grad()
            batch_inputs, labels = self.make_generator_batch(inputs)
            p_labels = self.d(batch_inputs)
            g_loss = self.bce_loss(p_labels, labels)
            g_loss.backward()
            self.optimizer_g.step()
        else:
            g_loss = d_loss = d_grad_penalty = 0.

        self.update_target_generator()

        return dict(g_loss=float(g_loss),
                    d_loss=float(d_loss),
                    gp=float(d_grad_penalty),
                    embedding_loss=float(embedding_loss))

    def make_adversarial_batch(self, real_data, **g_kwargs):
        generated_data = self.sample_g(len(real_data), **g_kwargs)
        batch = torch.cat([real_data, generated_data], dim=0)
        labels = torch.zeros(len(batch), 1).to(self.device)
        labels[:len(labels) // 2] = 1  # first half is real
        return batch, labels

    @torch.no_grad()
    def update_target_generator(self, lr=None):
        if lr is None:
            lr = self.args.lr_target_g
        for g_p, target_g_p in zip(self.g.parameters(),
                                   self.target_g.parameters()):
            target_g_p.add_((g_p - target_g_p) * self.args.lr_target_g)

    @classmethod
    def add_args_to_parser(cls, p):
        super().add_args_to_parser(p)
        p.add_argument('--embedding-dims', type=int, default=64)
        p.add_argument('--context', type=int, default=3)
        p.add_argument('--pretrain-embedding', type=int, default=10000)
Example #5
0
    def build_models(self):
        self.gan_config = GAN_CONFIGS[self.args.config]
        self.gan_config = self.gan_config.scale_model(self.args.model_scale)
        g_norm_factory = {
            'id': nn.Identity,
            'bn': nn.BatchNorm2d,
        }[self.args.norm]
        d_norm_factory = g_norm_factory  # nn.Identity
        g_input_factory = {
            'mlp': GeneratorInputMLP,
            'tiledz': TiledZGeneratorInput,
        }[self.args.g_base]
        activation_factory = {
            'relu': functools.partial(nn.LeakyReLU, 0.2),
            'selu': nn.SELU,
        }[self.args.activation]

        g_input_factory = functools.partial(
            g_input_factory, activation_factory=activation_factory
        )
        g_block_factory = functools.partial(
            ResidualGeneratorBlock, norm_factory=g_norm_factory,
            activation_factory=activation_factory
        )
        d_block_factory = functools.partial(
            ResidualDiscriminatorBlock, norm_factory=d_norm_factory,
            activation_factory=activation_factory
        )
        g_output_factory = functools.partial(
            GeneratorOutput, norm_factory=g_norm_factory,
            activation_factory=activation_factory
        )
        d_output_factory = functools.partial(
            MultiModelDiscriminatorOutput,
            output_model_factories=[
                functools.partial(LinearOutput, out_dims=1),
                functools.partial(
                    LinearOutput,
                    out_dims=self.args.info_cat_dims + self.args.info_cont_dims,
                )
            ],
            norm_factory=d_norm_factory,
            activation_factory=activation_factory
        )
        self.g = Generator(
            self.gan_config,
            input_factory=g_input_factory,
            block_factory=g_block_factory,
            output_factory=g_output_factory,
        ).to(self.device)
        self.target_g = Generator(
            self.gan_config,
            input_factory=g_input_factory,
            block_factory=g_block_factory,
            output_factory=g_output_factory,
        ).to(self.device)

        self.d = Discriminator(
            self.gan_config,
            block_factory=d_block_factory,
            output_factory=d_output_factory,
        ).to(self.device)
        self.optimizer_g = torch.optim.Adam(self.g.parameters(), lr=self.args.lr_g, betas=(0., 0.999))
        self.optimizer_d = torch.optim.Adam(self.d.parameters(), lr=self.args.lr_d, betas=(0., 0.999))
        print(self.g)
        print(self.d)
        if self.args.activation == 'selu':
            self.init_params_selu(self.g.parameters())
            self.init_params_selu(self.d.parameters())
        self.update_target_generator(1.)  # copy weights
        self.bce_loss = nn.BCEWithLogitsLoss()
        self.mse_loss = nn.MSELoss()
Example #6
0
class InfoTrainer(Trainer):
    def build_models(self):
        self.gan_config = GAN_CONFIGS[self.args.config]
        self.gan_config = self.gan_config.scale_model(self.args.model_scale)
        g_norm_factory = {
            'id': nn.Identity,
            'bn': nn.BatchNorm2d,
        }[self.args.norm]
        d_norm_factory = g_norm_factory  # nn.Identity
        g_input_factory = {
            'mlp': GeneratorInputMLP,
            'tiledz': TiledZGeneratorInput,
        }[self.args.g_base]
        activation_factory = {
            'relu': functools.partial(nn.LeakyReLU, 0.2),
            'selu': nn.SELU,
        }[self.args.activation]

        g_input_factory = functools.partial(
            g_input_factory, activation_factory=activation_factory
        )
        g_block_factory = functools.partial(
            ResidualGeneratorBlock, norm_factory=g_norm_factory,
            activation_factory=activation_factory
        )
        d_block_factory = functools.partial(
            ResidualDiscriminatorBlock, norm_factory=d_norm_factory,
            activation_factory=activation_factory
        )
        g_output_factory = functools.partial(
            GeneratorOutput, norm_factory=g_norm_factory,
            activation_factory=activation_factory
        )
        d_output_factory = functools.partial(
            MultiModelDiscriminatorOutput,
            output_model_factories=[
                functools.partial(LinearOutput, out_dims=1),
                functools.partial(
                    LinearOutput,
                    out_dims=self.args.info_cat_dims + self.args.info_cont_dims,
                )
            ],
            norm_factory=d_norm_factory,
            activation_factory=activation_factory
        )
        self.g = Generator(
            self.gan_config,
            input_factory=g_input_factory,
            block_factory=g_block_factory,
            output_factory=g_output_factory,
        ).to(self.device)
        self.target_g = Generator(
            self.gan_config,
            input_factory=g_input_factory,
            block_factory=g_block_factory,
            output_factory=g_output_factory,
        ).to(self.device)

        self.d = Discriminator(
            self.gan_config,
            block_factory=d_block_factory,
            output_factory=d_output_factory,
        ).to(self.device)
        self.optimizer_g = torch.optim.Adam(self.g.parameters(), lr=self.args.lr_g, betas=(0., 0.999))
        self.optimizer_d = torch.optim.Adam(self.d.parameters(), lr=self.args.lr_d, betas=(0., 0.999))
        print(self.g)
        print(self.d)
        if self.args.activation == 'selu':
            self.init_params_selu(self.g.parameters())
            self.init_params_selu(self.d.parameters())
        self.update_target_generator(1.)  # copy weights
        self.bce_loss = nn.BCEWithLogitsLoss()
        self.mse_loss = nn.MSELoss()

    @classmethod
    def get_component_classes(self, args):
        classes = super().get_component_classes(args)
        classes.append(InfoImageSamplerComponent)
        return classes

    def init_params_selu(self, params):
        for p in params:
            d = p.data
            if len(d.shape) == 1:
                d.zero_()
                #d.normal_(std=1e-8)
            else:
                in_dims, _ = nn.init._calculate_fan_in_and_fan_out(d)
                d.normal_(std=np.sqrt(1. / in_dims))

    def train_batch(self, imgs):
        #print(imgs.min(), imgs.mean(), imgs.max())
        imgs = imgs.to(self.device)
        self.g.train()
        self.d.train()
        # train discriminator
        toggle_grad(self.g, False)
        toggle_grad(self.d, True)
        self.optimizer_d.zero_grad()
        batch_imgs, labels, z = self.make_adversarial_batch(imgs)
        real, fake = batch_imgs[:self.args.batch_size], batch_imgs[self.args.batch_size:]
        if self.args.grad_penalty:
            real.requires_grad_()
        p_labels_real, _ = self.d(real)
        p_labels_fake, p_codes = self.d(fake.detach())
        p_labels = torch.cat([p_labels_real, p_labels_fake], dim=0)
        d_loss = self.bce_loss(p_labels, labels)
        # infogan loss
        d_code_loss = 0
        if self.args.info_cat_dims:
            z_cat_code = self.z_categorical_code(z)
            p_z_cat_code = self.z_categorical_code(p_codes)
            d_cat_code_loss = self.bce_loss(p_z_cat_code, z_cat_code)
            d_code_loss += d_cat_code_loss
        if self.args.info_cont_dims:
            z_cont_code = self.z_continuous_code(z)
            p_z_cont_code = self.z_continuous_code(p_codes)
            d_cont_code_loss = self.mse_loss(p_z_cont_code, z_cont_code)
            d_code_loss += d_cont_code_loss
        d_loss += self.args.info_w * d_code_loss

        d_grad_penalty = 0.
        if self.args.grad_penalty:
            d_grad_penalty = self.args.grad_penalty * gradient_penalty(p_labels_real, real)
            d_loss += d_grad_penalty
        d_loss.backward()
        self.optimizer_d.step()

        # train generator
        toggle_grad(self.g, True)
        toggle_grad(self.d, False)
        self.optimizer_g.zero_grad()
        batch_imgs, labels, z = self.make_generator_batch(imgs)
        p_labels, p_codes = self.d(batch_imgs)
        g_loss = self.bce_loss(p_labels, labels)
        # infogan loss
        g_code_loss = 0.
        if self.args.info_cat_dims:
            z_cat_code = self.z_categorical_code(z)
            p_z_cat_code = self.z_categorical_code(p_codes)
            g_cat_code_loss = self.bce_loss(p_z_cat_code, z_cat_code)
            g_code_loss += g_cat_code_loss
        if self.args.info_cont_dims:
            z_cont_code = self.z_continuous_code(z)
            p_z_cont_code = self.z_continuous_code(p_codes)
            g_cont_code_loss = self.mse_loss(p_z_cont_code, z_cont_code)
            g_code_loss += g_cont_code_loss

        g_loss += self.args.info_w * g_code_loss

        g_loss.backward()
        self.optimizer_g.step()

        self.update_target_generator()

        return dict(
            g_loss=float(g_loss), g_code_loss=float(g_code_loss),
            d_loss=float(d_loss), d_code_loss=float(d_code_loss),
            gp=float(d_grad_penalty)
        )

    def log_likelihood_gaussian(self, x, mu, log_sigma, eps=1e-8):
        z = (x - mu) / (torch.exp(log_sigma) + eps)
        result = -0.5 * np.log(2. * np.pi) - log_sigma - 0.5 * z ** 2
        return result.sum(1)

    def z_categorical_code(self, z):
        return z[..., :self.args.info_cat_dims]

    def z_continuous_code(self, z):
        return z[..., self.args.info_cat_dims:self.args.info_cat_dims + self.args.info_cont_dims]

    def sample_z(self, n=None):
        if n is None:
            n = self.args.batch_size
        z = torch.randn(n, self.gan_config.latent_dims).to(self.device)
        # set up the categorical dimensions
        if self.args.info_cat_dims:
            z[..., :self.args.info_cat_dims] = 0.
            cats = np.random.randint(0, self.args.info_cat_dims, (n,))
            z[np.arange(n), ..., cats] = 1.
        return z

    def sample_g(self, n=None, target_g=False, **g_kwargs):
        z = self.sample_z(n)
        if target_g:
            imgs = self.target_g(z, **g_kwargs)
        else:
            imgs = self.g(z, **g_kwargs)
        return imgs, z

    def make_adversarial_batch(self, real_data, **g_kwargs):
        generated_data, z = self.sample_g(len(real_data), **g_kwargs)
        batch = torch.cat([real_data, generated_data], dim=0)
        labels = torch.zeros(len(batch), 1).to(self.device)
        labels[:len(labels) // 2] = 1  # first half is real
        return batch, labels, z

    def make_generator_batch(self, real_data, **g_kwargs):
        generated_data, z = self.sample_g(len(real_data), **g_kwargs)
        labels = torch.ones(len(generated_data), 1).to(self.device)
        return generated_data, labels, z

    @torch.no_grad()
    def update_target_generator(self, lr=None):
        if lr is None:
            lr = self.args.lr_target_g
        for g_p, target_g_p in zip(self.g.parameters(), self.target_g.parameters()):
            target_g_p.add_(
                (g_p - target_g_p) * self.args.lr_target_g
            )

    @classmethod
    def add_args_to_parser(cls, p):
        super().add_args_to_parser(p)
        p.add_argument('--info-cat-dims', type=int, default=10)
        p.add_argument('--info-cont-dims', type=int, default=5)
        p.add_argument('--info-w', type=float, default=1.)
Example #7
0
class CNNTrainer(Trainer):
    def build_models(self):
        self.gan_config = GAN_CONFIGS[self.args.config]
        self.gan_config = self.gan_config.scale_model(self.args.model_scale)
        g_norm_factory = {
            'id': nn.Identity,
            'bn': nn.BatchNorm2d,
        }[self.args.norm]
        d_norm_factory = g_norm_factory  # nn.Identity
        g_input_factory = {
            'mlp': GeneratorInputMLP,
            'tiledz': TiledZGeneratorInput,
        }[self.args.g_base]
        activation_factory = {
            'relu': functools.partial(nn.LeakyReLU, 0.2),
            'selu': nn.SELU,
            'elu': nn.ELU,
        }[self.args.activation]

        g_input_factory = functools.partial(
            g_input_factory, activation_factory=activation_factory)
        g_block_factory = functools.partial(
            ResidualGeneratorBlock,
            norm_factory=g_norm_factory,
            activation_factory=activation_factory)
        d_block_factory = functools.partial(
            ResidualDiscriminatorBlock,
            norm_factory=d_norm_factory,
            activation_factory=activation_factory)
        g_output_factory = functools.partial(
            GeneratorOutput,
            norm_factory=g_norm_factory,
            activation_factory=activation_factory)
        d_output_factory = functools.partial(
            DiscriminatorOutput,
            norm_factory=d_norm_factory,
            activation_factory=activation_factory)
        self.g = Generator(
            self.gan_config,
            input_factory=g_input_factory,
            block_factory=g_block_factory,
            output_factory=g_output_factory,
        ).to(self.device)
        self.target_g = Generator(
            self.gan_config,
            input_factory=g_input_factory,
            block_factory=g_block_factory,
            output_factory=g_output_factory,
        ).to(self.device)

        self.d = Discriminator(
            self.gan_config,
            block_factory=d_block_factory,
            output_factory=d_output_factory,
        ).to(self.device)
        self.optimizer_g = torch.optim.Adam(self.g.parameters(),
                                            lr=self.args.lr_g,
                                            betas=(0., 0.999))
        self.optimizer_d = torch.optim.Adam(self.d.parameters(),
                                            lr=self.args.lr_d,
                                            betas=(0., 0.999))
        self.d_loss = discriminator_hinge_loss
        self.g_loss = generator_hinge_loss
        self.bce_loss = nn.BCEWithLogitsLoss()
        # self.bce_loss = nn.BCELoss()
        print(self.g)
        print(self.d)
        if self.args.activation == 'selu':
            self.init_params_selu(self.g.parameters())
            self.init_params_selu(self.d.parameters())
        self.update_target_generator(1.)  # copy weights

    def init_params_selu(self, params):
        for p in params:
            d = p.data
            if len(d.shape) == 1:
                d.zero_()
                # d.normal_(std=1e-8)
            else:
                in_dims, _ = nn.init._calculate_fan_in_and_fan_out(d)
                d.normal_(std=np.sqrt(1. / in_dims))

    def train_batch(self, imgs):
        #print(imgs.min(), imgs.mean(), imgs.max())
        imgs = imgs.to(self.device)
        self.g.train()
        self.d.train()
        # train discriminator
        toggle_grad(self.g, False)
        toggle_grad(self.d, True)
        self.optimizer_d.zero_grad()
        batch_imgs, labels = self.make_adversarial_batch(imgs)
        real, fake = batch_imgs[:self.args.batch_size], batch_imgs[self.args.
                                                                   batch_size:]
        # torchvision.utils.save_image(real, 'batch_real.png', normalize=True, range=(-1, 1))
        # torchvision.utils.save_image(fake, 'batch_fake.png', normalize=True, range=(-1, 1))
        if self.args.grad_penalty:
            real.requires_grad_()
        p_labels_real = self.d(real)
        p_labels_fake = self.d(fake.detach())
        p_labels = torch.cat([p_labels_real, p_labels_fake], dim=0)
        # loss_real, loss_fake = self.d_loss(labels, p_labels)
        # d_loss = loss_real + loss_fake
        # d_loss = (
        #     self.bce_loss(p_labels_real, labels[:len(labels) // 2]) +
        #     self.bce_loss(p_labels_fake, labels[len(labels) // 2:])
        # )
        d_loss = self.bce_loss(p_labels, labels)
        d_grad_penalty = 0.
        if self.args.grad_penalty:
            d_grad_penalty = self.args.grad_penalty * gradient_penalty(
                p_labels_real, real)
            d_loss += d_grad_penalty
        d_loss.backward()
        self.optimizer_d.step()

        # train generator
        toggle_grad(self.g, True)
        toggle_grad(self.d, False)
        self.optimizer_g.zero_grad()
        batch_imgs, labels = self.make_generator_batch(imgs)
        #torchvision.utils.save_image(batch_imgs, 'batch.png', normalize=True, range=(-1, 1))
        p_labels = self.d(batch_imgs)
        #g_loss = self.g_loss(p_labels)
        g_loss = self.bce_loss(p_labels, labels)
        g_loss.backward()
        self.optimizer_g.step()

        self.update_target_generator()

        return dict(g_loss=float(g_loss),
                    d_loss=float(d_loss),
                    gp=float(d_grad_penalty))

    @torch.no_grad()
    def update_target_generator(self, lr=None):
        if lr is None:
            lr = self.args.lr_target_g
        for g_p, target_g_p in zip(self.g.parameters(),
                                   self.target_g.parameters()):
            target_g_p.add_((g_p - target_g_p) * self.args.lr_target_g)