Example #1
0
    def train_batch(self, imgs):
        #print(imgs.min(), imgs.mean(), imgs.max())
        imgs = imgs.to(self.device)
        self.g.train()
        self.d.train()
        # train discriminator
        toggle_grad(self.g, False)
        toggle_grad(self.d, True)
        self.optimizer_d.zero_grad()
        batch_imgs, labels = self.make_adversarial_batch(imgs)
        real, fake = batch_imgs[:self.args.batch_size], batch_imgs[self.args.
                                                                   batch_size:]
        # torchvision.utils.save_image(real, 'batch_real.png', normalize=True, range=(-1, 1))
        # torchvision.utils.save_image(fake, 'batch_fake.png', normalize=True, range=(-1, 1))
        if self.args.grad_penalty:
            real.requires_grad_()
        p_labels_real = self.d(real)
        p_labels_fake = self.d(fake.detach())
        p_labels = torch.cat([p_labels_real, p_labels_fake], dim=0)
        # loss_real, loss_fake = self.d_loss(labels, p_labels)
        # d_loss = loss_real + loss_fake
        # d_loss = (
        #     self.bce_loss(p_labels_real, labels[:len(labels) // 2]) +
        #     self.bce_loss(p_labels_fake, labels[len(labels) // 2:])
        # )
        d_loss = self.bce_loss(p_labels, labels)
        d_grad_penalty = 0.
        if self.args.grad_penalty:
            d_grad_penalty = self.args.grad_penalty * gradient_penalty(
                p_labels_real, real)
            d_loss += d_grad_penalty
        d_loss.backward()
        self.optimizer_d.step()

        # train generator
        toggle_grad(self.g, True)
        toggle_grad(self.d, False)
        self.optimizer_g.zero_grad()
        batch_imgs, labels = self.make_generator_batch(imgs)
        #torchvision.utils.save_image(batch_imgs, 'batch.png', normalize=True, range=(-1, 1))
        p_labels = self.d(batch_imgs)
        #g_loss = self.g_loss(p_labels)
        g_loss = self.bce_loss(p_labels, labels)
        g_loss.backward()
        self.optimizer_g.step()

        self.update_target_generator()

        return dict(g_loss=float(g_loss),
                    d_loss=float(d_loss),
                    gp=float(d_grad_penalty))
Example #2
0
    def train_batch(self, input_indexes):
        input_indexes = input_indexes.long().to(self.device)

        # train embedding
        self.embedding.train()
        toggle_grad(self.embedding, True)
        self.optimizer_embedding.zero_grad()
        # extract random windows from the input docs
        window_size = self.args.context * 2 + 1
        offsets = np.random.randint(0, window_size, len(input_indexes))
        windows = torch.stack([
            input_indexes[i, ..., offset:offset + window_size]
            for i, offset in enumerate(offsets)
        ])
        # get a list of pivot words and their contexts from the windows
        words = windows[..., self.args.context]
        contexts = torch.cat([
            windows[..., :self.args.context], windows[...,
                                                      self.args.context + 1:]
        ],
                             dim=-1)
        # get the loss
        embedding_loss = self.embedding.loss(words, contexts)
        embedding_loss.backward()
        self.optimizer_embedding.step()

        self.pretraining_embedding = max(self.pretraining_embedding - 1, 0)
        if not self.pretraining_embedding:
            inputs = self.embedding(input_indexes).permute((0, 2, 1)).detach()
            self.g.train()
            self.d.train()
            # train discriminator
            toggle_grad(self.embedding, False)
            toggle_grad(self.g, False)
            toggle_grad(self.d, True)
            self.optimizer_d.zero_grad()
            batch_inputs, labels = self.make_adversarial_batch(inputs)
            real, fake = batch_inputs[:self.args.batch_size], batch_inputs[
                self.args.batch_size:]
            if self.args.grad_penalty:
                real.requires_grad_()
            p_labels_real = self.d(real)
            p_labels_fake = self.d(fake.detach())
            p_labels = torch.cat([p_labels_real, p_labels_fake], dim=0)
            d_loss = self.bce_loss(p_labels, labels)
            d_grad_penalty = 0.
            if self.args.grad_penalty:
                d_grad_penalty = self.args.grad_penalty * gradient_penalty(
                    p_labels_real, real)
                d_loss += d_grad_penalty
            d_loss.backward()
            self.optimizer_d.step()

            # train generator
            toggle_grad(self.g, True)
            toggle_grad(self.d, False)
            self.optimizer_g.zero_grad()
            batch_inputs, labels = self.make_generator_batch(inputs)
            p_labels = self.d(batch_inputs)
            g_loss = self.bce_loss(p_labels, labels)
            g_loss.backward()
            self.optimizer_g.step()
        else:
            g_loss = d_loss = d_grad_penalty = 0.

        self.update_target_generator()

        return dict(g_loss=float(g_loss),
                    d_loss=float(d_loss),
                    gp=float(d_grad_penalty),
                    embedding_loss=float(embedding_loss))
