def run(config):
    # Get loader
    config['drop_last'] = False
    loaders = utils.get_data_loaders(**config)

    # Load inception net
    net = inception_utils.load_inception_net(parallel=config['parallel'])
    pool, logits, labels = [], [], []
    device = 'cuda'
    for i, (x, y) in enumerate(tqdm(loaders[0])):
        x = x.to(device)
        with torch.no_grad():
            pool_val, logits_val = net(x)
            pool += [np.asarray(pool_val.cpu())]
            logits += [np.asarray(F.softmax(logits_val, 1).cpu())]
            labels += [np.asarray(y.cpu())]

    pool, logits, labels = [
        np.concatenate(item, 0) for item in [pool, logits, labels]
    ]
    # uncomment to save pool, logits, and labels to disk
    # print('Saving pool, logits, and labels to disk...')
    # np.savez(config['dataset']+'_inception_activations.npz',
    #           {'pool': pool, 'logits': logits, 'labels': labels})
    # Calculate inception metrics and report them
    print('Calculating inception metrics...')
    IS_mean, IS_std = inception_utils.calculate_inception_score(logits)
    print('Training data from dataset %s has IS of %5.5f +/- %5.5f' %
          (config['dataset'], IS_mean, IS_std))
    # Prepare mu and sigma, save to disk. Remove "hdf5" by default
    # (the FID code also knows to strip "hdf5")
    print('Calculating means and covariances...')
    mu, sigma = np.mean(pool, axis=0), np.cov(pool, rowvar=False)
    print('Saving calculated means and covariances to disk...')
    np.savez(config['dataset'].strip('_hdf5') + '_inception_moments.npz', **{
        'mu': mu,
        'sigma': sigma
    })
