示例#1
0
def load_vgg_from_local(arch='vgg19',
                        cfg='E',
                        batch_norm=False,
                        pretrained=True,
                        vgg_dir=None,
                        parallel=True,
                        requires_grad=False,
                        **kwargs):
    vgg = vgglib.VGG(vgglib.make_layers(cfgs[cfg], batch_norm=batch_norm),
                     **kwargs)
    vgg.load_state_dict(
        model_zoo.load_url(url=VGG_URL,
                           model_dir='/gpub/temp/imagenet2012/hdf5'))
    vgg = (vgg.eval()).cuda()
    if parallel:
        print("Parallel VGG model...")
        vgg = torch.nn.DataParallel(vgg)
    toggle_grad(vgg, requires_grad)
    return vgg
示例#2
0
def run(config):
    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = vae_utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = Generator(**{**config, 'skip_init': True, 'no_optim': True}).to(device)
    print('Loading pretrained G for dir %s ...' % config['pretrained_G_dir'])
    pretrained_dict = torch.load(config['pretrained_G_dir'])
    G_dict = G.state_dict()
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in G_dict}
    G_dict.update(pretrained_dict)
    G.load_state_dict(G_dict)

    E = Encoder(**config).to(device)
    utils.toggle_grad(G, False)
    utils.toggle_grad(E, True)

    class G_E(nn.Module):
        def __init__(self):
            super(G_E, self).__init__()
            self.G = G
            self.E = E

        def forward(self, w, y):
            with torch.no_grad():
                net = self.G(w, self.G.shared(y))
            net = self.E(net)
            return net

    GE = G_E()

    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for E with decay of {}'.format(
            config['ema_decay']))
        E_ema = Encoder(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        e_ema = utils.ema(E, E_ema, config['ema_decay'], config['ema_start'])
    else:
        E_ema, e_ema = None, None

    print(G)
    print(E)
    print('Number of params in G: {} E: {}'.format(
        *
        [sum([p.data.nelement() for p in net.parameters()])
         for net in [G, E]]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        vae_utils.load_weights(
            [E], state_dict, config['weights_root'], experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            [e_ema] if config['ema'] else None)

    # If parallel, parallelize the GD module
    if config['parallel']:
        GE = nn.DataParallel(GE)
        if config['cross_replica']:
            patch_replication_callback(GE)

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])
    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)

    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])

    def train():
        E.optim.zero_grad()
        z_.sample_()
        y_.sample_()

        net = GE(z_[:config['batch_size']], y_[:config['batch_size']])
        loss = F.l1_loss(z_[:config['batch_size']], net)
        loss.backward()
        if config["E_ortho"] > 0.0:
            print('using modified ortho reg in E')
            utils.ortho(E, config['E_ortho'])
        E.optim.step()
        out = {'loss': float(loss.item())}
        return out

    print('Beginning training at epoch %d...' % state_dict['epoch'])
    # Train for specified number of epochs, although we mostly track G iterations.
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        for i in range(100000):
            # Increment the iteration counter
            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter much.
            G.train()
            E.train()
            if config['ema']:
                E_ema.train()
            metrics = train()
            train_log.log(itr=int(state_dict['itr']), **metrics)

            # Every sv_log_interval, log singular values
            if (config['sv_log_interval'] > 0) and (
                    not (state_dict['itr'] % config['sv_log_interval'])):
                train_log.log(itr=int(state_dict['itr']),
                              **{
                                  **utils.get_SVs(G, 'G'),
                                  **utils.get_SVs(E, 'E')
                              })

            # If using my progbar, print metrics.
            if config['pbar'] == 'mine':
                print(', '.join(
                    ['itr: %d' % state_dict['itr']] +
                    ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                      end=' ')

            # Save weights and copies as configured at specified interval
            if not (state_dict['itr'] % config['save_every']):
                vae_utils.save_weights([E], state_dict, config['weights_root'],
                                       experiment_name,
                                       'copy%d' % state_dict['save_num'],
                                       [E_ema if config['ema'] else None])
                state_dict['save_num'] = (state_dict['save_num'] +
                                          1) % config['num_save_copies']
        # Increment epoch counter at end of epoch
        state_dict['epoch'] += 1
示例#3
0
    def train(x, y):
        G.optim.zero_grad()
        D.optim.zero_grad()
        # How many chunks to split x and y into?
        x = torch.split(x, config['batch_size'])
        y = torch.split(y, config['batch_size'])
        counter = 0

        # Optionally toggle D and G's "require_grad"
        if config['toggle_grads']:
            utils.toggle_grad(D, True)
            utils.toggle_grad(G, False)

        for step_index in range(config['num_D_steps']):
            # If accumulating gradients, loop multiple times before an optimizer step
            D.optim.zero_grad()
            for accumulation_index in range(config['num_D_accumulations']):
                z_.sample_()
                D_fake, D_real = GD(z_[:config['batch_size']], y_[:config['batch_size']],
                                    x[counter], y[counter], train_G=False,
                                    split_D=config['split_D'])

                # Compute components of D's loss, average them, and divide by
                # the number of gradient accumulations
                D_loss_real, D_loss_fake = losses.discriminator_loss(D_fake, D_real)
                D_loss = (D_loss_real + D_loss_fake) / float(config['num_D_accumulations'])
                D_loss.backward()
                counter += 1

            # Optionally apply ortho reg in D
            if config['D_ortho'] > 0.0:
                # Debug print to indicate we're using ortho reg in D.
                print('using modified ortho reg in D')
                utils.ortho(D, config['D_ortho'])

            D.optim.step()

        # Optionally toggle "requires_grad"
        if config['toggle_grads']:
            utils.toggle_grad(D, False)
            utils.toggle_grad(G, True)

        # Zero G's gradients by default before training G, for safety
        G.optim.zero_grad()

        # If accumulating gradients, loop multiple times
        for accumulation_index in range(config['num_G_accumulations']):
            z_.sample_()
            D_fake = GD(z_, y_, train_G=True, split_D=config['split_D'])
            G_loss = losses.generator_loss(D_fake) / float(config['num_G_accumulations'])
            G_loss.backward()

        # Optionally apply modified ortho reg in G
        if config['G_ortho'] > 0.0:
            print('using modified ortho reg in G')  # Debug print to indicate we're using ortho reg in G
            # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this
            utils.ortho(G, config['G_ortho'],
                        blacklist=[param for param in G.shared.parameters()])
        G.optim.step()

        # If we have an ema, update it, regardless of if we test with it or not
        if config['ema']:
            ema.update(state_dict['itr'])

        out = {'G_loss': float(G_loss.item()),
               'D_loss_real': float(D_loss_real.item()),
               'D_loss_fake': float(D_loss_fake.item())}
        # Return G's loss and the components of D's loss.
        return out
示例#4
0
    def train(x):
        G.optim.zero_grad()
        D.optim.zero_grad()
        I.optim.zero_grad()
        E.optim.zero_grad()
        L.optim.zero_grad()
        # How many chunks to split x and y into?
        x = torch.split(x, config['batch_size'])
        counter = 0

        # Optionally toggle D and G's "require_grad"
        if config['toggle_grads']:
            utils.toggle_grad(D, True)
            utils.toggle_grad(L, True)
            utils.toggle_grad(G, False)
            utils.toggle_grad(I, False)
            utils.toggle_grad(E, False)

        for step_index in range(config['num_D_steps']):
            # If accumulating gradients, loop multiple times before an optimizer step
            D.optim.zero_grad()
            L.optim.zero_grad()
            for accumulation_index in range(config['num_D_accumulations']):
                z_.sample_()
                y_.sample_()
                ey_.sample_()
                D_fake, D_real, D_inv, D_en, _, _ = Decoder(
                    z_[:config['batch_size']],
                    y_[:config['batch_size']],
                    x[counter],
                    ey_[:config['batch_size']],
                    train_G=False,
                    split_D=config['split_D'])

                # Compute components of D's loss, average them, and divide by
                # the number of gradient accumulations
                D_loss_real, D_loss_fake = losses.discriminator_loss(
                    D_fake, D_real)
                Latent_loss = losses.latent_loss_dis(D_inv, D_en)
                D_loss = (D_loss_real + D_loss_fake + Latent_loss) / float(
                    config['num_D_accumulations'])
                D_loss.backward()
                counter += 1

            # Optionally apply ortho reg in D
            if config['D_ortho'] > 0.0:
                # Debug print to indicate we're using ortho reg in D.
                print('using modified ortho reg in D and Latent_Binder')
                utils.ortho(D, config['D_ortho'])
                utils.ortho(L, config['L_ortho'])

            D.optim.step()
            L.optim.step()

        # Optionally toggle "requires_grad"
        if config['toggle_grads']:
            utils.toggle_grad(D, False)
            utils.toggle_grad(L, False)
            utils.toggle_grad(G, True)
            utils.toggle_grad(I, True)
            utils.toggle_grad(E, True)

        # Zero G's gradients by default before training G, for safety
        G.optim.zero_grad()
        I.optim.zero_grad()
        E.optim.zero_grad()
        counter = 0

        # If accumulating gradients, loop multiple times
        for accumulation_index in range(config['num_G_accumulations']):
            z_.sample_()
            y_.sample_()
            ey_.sample_()
            D_fake, _, D_inv, D_en, G_en, reals = Decoder(
                z_,
                y_,
                x[counter],
                ey_,
                train_G=True,
                split_D=config['split_D'])
            G_loss_fake = losses.generator_loss(
                D_fake) * config['adv_loss_scale']
            Latent_loss = losses.latent_loss_gen(D_inv, D_en)
            Recon_loss = losses.recon_loss(G_en, reals)
            G_loss = (G_loss_fake + Latent_loss + Recon_loss) / float(
                config['num_G_accumulations'])
            G_loss.backward()
            counter += 1

        # Optionally apply modified ortho reg in G
        if config['G_ortho'] > 0.0:
            print('using modified ortho reg in G, Invert, and Encoder')
            # Debug print to indicate we're using ortho reg in G
            # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this
            utils.ortho(G,
                        config['G_ortho'],
                        blacklist=[param for param in G.shared.parameters()])
            utils.ortho(E, config['E_ortho'])
            utils.ortho(I, config['I_ortho'])
        G.optim.step()
        I.optim.step()
        E.optim.step()

        # If we have an ema, update it, regardless of if we test with it or not
        if config['ema']:
            for ema in ema_list:
                ema.update(state_dict['itr'])

        out = {
            'G_loss': float(G_loss.item()),
            'D_loss_real': float(D_loss_real.item()),
            'D_loss_fake': float(D_loss_fake.item()),
            'Latent_loss': float(Latent_loss.item()),
            'Recon_loss': float(Recon_loss.item())
        }

        # Release GPU memory:
        del G_loss, D_loss_real, D_loss_fake, Latent_loss, Recon_loss
        del D_fake, D_real, D_inv, D_en, G_en, reals
        del x

        # Return G's loss and the components of D's loss.
        return out