Beispiel #1
0
def save_and_eavl(E,
                  Out,
                  E_ema,
                  O_ema,
                  state_dict,
                  config,
                  experiment_name,
                  eval_fn=None,
                  test_log=None):
    if config['num_save_copies'] > 0:
        vae_utils.save_weights([E, Out], state_dict, config['weights_root'],
                               experiment_name,
                               'copy%d' % state_dict['save_num'],
                               [E_ema, O_ema] if config['ema'] else [None])
        state_dict['save_num'] = (state_dict['save_num'] +
                                  1) % config['num_save_copies']
    if eval_fn is not None:
        which_E = E_ema if config['ema'] and config['use_ema'] else E
        precise = eval_fn(which_E)
        if precise > state_dict['best_precise']:
            print(
                'KNN precise improved over previous best, saving checkpoint...'
            )
            vae_utils.save_weights([E], state_dict, config['weights_root'],
                                   experiment_name,
                                   'best%d' % state_dict['save_best_num'],
                                   [E_ema if config['ema'] else None])
            state_dict['save_best_num'] = (state_dict['save_best_num'] +
                                           1) % config['num_best_copies']
        state_dict['best_precise'] = max(state_dict['best_precise'], precise)

        test_log.log(itr=int(state_dict['itr']), precise=float(precise))
Beispiel #2
0
def save_and_sample(G, E, E_ema, fixed_x, fixed_y, state_dict, config,
                    experiment_name):
    vae_utils.save_weights([E], state_dict, config['weights_root'],
                           experiment_name, None,
                           [E_ema if config['ema'] else None])
    if config['num_save_copies'] > 0:
        vae_utils.save_weights([E], state_dict, config['weights_root'],
                               experiment_name,
                               'copy%d' % state_dict['save_num'],
                               [E_ema if config['ema'] else None])
        state_dict['save_num'] = (state_dict['save_num'] +
                                  1) % config['num_save_copies']
        G_batch_size = max(config['G_batch_size'], config['batch_size'])
        z_, y_ = utils.prepare_z_y(G_batch_size,
                                   G.dim_z,
                                   config['n_classes'],
                                   device='cuda',
                                   fp16=config['G_fp16'])
        utils.accumulate_standing_stats(G, z_, y_, config['n_classes'],
                                        config['num_standing_accumulations'])
        del z_, y_
        which_E = E_ema if config['ema'] and config['use_ema'] else E
        with torch.no_grad():
            if config['parallel']:
                fixed_w = nn.parallel.data_parallel(which_E, fixed_x)
                fixed_Gz = nn.parallel.data_parallel(
                    G, (fixed_w, G.shared(fixed_y)))
            else:
                fixed_w = which_E(fixed_x)
                fixed_Gz = G(fixed_w, G.shared(fixed_y))
        if not os.path.isdir('%s/%s' %
                             (config['samples_root'], experiment_name)):
            os.mkdir('%s/%s' % (config['samples_root'], experiment_name))
        image_filename = '%s/%s/fixed_samples%d.jpg' % (
            config['samples_root'], experiment_name, state_dict['itr'])
        torchvision.utils.save_image(fixed_Gz.float().cpu(),
                                     image_filename,
                                     nrow=int(fixed_Gz.shape[0]**0.5),
                                     normalize=True)