Example #2
0
def run(config):
    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = vae_utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = import_module('Network.' + config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)
    L = model.LatentBinder(**config).to(device)
    I = Invert.Invert(**config).to(device)
    E = Encoder.Encoder(**config).to(device)
    Decoder = model.Decoder(I, E, G, D, L).to(device)

    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(name='G_ema',
                                **{
                                    **config, 'skip_init': True,
                                    'no_optim': True
                                }).to(device)
        gema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
        print('Preparing EMA for Invert with decay of {}'.format(
            config['ema_decay']))
        I_ema = Invert.Invert(name='Invert_ema',
                              **{
                                  **config, 'skip_init': True,
                                  'no_optim': True
                              }).to(device)
        iema = utils.ema(I, I_ema, config['ema_decay'], config['ema_start'])
        print('Preparing EMA for Encoder with decay of {}'.format(
            config['ema_decay']))
        E_ema = Encoder.Encoder(name='Encoder_ema',
                                **{
                                    **config, 'skip_init': True,
                                    'no_optim': True
                                }).to(device)
        eema = utils.ema(E, E_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, gema, I_ema, iema, E_ema, eema = None, None, None, None, None, None

    # FP16? We should also half other components of Deocer, but as we will not use FP16, we simply
    # not implement this.
    if config['G_fp16']:
        print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        # Consider automatically reducing SN_eps?
    print(G)
    print(D)
    print(I)
    print(E)
    print(L)
    print(
        'Number of params in G: {} D: {} Invert: {} Encoder: {} LatentBinder: {}'
        .format(*[
            sum([p.data.nelement() for p in net.parameters()])
            for net in [G, D, I, E, L]
        ]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        vae_utils.load_weights(
            [G, D, I, E, L], state_dict, config['weights_root'],
            experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            [G_ema, I_ema, E_ema] if config['ema'] else None)

    # If parallel, parallelize the GD module
    if config['parallel']:
        # Decoder = nn.DataParallel(Decoder)
        # Using custom dataparallel to save GPU memory
        Decoder = parallel_utils.DataParallelModel(Decoder)
        if config['cross_replica']:
            patch_replication_callback(Decoder)

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                              experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])
    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)
    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    loaders = utils.get_data_loaders(**{
        **config, 'batch_size': D_batch_size,
        'start_itr': state_dict['itr']
    })

    # Prepare inception metrics: FID and IS
    get_inception_metrics = inception_utils.prepare_inception_metrics(
        config['dataset'], config['parallel'], config['data_root'],
        config['no_fid'])
    # Prepare vgg for recon_loss, considering loss is parallel, it's no need for vgg to be parallel
    # vgg is pretrained on imagenet, so we cannot use it.
    # vgg = load_vgg_from_local(parallel=False)
    # Prepare KNN for evaluating encoder.
    KNN = vae_utils.KNN(loaders[0], anchor_num=10, K=4)
    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])
    # Prepare fake labels for encoder.
    _, ey_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])
    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])
    fixed_x, _ = vae_utils.prepare_fixed_x(loaders[0], G_batch_size, config,
                                           experiment_name, device)
    fixed_z.sample_()
    fixed_y.sample_()
    # Loaders are loaded, prepare the training function
    if config['which_train_fn'] == 'GAN':
        # train = train_vae_fns.VAE_training_function(G, D, E, I, L, Decoder, z_, y_, ey_,
        #                                             [gema, iema, eema], state_dict, vgg, config)
        train = train_vae_fns.parallel_training_function(
            G, D, E, I, L, Decoder, z_, y_, ey_, [gema, iema, eema],
            state_dict, config)
    # Else, assume debugging and use the dummy train fn
    else:
        train = train_vae_fns.dummy_training_function()
    # Prepare Sample function for use with inception metrics
    sample = functools.partial(
        vae_utils.sample,
        Invert=(I_ema if config['ema'] and config['use_ema'] else I),
        G=(G_ema if config['ema'] and config['use_ema'] else G),
        z_=z_,
        y_=y_,
        config=config)

    print('Beginning training at epoch %d...' % state_dict['epoch'])
    # Train for specified number of epochs, although we mostly track G iterations.
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        # Which progressbar to use? TQDM or my own?
        if config['pbar'] == 'mine':
            pbar = utils.progress(loaders[0],
                                  displaytype='s1k' if
                                  config['use_multiepoch_sampler'] else 'eta')
        else:
            pbar = tqdm(loaders[0])
        for i, (x, y) in enumerate(pbar):
            # Increment the iteration counter
            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter much.
            G.train()
            D.train()
            I.train()
            E.train()
            L.train()
            if config['ema']:
                G_ema.train()
                I_ema.train()
                E_ema.train()
            if config['D_fp16']:
                x, y = x.to(device).half(), y.to(device)
            else:
                x, y = x.to(device), y.to(device)
            metrics = train(x)
            train_log.log(itr=int(state_dict['itr']), **metrics)

            # Every sv_log_interval, log singular values
            if (config['sv_log_interval'] > 0) and (
                    not (state_dict['itr'] % config['sv_log_interval'])):
                train_log.log(itr=int(state_dict['itr']),
                              **{
                                  **utils.get_SVs(G, 'G'),
                                  **utils.get_SVs(D, 'D'),
                                  **utils.get_SVs(I, 'Invert'),
                                  **utils.get_SVs(E, 'Encoder'),
                                  **utils.get_SVs(L, 'LatentBinder')
                              })

            # If using my progbar, print metrics.
            if config['pbar'] == 'mine':
                print(', '.join(
                    ['itr: %d' % state_dict['itr']] +
                    ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                      end=' ')

            # Save weights and copies as configured at specified interval
            if not (state_dict['itr'] % config['save_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                    I.eval()
                    E.eval()
                    if config['ema']:
                        G_ema.eval()
                        I_ema.eval()
                        E_ema.eval()
                train_vae_fns.save_and_sample(G, D, E, I, L, G_ema, I_ema,
                                              E_ema, z_, y_, fixed_z, fixed_y,
                                              fixed_x, state_dict, config,
                                              experiment_name)

            # Test every specified interval
            if not (state_dict['itr'] % config['test_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                    I.eval()
                    E.eval()
                train_vae_fns.test(G, D, E, I, L, KNN, G_ema, I_ema, E_ema, z_,
                                   y_, state_dict, config, sample,
                                   get_inception_metrics, experiment_name,
                                   test_log)
        # Increment epoch counter at end of epoch
        state_dict['epoch'] += 1
Example #3
0
def run(config):
    if 'hdf5' in config['dataset']:
        raise ValueError(
            'Reading from an HDF5 file which you will probably be '
            'about to overwrite! Override this error only if you know '
            'what you'
            're doing!')
    # Get image size
    config['image_size'] = utils.imsize_dict[config['dataset']]

    # Update compression entry
    config['compression'] = 'lzf' if config[
        'compression'] else None  #No compression; can also use 'lzf'

    # Get dataset
    kwargs = {
        'num_workers': config['num_workers'],
        'pin_memory': False,
        'drop_last': False
    }
    train_loader = utils.get_data_loaders(dataset=config['dataset'],
                                          batch_size=config['batch_size'],
                                          shuffle=False,
                                          data_root=config['data_root'],
                                          use_multiepoch_sampler=False,
                                          **kwargs)[0]

    # HDF5 supports chunking and compression. You may want to experiment
    # with different chunk sizes to see how it runs on your machines.
    # Chunk Size/compression     Read speed @ 256x256   Read speed @ 128x128  Filesize @ 128x128    Time to write @128x128
    # 1 / None                   20/s
    # 500 / None                 ramps up to 77/s       102/s                 61GB                  23min
    # 500 / LZF                                         8/s                   56GB                  23min
    # 1000 / None                78/s
    # 5000 / None                81/s
    # auto:(125,1,16,32) / None                         11/s                  61GB

    print(
        'Starting to load %s into an HDF5 file with chunk size %i and compression %s...'
        % (config['dataset'], config['chunk_size'], config['compression']))
    # Loop over train loader
    for i, (x, y) in enumerate(tqdm(train_loader)):
        # Stick X into the range [0, 255] since it's coming from the train loader
        x = (255 * ((x + 1) / 2.0)).byte().numpy()
        # Numpyify y
        y = y.numpy()
        # If we're on the first batch, prepare the hdf5
        if i == 0:
            with h5.File(
                    config['data_root'] +
                    '/ILSVRC%i.hdf5' % config['image_size'], 'w') as f:
                print('Producing dataset of len %d' %
                      len(train_loader.dataset))
                imgs_dset = f.create_dataset(
                    'imgs',
                    x.shape,
                    dtype='uint8',
                    maxshape=(len(train_loader.dataset), 3,
                              config['image_size'], config['image_size']),
                    chunks=(config['chunk_size'], 3, config['image_size'],
                            config['image_size']),
                    compression=config['compression'])
                print('Image chunks chosen as ' + str(imgs_dset.chunks))
                imgs_dset[...] = x
                labels_dset = f.create_dataset(
                    'labels',
                    y.shape,
                    dtype='int64',
                    maxshape=(len(train_loader.dataset), ),
                    chunks=(config['chunk_size'], ),
                    compression=config['compression'])
                print('Label chunks chosen as ' + str(labels_dset.chunks))
                labels_dset[...] = y
        # Else append to the hdf5
        else:
            with h5.File(
                    config['data_root'] +
                    '/ILSVRC%i.hdf5' % config['image_size'], 'a') as f:
                f['imgs'].resize(f['imgs'].shape[0] + x.shape[0], axis=0)
                f['imgs'][-x.shape[0]:] = x
                f['labels'].resize(f['labels'].shape[0] + y.shape[0], axis=0)
                f['labels'][-y.shape[0]:] = y
Example #4
0
def run(config):
    timer = vae_utils.Timer()

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = vae_utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    E = Encoder(**{**config, 'arch': 'default'}).to(device)
    Out = Encoder(**{**config, 'arch': 'out'}).to(device)

    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        E_ema = Encoder(**{
            **config, 'skip_init': True,
            'no_optim': True,
            'arch': 'default'
        }).to(device)
        O_ema = Encoder(**{
            **config, 'skip_init': True,
            'no_optim': True,
            'arch': 'out'
        }).to(device)
        eema = utils.ema(E, E_ema, config['ema_decay'], config['ema_start'])
        oema = utils.ema(Out, O_ema, config['ema_decay'], config['ema_start'])
    else:
        E_ema, eema, O_ema, oema = None, None, None, None

    print(E)
    print(Out)
    print('Number of params in E: {}'.format(*[
        sum([p.data.nelement() for p in net.parameters()]) for net in [E, Out]
    ]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config,
        'best_precise': 0.0
    }

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        vae_utils.load_weights(
            [E, Out], state_dict, config['weights_root'], experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            [E_ema, O_ema] if config['ema'] else [None])

    class Wrapper(nn.Module):
        def __init__(self):
            super(Wrapper, self).__init__()
            self.E = E
            self.O = Out

        def forward(self, x):
            x = self.E(x)
            x = self.O(x)
            return x

    W = Wrapper()

    # If parallel, parallelize the GD module
    if config['parallel']:
        W = nn.DataParallel(W)
        if config['cross_replica']:
            patch_replication_callback(W)

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                              experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])
    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)
    # Batch size for dataloader, prefetch 8 times batch
    batch_size = config['batch_size'] * config['num_D_steps'] * config[
        'num_D_accumulations']

    eval_loader = utils.get_data_loaders(**{
        **config, 'load_in_mem': False,
        'use_multiepoch_sampler': False
    })[0]
    dense_eval = vae_utils.dense_eval(2048, config['n_classes'],
                                      steps=5).to(device)
    eval_fn = functools.partial(vae_utils.eval_encoder,
                                sample_batch=10,
                                config=config,
                                loader=eval_loader,
                                dense_eval=dense_eval,
                                device=device)

    E_scheduler = torch.optim.lr_scheduler.StepLR(E.optim,
                                                  step_size=2,
                                                  gamma=0.1)
    O_scheduler = torch.optim.lr_scheduler.StepLR(Out.optim,
                                                  step_size=2,
                                                  gamma=0.1)

    def train(w, img):
        E.optim.zero_grad()
        Out.optim.zero_grad()
        w_ = W(img)
        loss = F.mse_loss(w_, w, reduction='mean')
        loss.backward()
        if config['E_ortho'] > 0.0:
            # Debug print to indicate we're using ortho reg in D.
            print('using modified ortho reg in E')
            utils.ortho(E, config['E_ortho'])
            utils.ortho(Out, config['E_ortho'])
        E.optim.step()
        Out.optim.step()
        out = {' loss': float(loss.item())}
        if config['ema']:
            for ema in [eema, oema]:
                ema.update(state_dict['itr'])
        del w_, loss
        return out

    start, end = sampled_ssgan.make_dset_range(config['ssgan_sample_root'],
                                               config['ssgan_piece'],
                                               batch_size)
    timer.update()
    print(
        'Beginning training at epoch %d (runing time %02d day %02d h %02d min %02d sec) ...'
        % ((state_dict['epoch'], ) + timer.runing_time))
    # Train for specified number of epochs, although we mostly track G iterations.
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        for piece in range(config['ssgan_piece']):
            timer.update()
            print(
                'Load %d-th piece of ssgan sample into memory (runing time %02d day %02d h %02d min %02d sec)...'
                % ((piece, ) + timer.runing_time))
            loader = sampled_ssgan.get_SSGAN_sample_loader(
                **{
                    **config, 'batch_size': batch_size,
                    'start_itr': state_dict['itr'],
                    'start': start[piece],
                    'end': end[piece]
                })
            for _ in range(200):
                for i, (img, z, w) in enumerate(loader):
                    # Increment the iteration counter
                    state_dict['itr'] += 1
                    # Make sure G and D are in training mode, just in case they got set to eval
                    # For D, which typically doesn't have BN, this shouldn't matter much.
                    E.train()
                    Out.train()
                    if config['ema']:
                        E_ema.train()
                        O_ema.train()

                    img, w = img.to(device), w.to(device)
                    counter = 0
                    img = torch.split(img, config['batch_size'])
                    w = torch.split(w, config['batch_size'])
                    metrics = train(w[counter], img[counter])
                    counter += 1
                    del img, w

                    train_log.log(itr=int(state_dict['itr']), **metrics)

                    if not (state_dict['itr'] % 100):
                        timer.update()
                        print(
                            "Runing time %02d day %02d h %02d min %02d sec," %
                            timer.runing_time +
                            ', '.join(['itr: %d' % state_dict['itr']] + [
                                '%s : %+4.3f' % (key, metrics[key])
                                for key in metrics
                            ]))

                    # Save weights and copies as configured at specified interval
                    if not (state_dict['itr'] % config['save_every']):
                        if config['G_eval_mode']:
                            print('Switchin E to eval mode...')
                            E.eval()
                            if config['ema']:
                                E_ema.eval()
                        sampled_ssgan.save_and_eavl(E, Out, E_ema, O_ema,
                                                    state_dict, config,
                                                    experiment_name, eval_fn,
                                                    test_log)
            E_scheduler.step()
            O_scheduler.step()
            del loader
        #  Increment epoch counter at end of epoch
        state_dict['epoch'] += 1
Example #5
0
def run(config):
    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = vae_utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = BigGAN.Generator(**{
        **config, 'skip_init': True,
        'no_optim': True
    }).to(device)
    D = BigGAN.Discriminator(**{
        **config, 'skip_init': True,
        'no_optim': True
    }).to(device)
    E = Encoder(**config).to(device)
    vgg_alter = Encoder(**{
        **config, 'skip_init': True,
        'no_optim': True,
        'name': 'Vgg_alter'
    }).to(device)
    load_pretrained(G, config['pretrained_G_dir'])
    load_pretrained(D, config['pretrained_D_dir'])
    load_pretrained(vgg_alter, config['pretrained_vgg_alter_dir'])

    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        E_ema = Encoder(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        ema = utils.ema(E, E_ema, config['ema_decay'], config['ema_start'])
    else:
        E_ema, ema = None, None

    class TrainWarpper(nn.Module):
        def __init__(self):
            super(TrainWarpper, self).__init__()
            self.G = G
            self.D = D
            self.E = E
            self.vgg_alter = vgg_alter

        def forward(self, img, label):
            en_w = self.E(img)
            with torch.no_grad():
                fake = self.G(en_w, self.G.shared(label))
                logits = self.D(fake, label)
                vgg_logits = F.l1_loss(self.vgg_alter(img),
                                       self.vgg_alter(fake))
            return fake, logits, vgg_logits

    Wrapper = TrainWarpper()
    print(G)
    print(D)
    print(E)
    print(vgg_alter)
    print('Number of params in G: {} D: {} E: {} Vgg_alter: {}'.format(*[
        sum([p.data.nelement() for p in net.parameters()])
        for net in [G, D, E, vgg_alter]
    ]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        vae_utils.load_weights(
            [E], state_dict, config['weights_root'], experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            [E_ema if config['ema'] else None])

    # If parallel, parallelize the GD module
    if config['parallel']:
        Wrapper = nn.DataParallel(Wrapper)
        if config['cross_replica']:
            patch_replication_callback(Wrapper)

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                              experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])
    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)
    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])

    loaders = utils.get_data_loaders(**{
        **config, 'batch_size': D_batch_size,
        'start_itr': state_dict['itr']
    })
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    fixed_x, fixed_y = vae_utils.prepare_fixed_x(loaders[0], G_batch_size,
                                                 config, experiment_name,
                                                 device)

    # Prepare noise and randomly sampled label arrays

    def train(img, label):
        E.optim.zero_grad()
        img = torch.split(img, config['batch_size'])
        label = torch.split(label, config['batch_size'])
        counter = 0

        for step_index in range(config['num_D_steps']):
            E.optim.zero_grad()
            fake, logits, vgg_loss = Wrapper(img[counter], label[counter])
            vgg_loss = vgg_loss * config['vgg_loss_scale']
            d_loss = losses.generator_loss(logits) * config['adv_loss_scale']
            recon_loss = losses.recon_loss(
                fakes=fake, reals=img[counter]) * config['recon_loss_scale']
            loss = d_loss + recon_loss + vgg_loss
            loss.backward()
            counter += 1
            if config['E_ortho'] > 0.0:
                # Debug print to indicate we're using ortho reg in D.
                print('using modified ortho reg in D')
                utils.ortho(D, config['E_ortho'])
            E.optim.step()

        out = {
            'Vgg_loss': float(vgg_loss.item()),
            'D_loss': float(d_loss.item()),
            'pixel_loss': float(recon_loss.item())
        }
        return out

    print('Beginning training at epoch %d...' % state_dict['epoch'])
    # Train for specified number of epochs, although we mostly track G iterations.
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        # Which progressbar to use? TQDM or my own?
        if config['pbar'] == 'mine':
            pbar = utils.progress(loaders[0],
                                  displaytype='s1k' if
                                  config['use_multiepoch_sampler'] else 'eta')
        else:
            pbar = tqdm(loaders[0])
        for i, (x, y) in enumerate(pbar):
            # Increment the iteration counter
            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter much.
            G.train()
            D.train()
            E.train()
            vgg_alter.train()
            if config['ema']:
                E_ema.train()
            if config['D_fp16']:
                x, y = x.to(device).half(), y.to(device)
            else:
                x, y = x.to(device), y.to(device)
            metrics = train(x, y)
            train_log.log(itr=int(state_dict['itr']), **metrics)

            # Every sv_log_interval, log singular values
            if (config['sv_log_interval'] > 0) and (
                    not (state_dict['itr'] % config['sv_log_interval'])):
                train_log.log(itr=int(state_dict['itr']),
                              **{**utils.get_SVs(E, 'E')})

            # If using my progbar, print metrics.
            if config['pbar'] == 'mine':
                print(', '.join(
                    ['itr: %d' % state_dict['itr']] +
                    ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                      end=' ')

            # Save weights and copies as configured at specified interval
            if not (state_dict['itr'] % config['save_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                    E.eval()
                    if config['ema']:
                        E_ema.eval()
                save_and_sample(G, E, E_ema, fixed_x, fixed_y, state_dict,
                                config, experiment_name)
        # Increment epoch counter at end of epoch
        state_dict['epoch'] += 1