Пример #1
0
def save_and_sample(G, E, E_ema, fixed_x, fixed_y, state_dict, config,
                    experiment_name):
    vae_utils.save_weights([E], state_dict, config['weights_root'],
                           experiment_name, None,
                           [E_ema if config['ema'] else None])
    if config['num_save_copies'] > 0:
        vae_utils.save_weights([E], state_dict, config['weights_root'],
                               experiment_name,
                               'copy%d' % state_dict['save_num'],
                               [E_ema if config['ema'] else None])
        state_dict['save_num'] = (state_dict['save_num'] +
                                  1) % config['num_save_copies']
        G_batch_size = max(config['G_batch_size'], config['batch_size'])
        z_, y_ = utils.prepare_z_y(G_batch_size,
                                   G.dim_z,
                                   config['n_classes'],
                                   device='cuda',
                                   fp16=config['G_fp16'])
        utils.accumulate_standing_stats(G, z_, y_, config['n_classes'],
                                        config['num_standing_accumulations'])
        del z_, y_
        which_E = E_ema if config['ema'] and config['use_ema'] else E
        with torch.no_grad():
            if config['parallel']:
                fixed_w = nn.parallel.data_parallel(which_E, fixed_x)
                fixed_Gz = nn.parallel.data_parallel(
                    G, (fixed_w, G.shared(fixed_y)))
            else:
                fixed_w = which_E(fixed_x)
                fixed_Gz = G(fixed_w, G.shared(fixed_y))
        if not os.path.isdir('%s/%s' %
                             (config['samples_root'], experiment_name)):
            os.mkdir('%s/%s' % (config['samples_root'], experiment_name))
        image_filename = '%s/%s/fixed_samples%d.jpg' % (
            config['samples_root'], experiment_name, state_dict['itr'])
        torchvision.utils.save_image(fixed_Gz.float().cpu(),
                                     image_filename,
                                     nrow=int(fixed_Gz.shape[0]**0.5),
                                     normalize=True)