Beispiel #3
0
def test(G, D, E, I, L, KNN, G_ema, I_ema, E_ema, z_, y_, state_dict, config,
         sample, get_inception_metrics, experiment_name, test_log):
    print('Gathering inception metrics...')
    if config['accumulate_stats']:
        vae_utils.accumulate_standing_stats(
            [G_ema, I_ema, E_ema]
            if config['ema'] and config['use_ema'] else [G, I, E], z_, y_,
            config['n_classes'], config['num_standing_accumulations'])
    IS_mean, IS_std, FID = get_inception_metrics(
        sample, config['num_inception_images'], num_splits=10)
    print(
        'Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f'
        % (state_dict['itr'], IS_mean, IS_std, FID))
    # If improved over previous best metric, save approrpiate copy
    if ((config['which_best'] == 'IS' and IS_mean > state_dict['best_IS']) or
        (config['which_best'] == 'FID' and FID < state_dict['best_FID'])):
        print('%s improved over previous best, saving checkpoint...' %
              config['which_best'])
        vae_utils.save_weights(
            [G, D, E, I, L], state_dict, config['weights_root'],
            experiment_name, 'best%d' % state_dict['save_best_num'],
            [G_ema, I_ema, E_ema] if config['ema'] else None)
        state_dict['save_best_num'] = (state_dict['save_best_num'] +
                                       1) % config['num_best_copies']
    state_dict['best_IS'] = max(state_dict['best_IS'], IS_mean)
    state_dict['best_FID'] = min(state_dict['best_FID'], FID)
    if KNN is not None:
        KNN_precision = KNN(
            E_ema if config['ema'] and config['use_ema'] else E)
    else:
        KNN_precision = 0.0
    # Log results to file
    test_log.log(itr=int(state_dict['itr']),
                 IS_mean=float(IS_mean),
                 IS_std=float(IS_std),
                 FID=float(FID),
                 KNN_precision=float(KNN_precision))