Example #3
0
    def train_batch(self, imgs):
        #print(imgs.min(), imgs.mean(), imgs.max())
        imgs = imgs.to(self.device)
        self.g.train()
        self.d.train()
        # train discriminator
        toggle_grad(self.g, False)
        toggle_grad(self.d, True)
        self.optimizer_d.zero_grad()
        batch_imgs, labels, z = self.make_adversarial_batch(imgs)
        real, fake = batch_imgs[:self.args.batch_size], batch_imgs[self.args.batch_size:]
        if self.args.grad_penalty:
            real.requires_grad_()
        p_labels_real, _ = self.d(real)
        p_labels_fake, p_codes = self.d(fake.detach())
        p_labels = torch.cat([p_labels_real, p_labels_fake], dim=0)
        d_loss = self.bce_loss(p_labels, labels)
        # infogan loss
        d_code_loss = 0
        if self.args.info_cat_dims:
            z_cat_code = self.z_categorical_code(z)
            p_z_cat_code = self.z_categorical_code(p_codes)
            d_cat_code_loss = self.bce_loss(p_z_cat_code, z_cat_code)
            d_code_loss += d_cat_code_loss
        if self.args.info_cont_dims:
            z_cont_code = self.z_continuous_code(z)
            p_z_cont_code = self.z_continuous_code(p_codes)
            d_cont_code_loss = self.mse_loss(p_z_cont_code, z_cont_code)
            d_code_loss += d_cont_code_loss
        d_loss += self.args.info_w * d_code_loss

        d_grad_penalty = 0.
        if self.args.grad_penalty:
            d_grad_penalty = self.args.grad_penalty * gradient_penalty(p_labels_real, real)
            d_loss += d_grad_penalty
        d_loss.backward()
        self.optimizer_d.step()

        # train generator
        toggle_grad(self.g, True)
        toggle_grad(self.d, False)
        self.optimizer_g.zero_grad()
        batch_imgs, labels, z = self.make_generator_batch(imgs)
        p_labels, p_codes = self.d(batch_imgs)
        g_loss = self.bce_loss(p_labels, labels)
        # infogan loss
        g_code_loss = 0.
        if self.args.info_cat_dims:
            z_cat_code = self.z_categorical_code(z)
            p_z_cat_code = self.z_categorical_code(p_codes)
            g_cat_code_loss = self.bce_loss(p_z_cat_code, z_cat_code)
            g_code_loss += g_cat_code_loss
        if self.args.info_cont_dims:
            z_cont_code = self.z_continuous_code(z)
            p_z_cont_code = self.z_continuous_code(p_codes)
            g_cont_code_loss = self.mse_loss(p_z_cont_code, z_cont_code)
            g_code_loss += g_cont_code_loss

        g_loss += self.args.info_w * g_code_loss

        g_loss.backward()
        self.optimizer_g.step()

        self.update_target_generator()

        return dict(
            g_loss=float(g_loss), g_code_loss=float(g_code_loss),
            d_loss=float(d_loss), d_code_loss=float(d_code_loss),
            gp=float(d_grad_penalty)
        )