Пример #2
0
def run(config):
    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = vae_utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = import_module('Network.' + config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)
    L = model.LatentBinder(**config).to(device)
    I = Invert.Invert(**config).to(device)
    E = Encoder.Encoder(**config).to(device)
    Decoder = model.Decoder(I, E, G, D, L).to(device)

    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(name='G_ema',
                                **{
                                    **config, 'skip_init': True,
                                    'no_optim': True
                                }).to(device)
        gema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
        print('Preparing EMA for Invert with decay of {}'.format(
            config['ema_decay']))
        I_ema = Invert.Invert(name='Invert_ema',
                              **{
                                  **config, 'skip_init': True,
                                  'no_optim': True
                              }).to(device)
        iema = utils.ema(I, I_ema, config['ema_decay'], config['ema_start'])
        print('Preparing EMA for Encoder with decay of {}'.format(
            config['ema_decay']))
        E_ema = Encoder.Encoder(name='Encoder_ema',
                                **{
                                    **config, 'skip_init': True,
                                    'no_optim': True
                                }).to(device)
        eema = utils.ema(E, E_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, gema, I_ema, iema, E_ema, eema = None, None, None, None, None, None

    # FP16? We should also half other components of Deocer, but as we will not use FP16, we simply
    # not implement this.
    if config['G_fp16']:
        print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        # Consider automatically reducing SN_eps?
    print(G)
    print(D)
    print(I)
    print(E)
    print(L)
    print(
        'Number of params in G: {} D: {} Invert: {} Encoder: {} LatentBinder: {}'
        .format(*[
            sum([p.data.nelement() for p in net.parameters()])
            for net in [G, D, I, E, L]
        ]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        vae_utils.load_weights(
            [G, D, I, E, L], state_dict, config['weights_root'],
            experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            [G_ema, I_ema, E_ema] if config['ema'] else None)

    # If parallel, parallelize the GD module
    if config['parallel']:
        # Decoder = nn.DataParallel(Decoder)
        # Using custom dataparallel to save GPU memory
        Decoder = parallel_utils.DataParallelModel(Decoder)
        if config['cross_replica']:
            patch_replication_callback(Decoder)

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                              experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])
    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)
    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    loaders = utils.get_data_loaders(**{
        **config, 'batch_size': D_batch_size,
        'start_itr': state_dict['itr']
    })

    # Prepare inception metrics: FID and IS
    get_inception_metrics = inception_utils.prepare_inception_metrics(
        config['dataset'], config['parallel'], config['data_root'],
        config['no_fid'])
    # Prepare vgg for recon_loss, considering loss is parallel, it's no need for vgg to be parallel
    # vgg is pretrained on imagenet, so we cannot use it.
    # vgg = load_vgg_from_local(parallel=False)
    # Prepare KNN for evaluating encoder.
    KNN = vae_utils.KNN(loaders[0], anchor_num=10, K=4)
    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])
    # Prepare fake labels for encoder.
    _, ey_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])
    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])
    fixed_x, _ = vae_utils.prepare_fixed_x(loaders[0], G_batch_size, config,
                                           experiment_name, device)
    fixed_z.sample_()
    fixed_y.sample_()
    # Loaders are loaded, prepare the training function
    if config['which_train_fn'] == 'GAN':
        # train = train_vae_fns.VAE_training_function(G, D, E, I, L, Decoder, z_, y_, ey_,
        #                                             [gema, iema, eema], state_dict, vgg, config)
        train = train_vae_fns.parallel_training_function(
            G, D, E, I, L, Decoder, z_, y_, ey_, [gema, iema, eema],
            state_dict, config)
    # Else, assume debugging and use the dummy train fn
    else:
        train = train_vae_fns.dummy_training_function()
    # Prepare Sample function for use with inception metrics
    sample = functools.partial(
        vae_utils.sample,
        Invert=(I_ema if config['ema'] and config['use_ema'] else I),
        G=(G_ema if config['ema'] and config['use_ema'] else G),
        z_=z_,
        y_=y_,
        config=config)

    print('Beginning training at epoch %d...' % state_dict['epoch'])
    # Train for specified number of epochs, although we mostly track G iterations.
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        # Which progressbar to use? TQDM or my own?
        if config['pbar'] == 'mine':
            pbar = utils.progress(loaders[0],
                                  displaytype='s1k' if
                                  config['use_multiepoch_sampler'] else 'eta')
        else:
            pbar = tqdm(loaders[0])
        for i, (x, y) in enumerate(pbar):
            # Increment the iteration counter
            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter much.
            G.train()
            D.train()
            I.train()
            E.train()
            L.train()
            if config['ema']:
                G_ema.train()
                I_ema.train()
                E_ema.train()
            if config['D_fp16']:
                x, y = x.to(device).half(), y.to(device)
            else:
                x, y = x.to(device), y.to(device)
            metrics = train(x)
            train_log.log(itr=int(state_dict['itr']), **metrics)

            # Every sv_log_interval, log singular values
            if (config['sv_log_interval'] > 0) and (
                    not (state_dict['itr'] % config['sv_log_interval'])):
                train_log.log(itr=int(state_dict['itr']),
                              **{
                                  **utils.get_SVs(G, 'G'),
                                  **utils.get_SVs(D, 'D'),
                                  **utils.get_SVs(I, 'Invert'),
                                  **utils.get_SVs(E, 'Encoder'),
                                  **utils.get_SVs(L, 'LatentBinder')
                              })

            # If using my progbar, print metrics.
            if config['pbar'] == 'mine':
                print(', '.join(
                    ['itr: %d' % state_dict['itr']] +
                    ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                      end=' ')

            # Save weights and copies as configured at specified interval
            if not (state_dict['itr'] % config['save_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                    I.eval()
                    E.eval()
                    if config['ema']:
                        G_ema.eval()
                        I_ema.eval()
                        E_ema.eval()
                train_vae_fns.save_and_sample(G, D, E, I, L, G_ema, I_ema,
                                              E_ema, z_, y_, fixed_z, fixed_y,
                                              fixed_x, state_dict, config,
                                              experiment_name)

            # Test every specified interval
            if not (state_dict['itr'] % config['test_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                    I.eval()
                    E.eval()
                train_vae_fns.test(G, D, E, I, L, KNN, G_ema, I_ema, E_ema, z_,
                                   y_, state_dict, config, sample,
                                   get_inception_metrics, experiment_name,
                                   test_log)
        # Increment epoch counter at end of epoch
        state_dict['epoch'] += 1
Пример #3
0
def run(config):

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = vae_utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = Generator(**config).to(device)

    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = Generator(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None

    print(G)
    print('Number of params in E: {}'.format(
        *[sum([p.data.nelement() for p in net.parameters()]) for net in [G]]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config,
        'best_precise': 0.0
    }

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        vae_utils.load_weights(
            [G], state_dict, config['weights_root'], experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            [G_ema] if config['ema'] else [None])

    class Wrapper(nn.Module):
        def __init__(self):
            super(Wrapper, self).__init__()
            self.G = G

        def forward(self, w, y):
            x = self.G(w, self.G.shared(y))
            return x

    W = Wrapper()

    # If parallel, parallelize the GD module
    if config['parallel']:
        W = nn.DataParallel(W)
        if config['cross_replica']:
            patch_replication_callback(W)

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                              experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])
    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)

    get_inception_metrics = inception_utils.prepare_inception_metrics(
        config['dataset'], config['parallel'], config['data_root'],
        config['no_fid'])
    z_, y_ = utils.prepare_z_y(config['batch_size'],
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])
    fixed_w, fixed_y = utils.prepare_z_y(config['batch_size'],
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])
    fixed_w.sample_()
    fixed_y.sample_()
    G_scheduler = torch.optim.lr_scheduler.StepLR(G.optim,
                                                  step_size=50,
                                                  gamma=0.1)
    MSE = torch.nn.MSELoss(reduction='mean')

    def train(w, img):
        y_.sample_()
        G.optim.zero_grad()
        x = W(w, y_)
        loss = MSE(x, img)
        loss.backward()
        if config['E_ortho'] > 0.0:
            # Debug print to indicate we're using ortho reg in D.
            print('using modified ortho reg in E')
            utils.ortho(G, config['G_ortho'])
        G.optim.step()
        out = {' loss': float(loss.item())}
        if config['ema']:
            ema.update(state_dict['itr'])
        del loss, x
        return out

    class Embed(nn.Module):
        def __init__(self):
            super(Embed, self).__init__()
            embed = np.load('/ghome/fengrl/home/FGAN/embed_ema.npy')
            self.dense = nn.Linear(120, 120, bias=False)
            self.embed = torch.tensor(embed, requires_grad=False)
            self.dense.load_state_dict({'weight': self.embed})
            for param in self.dense.parameters():
                param.requires_grad = False

        def forward(self, z):
            z = self.dense(z)
            return z

    embedding = Embed().to(device)
    fixed_w = embedding(fixed_w)

    sample = functools.partial(
        sample_with_embed,
        embed=embedding,
        G=(G_ema if config['ema'] and config['use_ema'] else G),
        z_=z_,
        y_=y_,
        config=config)

    batch_size = config['batch_size'] * config['num_D_steps'] * config[
        'num_D_accumulations']
    loader = sampled_ssgan.get_SSGAN_sample_loader(
        **{
            **config, 'batch_size': batch_size,
            'start_itr': state_dict['itr'],
            'is_slice': False
        })

    print('Beginning training at epoch %d...' % state_dict['epoch'])
    # Train for specified number of epochs, although we mostly track G iterations.
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        # Which progressbar to use? TQDM or my own?
        if config['pbar'] == 'mine':
            pbar = utils.progress(loader, displaytype='eta')
        else:
            pbar = tqdm(loader)
        for i, (img, z, w) in enumerate(pbar):
            # Increment the iteration counter
            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter much.
            G.train()
            if config['ema']:
                G_ema.train()

            img, w = img.to(device), w.to(device)
            img = torch.split(img, config['batch_size'])
            w = torch.split(w, config['batch_size'])
            counter = 0
            metrics = train(w[counter], img[counter])
            counter += 1
            del img, w
            train_log.log(itr=int(state_dict['itr']), **metrics)

            # Every sv_log_interval, log singular values
            if (config['sv_log_interval'] > 0) and (
                    not (state_dict['itr'] % config['sv_log_interval'])):
                train_log.log(itr=int(state_dict['itr']),
                              **{**utils.get_SVs(G, 'G')})

            # If using my progbar, print metrics.
            if config['pbar'] == 'mine':
                print(', '.join(
                    ['itr: %d' % state_dict['itr']] +
                    ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                      end=' ')

            # Save weights and copies as configured at specified interval
            if not (state_dict['itr'] % config['save_every']):
                if config['G_eval_mode']:
                    print('Switchin e to eval mode...')
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                train_fns.save_and_sample(G, None, G_ema, z_, y_, fixed_w,
                                          fixed_y, state_dict, config,
                                          experiment_name)
                # Test every specified interval
            if not (state_dict['itr'] % config['test_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                train_fns.test(G, None, G_ema, z_, y_, state_dict, config,
                               sample, get_inception_metrics, experiment_name,
                               test_log)
        #  Increment epoch counter at end of epoch
        state_dict['epoch'] += 1
        G_scheduler.step()
Пример #4
0
def run(config):
    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = vae_utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = Generator(**{**config, 'skip_init': True, 'no_optim': True}).to(device)
    print('Loading pretrained G for dir %s ...' % config['pretrained_G_dir'])
    pretrained_dict = torch.load(config['pretrained_G_dir'])
    G_dict = G.state_dict()
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in G_dict}
    G_dict.update(pretrained_dict)
    G.load_state_dict(G_dict)

    E = Encoder(**config).to(device)
    utils.toggle_grad(G, False)
    utils.toggle_grad(E, True)

    class G_E(nn.Module):
        def __init__(self):
            super(G_E, self).__init__()
            self.G = G
            self.E = E

        def forward(self, w, y):
            with torch.no_grad():
                net = self.G(w, self.G.shared(y))
            net = self.E(net)
            return net

    GE = G_E()

    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for E with decay of {}'.format(
            config['ema_decay']))
        E_ema = Encoder(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        e_ema = utils.ema(E, E_ema, config['ema_decay'], config['ema_start'])
    else:
        E_ema, e_ema = None, None

    print(G)
    print(E)
    print('Number of params in G: {} E: {}'.format(
        *
        [sum([p.data.nelement() for p in net.parameters()])
         for net in [G, E]]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        vae_utils.load_weights(
            [E], state_dict, config['weights_root'], experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            [e_ema] if config['ema'] else None)

    # If parallel, parallelize the GD module
    if config['parallel']:
        GE = nn.DataParallel(GE)
        if config['cross_replica']:
            patch_replication_callback(GE)

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])
    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)

    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])

    def train():
        E.optim.zero_grad()
        z_.sample_()
        y_.sample_()

        net = GE(z_[:config['batch_size']], y_[:config['batch_size']])
        loss = F.l1_loss(z_[:config['batch_size']], net)
        loss.backward()
        if config["E_ortho"] > 0.0:
            print('using modified ortho reg in E')
            utils.ortho(E, config['E_ortho'])
        E.optim.step()
        out = {'loss': float(loss.item())}
        return out

    print('Beginning training at epoch %d...' % state_dict['epoch'])
    # Train for specified number of epochs, although we mostly track G iterations.
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        for i in range(100000):
            # Increment the iteration counter
            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter much.
            G.train()
            E.train()
            if config['ema']:
                E_ema.train()
            metrics = train()
            train_log.log(itr=int(state_dict['itr']), **metrics)

            # Every sv_log_interval, log singular values
            if (config['sv_log_interval'] > 0) and (
                    not (state_dict['itr'] % config['sv_log_interval'])):
                train_log.log(itr=int(state_dict['itr']),
                              **{
                                  **utils.get_SVs(G, 'G'),
                                  **utils.get_SVs(E, 'E')
                              })

            # If using my progbar, print metrics.
            if config['pbar'] == 'mine':
                print(', '.join(
                    ['itr: %d' % state_dict['itr']] +
                    ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                      end=' ')

            # Save weights and copies as configured at specified interval
            if not (state_dict['itr'] % config['save_every']):
                vae_utils.save_weights([E], state_dict, config['weights_root'],
                                       experiment_name,
                                       'copy%d' % state_dict['save_num'],
                                       [E_ema if config['ema'] else None])
                state_dict['save_num'] = (state_dict['save_num'] +
                                          1) % config['num_save_copies']
        # Increment epoch counter at end of epoch
        state_dict['epoch'] += 1
Пример #5
0
def run(config):
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # Optionally, get the configuration from the state dict. This allows for
    # recovery of the config provided only a state dict and experiment name,
    # and can be convenient for writing less verbose sample shell scripts.
    if config['config_from_name']:
        utils.load_weights(None,
                           None,
                           state_dict,
                           config['weights_root'],
                           config['experiment_name'],
                           config['load_weights'],
                           None,
                           strict=False,
                           load_optim=False)
        # Ignore items which we might want to overwrite from the command line
        for item in state_dict['config']:
            if item not in [
                    'z_var', 'base_root', 'batch_size', 'G_batch_size',
                    'use_ema', 'G_eval_mode'
            ]:
                config[item] = state_dict['config'][item]

    # update config (see train.py for explanation)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    config = utils.update_config_roots(config)
    config['skip_init'] = True
    config['no_optim'] = True
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    G = model.Generator(**config).cuda()
    utils.count_parameters(G)

    # Load weights
    print('Loading weights...')
    # Here is where we deal with the ema--load ema weights or load normal weights
    utils.load_weights(G if not (config['use_ema']) else None,
                       None,
                       state_dict,
                       config['weights_root'],
                       experiment_name,
                       config['load_weights'],
                       G if config['ema'] and config['use_ema'] else None,
                       strict=False,
                       load_optim=False)
    # Update batch size setting used for G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'],
                               z_var=config['z_var'])

    if config['G_eval_mode']:
        print('Putting G in eval mode..')
        G.eval()
    else:
        print('G is in %s mode...' % ('training' if G.training else 'eval'))

    #Sample function
    sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config)
    if config['accumulate_stats']:
        print('Accumulating standing stats across %d accumulations...' %
              config['num_standing_accumulations'])
        utils.accumulate_standing_stats(G, z_, y_, config['n_classes'],
                                        config['num_standing_accumulations'])

    # Sample a number of images and save them to an NPZ, for use with TF-Inception
    if config['sample_npz']:
        # Lists to hold images and labels for images
        x, y = [], []
        print('Sampling %d images and saving them to npz...' %
              config['sample_num_npz'])
        for i in trange(
                int(np.ceil(config['sample_num_npz'] / float(G_batch_size)))):
            with torch.no_grad():
                images, labels = sample()
            x += [np.uint8(255 * (images.cpu().numpy() + 1) / 2.)]
            y += [labels.cpu().numpy()]
        x = np.concatenate(x, 0)[:config['sample_num_npz']]
        y = np.concatenate(y, 0)[:config['sample_num_npz']]
        print('Images shape: %s, Labels shape: %s' % (x.shape, y.shape))
        npz_filename = '%s/%s/samples.npz' % (config['samples_root'],
                                              experiment_name)
        print('Saving npz to %s...' % npz_filename)
        np.savez(npz_filename, **{'x': x, 'y': y})

    # Prepare sample sheets
    if config['sample_sheets']:
        print('Preparing conditional sample sheets...')
        utils.sample_sheet(
            G,
            classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']],
            num_classes=config['n_classes'],
            samples_per_class=10,
            parallel=config['parallel'],
            samples_root=config['samples_root'],
            experiment_name=experiment_name,
            folder_number=config['sample_sheet_folder_num'],
            z_=z_,
        )
    # Sample interp sheets
    if config['sample_interps']:
        print('Preparing interp sheets...')
        for fix_z, fix_y in zip([False, False, True], [False, True, False]):
            utils.interp_sheet(G,
                               num_per_sheet=16,
                               num_midpoints=8,
                               num_classes=config['n_classes'],
                               parallel=config['parallel'],
                               samples_root=config['samples_root'],
                               experiment_name=experiment_name,
                               folder_number=config['sample_sheet_folder_num'],
                               sheet_number=0,
                               fix_z=fix_z,
                               fix_y=fix_y,
                               device='cuda')
    # Sample random sheet
    if config['sample_random']:
        print('Preparing random sample sheet...')
        images, labels = sample()
        torchvision.utils.save_image(images.float(),
                                     '%s/%s/random_samples.jpg' %
                                     (config['samples_root'], experiment_name),
                                     nrow=int(G_batch_size**0.5),
                                     normalize=True)

    # Get Inception Score and FID
    get_inception_metrics = inception_utils.prepare_inception_metrics(
        config['dataset'], config['parallel'], config['no_fid'])

    # Prepare a simple function get metrics that we use for trunc curves
    def get_metrics():
        sample = functools.partial(utils.sample,
                                   G=G,
                                   z_=z_,
                                   y_=y_,
                                   config=config)
        IS_mean, IS_std, FID = get_inception_metrics(
            sample,
            config['num_inception_images'],
            num_splits=10,
            prints=False)
        # Prepare output string
        outstring = 'Using %s weights ' % ('ema'
                                           if config['use_ema'] else 'non-ema')
        outstring += 'in %s mode, ' % ('eval' if config['G_eval_mode'] else
                                       'training')
        outstring += 'with noise variance %3.3f, ' % z_.var
        outstring += 'over %d images, ' % config['num_inception_images']
        if config['accumulate_stats'] or not config['G_eval_mode']:
            outstring += 'with batch size %d, ' % G_batch_size
        if config['accumulate_stats']:
            outstring += 'using %d standing stat accumulations, ' % config[
                'num_standing_accumulations']
        outstring += 'Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f' % (
            state_dict['itr'], IS_mean, IS_std, FID)
        print(outstring)

    if config['sample_inception_metrics']:
        print('Calculating Inception metrics...')
        get_metrics()

    # Sample truncation curve stuff. This is basically the same as the inception metrics code
    if config['sample_trunc_curves']:
        start, step, end = [
            float(item) for item in config['sample_trunc_curves'].split('_')
        ]
        print(
            'Getting truncation values for variance in range (%3.3f:%3.3f:%3.3f)...'
            % (start, step, end))
        for var in np.arange(start, end + step, step):
            z_.var = var
            # Optionally comment this out if you want to run with standing stats
            # accumulated at one z variance setting
            if config['accumulate_stats']:
                utils.accumulate_standing_stats(
                    G, z_, y_, config['n_classes'],
                    config['num_standing_accumulations'])
            get_metrics()