Beispiel #4
0
def run(config):
    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = vae_utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = Generator(**{**config, 'skip_init': True, 'no_optim': True}).to(device)
    print('Loading pretrained G for dir %s ...' % config['pretrained_G_dir'])
    pretrained_dict = torch.load(config['pretrained_G_dir'])
    G_dict = G.state_dict()
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in G_dict}
    G_dict.update(pretrained_dict)
    G.load_state_dict(G_dict)

    E = Encoder(**config).to(device)
    utils.toggle_grad(G, False)
    utils.toggle_grad(E, True)

    class G_E(nn.Module):
        def __init__(self):
            super(G_E, self).__init__()
            self.G = G
            self.E = E

        def forward(self, w, y):
            with torch.no_grad():
                net = self.G(w, self.G.shared(y))
            net = self.E(net)
            return net

    GE = G_E()

    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for E with decay of {}'.format(
            config['ema_decay']))
        E_ema = Encoder(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        e_ema = utils.ema(E, E_ema, config['ema_decay'], config['ema_start'])
    else:
        E_ema, e_ema = None, None

    print(G)
    print(E)
    print('Number of params in G: {} E: {}'.format(
        *
        [sum([p.data.nelement() for p in net.parameters()])
         for net in [G, E]]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        vae_utils.load_weights(
            [E], state_dict, config['weights_root'], experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            [e_ema] if config['ema'] else None)

    # If parallel, parallelize the GD module
    if config['parallel']:
        GE = nn.DataParallel(GE)
        if config['cross_replica']:
            patch_replication_callback(GE)

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])
    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)

    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])

    def train():
        E.optim.zero_grad()
        z_.sample_()
        y_.sample_()

        net = GE(z_[:config['batch_size']], y_[:config['batch_size']])
        loss = F.l1_loss(z_[:config['batch_size']], net)
        loss.backward()
        if config["E_ortho"] > 0.0:
            print('using modified ortho reg in E')
            utils.ortho(E, config['E_ortho'])
        E.optim.step()
        out = {'loss': float(loss.item())}
        return out

    print('Beginning training at epoch %d...' % state_dict['epoch'])
    # Train for specified number of epochs, although we mostly track G iterations.
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        for i in range(100000):
            # Increment the iteration counter
            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter much.
            G.train()
            E.train()
            if config['ema']:
                E_ema.train()
            metrics = train()
            train_log.log(itr=int(state_dict['itr']), **metrics)

            # Every sv_log_interval, log singular values
            if (config['sv_log_interval'] > 0) and (
                    not (state_dict['itr'] % config['sv_log_interval'])):
                train_log.log(itr=int(state_dict['itr']),
                              **{
                                  **utils.get_SVs(G, 'G'),
                                  **utils.get_SVs(E, 'E')
                              })

            # If using my progbar, print metrics.
            if config['pbar'] == 'mine':
                print(', '.join(
                    ['itr: %d' % state_dict['itr']] +
                    ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                      end=' ')

            # Save weights and copies as configured at specified interval
            if not (state_dict['itr'] % config['save_every']):
                vae_utils.save_weights([E], state_dict, config['weights_root'],
                                       experiment_name,
                                       'copy%d' % state_dict['save_num'],
                                       [E_ema if config['ema'] else None])
                state_dict['save_num'] = (state_dict['save_num'] +
                                          1) % config['num_save_copies']
        # Increment epoch counter at end of epoch
        state_dict['epoch'] += 1
Beispiel #5
0
def save_and_sample(G, D, E, I, L, G_ema, I_ema, E_ema, z_, y_, fixed_z,
                    fixed_y, fixed_x, state_dict, config, experiment_name):
    vae_utils.save_weights([G, D, E, I, L], state_dict, config['weights_root'],
                           experiment_name, None,
                           [G_ema, I_ema, E_ema] if config['ema'] else None)
    # Save an additional copy to mitigate accidental corruption if process
    # is killed during a save (it's happened to me before -.-)
    if config['num_save_copies'] > 0:
        vae_utils.save_weights(
            [G, D, E, I, L], state_dict, config['weights_root'],
            experiment_name, 'copy%d' % state_dict['save_num'],
            [G_ema, I_ema, E_ema] if config['ema'] else None)
        state_dict['save_num'] = (state_dict['save_num'] +
                                  1) % config['num_save_copies']

    # Use EMA G for samples or non-EMA?
    if config['ema'] and config['use_ema']:
        which_G, which_E, which_I = G_ema, E_ema, I_ema
    else:
        which_G, which_E, which_I = G, E, I

    # Accumulate standing statistics?
    if config['accumulate_stats']:
        vae_utils.accumulate_standing_stats(
            [which_G, which_I, which_E], z_, y_, config['n_classes'],
            config['num_standing_accumulations'])

    # Save a random sample sheet with fixed z and y
    with torch.no_grad():
        if config['parallel']:
            fixed_inv = nn.parallel.data_parallel(which_I, fixed_z)
            fixed_Gz = nn.parallel.data_parallel(
                which_G, (fixed_inv, which_G.shared(fixed_y)))
            fixed_en = nn.parallel.data_parallel(which_E, fixed_x)
            fixed_Gx = nn.parallel.data_parallel(
                which_G, (fixed_en, which_G.shared(fixed_y)))
        else:
            fixed_inv = which_I(fixed_z)
            fixed_Gz = which_G(fixed_inv, which_G.shared(fixed_y))
            fixed_en = which_E(fixed_x)
            fixed_Gx = which_G(fixed_en, which_G.shared(fixed_y))
    if not os.path.isdir('%s/%s' % (config['samples_root'], experiment_name)):
        os.mkdir('%s/%s' % (config['samples_root'], experiment_name))
    image_filename = '%s/%s/fixed_samples%d.jpg' % (
        config['samples_root'], experiment_name, state_dict['itr'])
    vae_image_filename = '%s/%s/fixed_vae%d.jpg' % (
        config['samples_root'], experiment_name, state_dict['itr'])
    torchvision.utils.save_image(fixed_Gz.float().cpu(),
                                 image_filename,
                                 nrow=int(fixed_Gz.shape[0]**0.5),
                                 normalize=True)
    torchvision.utils.save_image(fixed_Gx.float().cpu(),
                                 vae_image_filename,
                                 nrow=int(fixed_Gx.shape[0]**0.5),
                                 normalize=True)
    # For now, every time we save, also save sample sheets
    vae_utils.sample_sheet(
        which_G,
        which_I,
        classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']],
        num_classes=config['n_classes'],
        samples_per_class=10,
        parallel=config['parallel'],
        samples_root=config['samples_root'],
        experiment_name=experiment_name,
        folder_number=state_dict['itr'],
        z_=z_)
    # Also save interp sheets
    for fix_z, fix_y in zip([False, False, True], [False, True, False]):
        vae_utils.interp_sheet(which_G,
                               which_I,
                               num_per_sheet=16,
                               num_midpoints=8,
                               num_classes=config['n_classes'],
                               parallel=config['parallel'],
                               samples_root=config['samples_root'],
                               experiment_name=experiment_name,
                               folder_number=state_dict['itr'],
                               sheet_number=0,
                               fix_z=fix_z,
                               fix_y=fix_y,
                               device='cuda')