Example #1
0
    def train(img, label):
        E.optim.zero_grad()
        img = torch.split(img, config['batch_size'])
        label = torch.split(label, config['batch_size'])
        counter = 0

        for step_index in range(config['num_D_steps']):
            E.optim.zero_grad()
            fake, logits, vgg_loss = Wrapper(img[counter], label[counter])
            vgg_loss = vgg_loss * config['vgg_loss_scale']
            d_loss = losses.generator_loss(logits) * config['adv_loss_scale']
            recon_loss = losses.recon_loss(
                fakes=fake, reals=img[counter]) * config['recon_loss_scale']
            loss = d_loss + recon_loss + vgg_loss
            loss.backward()
            counter += 1
            if config['E_ortho'] > 0.0:
                # Debug print to indicate we're using ortho reg in D.
                print('using modified ortho reg in D')
                utils.ortho(D, config['E_ortho'])
            E.optim.step()

        out = {
            'Vgg_loss': float(vgg_loss.item()),
            'D_loss': float(d_loss.item()),
            'pixel_loss': float(recon_loss.item())
        }
        return out
Example #2
0
    def train():
        E.optim.zero_grad()
        z_.sample_()
        y_.sample_()

        net = GE(z_[:config['batch_size']], y_[:config['batch_size']])
        loss = F.l1_loss(z_[:config['batch_size']], net)
        loss.backward()
        if config["E_ortho"] > 0.0:
            print('using modified ortho reg in E')
            utils.ortho(E, config['E_ortho'])
        E.optim.step()
        out = {'loss': float(loss.item())}
        return out
 def train(w, img):
     y_.sample_()
     G.optim.zero_grad()
     x = W(w, y_)
     loss = MSE(x, img)
     loss.backward()
     if config['E_ortho'] > 0.0:
         # Debug print to indicate we're using ortho reg in D.
         print('using modified ortho reg in E')
         utils.ortho(G, config['G_ortho'])
     G.optim.step()
     out = {' loss': float(loss.item())}
     if config['ema']:
         ema.update(state_dict['itr'])
     del loss, x
     return out
Example #4
0
 def train(w, img):
     E.optim.zero_grad()
     Out.optim.zero_grad()
     w_ = W(img)
     loss = F.mse_loss(w_, w, reduction='mean')
     loss.backward()
     if config['E_ortho'] > 0.0:
         # Debug print to indicate we're using ortho reg in D.
         print('using modified ortho reg in E')
         utils.ortho(E, config['E_ortho'])
         utils.ortho(Out, config['E_ortho'])
     E.optim.step()
     Out.optim.step()
     out = {' loss': float(loss.item())}
     if config['ema']:
         for ema in [eema, oema]:
             ema.update(state_dict['itr'])
     del w_, loss
     return out
Example #5
0
    def train(x, y):
        Ex.optim.zero_grad()
        rot_logits, s2l_logits, y, ry = Ex_parallel(x, y)
        rot_loss, s2l_loss = extractor_loss(rot_logits, s2l_logits, y, ry)
        loss = rot_loss + 0.5 * s2l_loss
        loss.backward()
        if config['Ex_ortho'] > 0.0:
            print('using modified ortho reg in Extractor')  # Debug print to indicate we're using ortho reg in G
            # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this
            utils.ortho(Ex, config['G_ortho'])
        Ex.optim.step()

        # If we have an ema, update it, regardless of if we test with it or not
        if config['ema']:
            ema.update(state_dict['itr'])

        out = {'G_loss': float(loss.item()),
               'D_loss_real': float(rot_loss.item()),
               'D_loss_fake': float(s2l_loss.item())}
        # Return G's loss and the components of D's loss.
        return out
Example #6
0
    def train(x, y):
        G.optim.zero_grad()
        D.optim.zero_grad()
        # How many chunks to split x and y into?
        x = torch.split(x, config['batch_size'])
        y = torch.split(y, config['batch_size'])
        counter = 0

        # Optionally toggle D and G's "require_grad"
        if config['toggle_grads']:
            utils.toggle_grad(D, True)
            utils.toggle_grad(G, False)

        for step_index in range(config['num_D_steps']):
            # If accumulating gradients, loop multiple times before an optimizer step
            D.optim.zero_grad()
            for accumulation_index in range(config['num_D_accumulations']):
                z_.sample_()
                D_fake, D_real = GD(z_[:config['batch_size']], y_[:config['batch_size']],
                                    x[counter], y[counter], train_G=False,
                                    split_D=config['split_D'])

                # Compute components of D's loss, average them, and divide by
                # the number of gradient accumulations
                D_loss_real, D_loss_fake = losses.discriminator_loss(D_fake, D_real)
                D_loss = (D_loss_real + D_loss_fake) / float(config['num_D_accumulations'])
                D_loss.backward()
                counter += 1

            # Optionally apply ortho reg in D
            if config['D_ortho'] > 0.0:
                # Debug print to indicate we're using ortho reg in D.
                print('using modified ortho reg in D')
                utils.ortho(D, config['D_ortho'])

            D.optim.step()

        # Optionally toggle "requires_grad"
        if config['toggle_grads']:
            utils.toggle_grad(D, False)
            utils.toggle_grad(G, True)

        # Zero G's gradients by default before training G, for safety
        G.optim.zero_grad()

        # If accumulating gradients, loop multiple times
        for accumulation_index in range(config['num_G_accumulations']):
            z_.sample_()
            D_fake = GD(z_, y_, train_G=True, split_D=config['split_D'])
            G_loss = losses.generator_loss(D_fake) / float(config['num_G_accumulations'])
            G_loss.backward()

        # Optionally apply modified ortho reg in G
        if config['G_ortho'] > 0.0:
            print('using modified ortho reg in G')  # Debug print to indicate we're using ortho reg in G
            # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this
            utils.ortho(G, config['G_ortho'],
                        blacklist=[param for param in G.shared.parameters()])
        G.optim.step()

        # If we have an ema, update it, regardless of if we test with it or not
        if config['ema']:
            ema.update(state_dict['itr'])

        out = {'G_loss': float(G_loss.item()),
               'D_loss_real': float(D_loss_real.item()),
               'D_loss_fake': float(D_loss_fake.item())}
        # Return G's loss and the components of D's loss.
        return out
Example #7
0
    def train(x):
        G.optim.zero_grad()
        D.optim.zero_grad()
        I.optim.zero_grad()
        E.optim.zero_grad()
        L.optim.zero_grad()
        # How many chunks to split x and y into?
        x = torch.split(x, config['batch_size'])
        counter = 0

        # Optionally toggle D and G's "require_grad"
        if config['toggle_grads']:
            utils.toggle_grad(D, True)
            utils.toggle_grad(L, True)
            utils.toggle_grad(G, False)
            utils.toggle_grad(I, False)
            utils.toggle_grad(E, False)

        for step_index in range(config['num_D_steps']):
            # If accumulating gradients, loop multiple times before an optimizer step
            D.optim.zero_grad()
            L.optim.zero_grad()
            for accumulation_index in range(config['num_D_accumulations']):
                z_.sample_()
                y_.sample_()
                ey_.sample_()
                D_fake, D_real, D_inv, D_en, _, _ = Decoder(
                    z_[:config['batch_size']],
                    y_[:config['batch_size']],
                    x[counter],
                    ey_[:config['batch_size']],
                    train_G=False,
                    split_D=config['split_D'])

                # Compute components of D's loss, average them, and divide by
                # the number of gradient accumulations
                D_loss_real, D_loss_fake = losses.discriminator_loss(
                    D_fake, D_real)
                Latent_loss = losses.latent_loss_dis(D_inv, D_en)
                D_loss = (D_loss_real + D_loss_fake + Latent_loss) / float(
                    config['num_D_accumulations'])
                D_loss.backward()
                counter += 1

            # Optionally apply ortho reg in D
            if config['D_ortho'] > 0.0:
                # Debug print to indicate we're using ortho reg in D.
                print('using modified ortho reg in D and Latent_Binder')
                utils.ortho(D, config['D_ortho'])
                utils.ortho(L, config['L_ortho'])

            D.optim.step()
            L.optim.step()

        # Optionally toggle "requires_grad"
        if config['toggle_grads']:
            utils.toggle_grad(D, False)
            utils.toggle_grad(L, False)
            utils.toggle_grad(G, True)
            utils.toggle_grad(I, True)
            utils.toggle_grad(E, True)

        # Zero G's gradients by default before training G, for safety
        G.optim.zero_grad()
        I.optim.zero_grad()
        E.optim.zero_grad()
        counter = 0

        # If accumulating gradients, loop multiple times
        for accumulation_index in range(config['num_G_accumulations']):
            z_.sample_()
            y_.sample_()
            ey_.sample_()
            D_fake, _, D_inv, D_en, G_en, reals = Decoder(
                z_,
                y_,
                x[counter],
                ey_,
                train_G=True,
                split_D=config['split_D'])
            G_loss_fake = losses.generator_loss(
                D_fake) * config['adv_loss_scale']
            Latent_loss = losses.latent_loss_gen(D_inv, D_en)
            Recon_loss = losses.recon_loss(G_en, reals)
            G_loss = (G_loss_fake + Latent_loss + Recon_loss) / float(
                config['num_G_accumulations'])
            G_loss.backward()
            counter += 1

        # Optionally apply modified ortho reg in G
        if config['G_ortho'] > 0.0:
            print('using modified ortho reg in G, Invert, and Encoder')
            # Debug print to indicate we're using ortho reg in G
            # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this
            utils.ortho(G,
                        config['G_ortho'],
                        blacklist=[param for param in G.shared.parameters()])
            utils.ortho(E, config['E_ortho'])
            utils.ortho(I, config['I_ortho'])
        G.optim.step()
        I.optim.step()
        E.optim.step()

        # If we have an ema, update it, regardless of if we test with it or not
        if config['ema']:
            for ema in ema_list:
                ema.update(state_dict['itr'])

        out = {
            'G_loss': float(G_loss.item()),
            'D_loss_real': float(D_loss_real.item()),
            'D_loss_fake': float(D_loss_fake.item()),
            'Latent_loss': float(Latent_loss.item()),
            'Recon_loss': float(Recon_loss.item())
        }

        # Release GPU memory:
        del G_loss, D_loss_real, D_loss_fake, Latent_loss, Recon_loss
        del D_fake, D_real, D_inv, D_en, G_en, reals
        del x

        # Return G's loss and the components of D's loss.
        return out