Ejemplo n.º 1
0
def test(G, D, G_ema, z_, y_, state_dict, config, sample, get_inception_metrics,
         experiment_name, test_log):
  print('Gathering inception metrics...', )
  if config['accumulate_stats']:
    utils.accumulate_standing_stats(G_ema if config['ema'] and config['use_ema'] else G,
                           z_, y_, config['n_classes'],
                           config['num_standing_accumulations'])
  IS_mean, IS_std, FID = get_inception_metrics(sample, step=state_dict['shown_images'],
                                               num_inception_images=config['num_inception_images'],
                                               num_splits=10)
  print('shown_images %d: Inception Score is %3.3f +/- %3.3f, FID is %5.4f' %
        (state_dict['shown_images'], IS_mean, IS_std, FID), )
  # If improved over previous best metric, save approrpiate copy
  if not math.isnan(IS_mean) and not math.isnan(FID)\
        and (IS_mean > state_dict['best_IS']):
    print('%s improved over previous best, saving checkpoint...' % 'IS', )
    utils.save_weights(G, D, state_dict, config['weights_root'],
                       experiment_name, 'best_IS%d' % state_dict['save_best_num'],
                       G_ema if config['ema'] else None)
    state_dict['save_best_num'] = (state_dict['save_best_num'] + 1 ) % config['num_best_copies']
  if not math.isnan(IS_mean) and not math.isnan(FID)\
          and (FID < state_dict['best_FID']):
    print('%s improved over previous best, saving checkpoint...' % 'FID', )
    utils.save_weights(G, D, state_dict, config['weights_root'],
                       experiment_name, 'best_FID%d' % state_dict['save_best_num'],
                       G_ema if config['ema'] else None)
    state_dict['save_best_num'] = (state_dict['save_best_num'] + 1) % config['num_best_copies']

  state_dict['best_IS'] = max(state_dict['best_IS'], IS_mean)
  state_dict['best_FID'] = min(state_dict['best_FID'], FID)
  # Log results to file
  test_log.log(itr=int(state_dict['itr']), IS_mean=float(IS_mean),
               IS_std=float(IS_std), FID=float(FID))
  return IS_mean, IS_std, FID
Ejemplo n.º 2
0
def test(G, D, G_ema, z_, y_, state_dict, config, sample,
         get_inception_metrics, experiment_name, test_log):
    print('Gathering inception metrics...')
    if config['accumulate_stats']:
        utils.accumulate_standing_stats(
            G_ema if config['ema'] and config['use_ema'] else G, z_, y_,
            config['n_classes'], config['num_standing_accumulations'])
    IS_mean, IS_std, FID = get_inception_metrics(
        sample, config['num_inception_images'], num_splits=10)
    print(
        'Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f'
        % (state_dict['itr'], IS_mean, IS_std, FID))
    # If improved over previous best metric, save approrpiate copy
    if ((config['which_best'] == 'IS' and IS_mean > state_dict['best_IS']) or
        (config['which_best'] == 'FID' and FID < state_dict['best_FID'])):
        print('%s improved over previous best, saving checkpoint...' %
              config['which_best'])
        utils.save_weights(G, D, state_dict, config['weights_root'],
                           experiment_name,
                           'best%d' % state_dict['save_best_num'],
                           G_ema if config['ema'] else None)
        state_dict['save_best_num'] = (state_dict['save_best_num'] +
                                       1) % config['num_best_copies']
    state_dict['best_IS'] = max(state_dict['best_IS'], IS_mean)
    state_dict['best_FID'] = min(state_dict['best_FID'], FID)
    # Log results to file
    test_log.log(itr=int(state_dict['itr']),
                 IS_mean=float(IS_mean),
                 IS_std=float(IS_std),
                 FID=float(FID))
Ejemplo n.º 3
0
def save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y,
                    state_dict, config, experiment_name):
    utils.save_weights(G, D, state_dict, config['weights_root'],
                       experiment_name, None, G_ema if config['ema'] else None)
    # Save an additional copy to mitigate accidental corruption if process
    # is killed during a save (it's happened to me before -.-)
    if config['num_save_copies'] > 0:
        utils.save_weights(G, D, state_dict, config['weights_root'],
                           experiment_name,
                           'copy%d' % state_dict['save_num'],
                           G_ema if config['ema'] else None)
        state_dict['save_num'] = (
            state_dict['save_num'] + 1) % config['num_save_copies']

    # Use EMA G for samples or non-EMA?
    which_G = G_ema if config['ema'] and config['use_ema'] else G

    # Accumulate standing statistics?
    if config['accumulate_stats']:
        utils.accumulate_standing_stats(G_ema if config['ema'] and config['use_ema'] else G,
                                        z_, y_, config['n_classes'],
                                        config['num_standing_accumulations'])

    # Save a random sample sheet with fixed z and y
    with torch.no_grad():
        if config['parallel']:
            fixed_Gz = nn.parallel.data_parallel(
                which_G, (fixed_z, which_G.shared(fixed_y)))
        else:
            fixed_Gz = which_G(fixed_z, which_G.shared(fixed_y))
    if not os.path.isdir('%s/%s' % (config['samples_root'], experiment_name)):
        os.mkdir('%s/%s' % (config['samples_root'], experiment_name))
    image_filename = '%s/%s/fixed_samples%d.jpg' % (config['samples_root'],
                                                    experiment_name,
                                                    state_dict['itr'])
    torchvision.utils.save_image(fixed_Gz.float().cpu(), image_filename,
                                 nrow=int(fixed_Gz.shape[0] ** 0.5), normalize=True)
    # For now, every time we save, also save sample sheets
    utils.sample_sheet(which_G,
                       classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']],
                       num_classes=config['n_classes'],
                       samples_per_class=10, parallel=config['parallel'],
                       samples_root=config['samples_root'],
                       experiment_name=experiment_name,
                       folder_number=state_dict['itr'],
                       z_=z_)
    # Also save interp sheets
    for fix_z, fix_y in zip([False, False, True], [False, True, False]):
        utils.interp_sheet(which_G,
                           num_per_sheet=16,
                           num_midpoints=8,
                           num_classes=config['n_classes'],
                           parallel=config['parallel'],
                           samples_root=config['samples_root'],
                           experiment_name=experiment_name,
                           folder_number=state_dict['itr'],
                           sheet_number=0,
                           fix_z=fix_z, fix_y=fix_y, device='cuda')
Ejemplo n.º 4
0
def save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, state_dict, config,
                    experiment_name):
    utils.save_weights(G, D, state_dict, config['weights_root'],
                       experiment_name, None, G_ema if config['ema'] else None)
    if config['num_save_copies'] > 0:
        utils.save_weights(G, D, state_dict, config['weights_root'],
                           experiment_name, 'copy%d' % state_dict['save_num'],
                           G_ema if config['ema'] else None)
        state_dict['save_num'] = (state_dict['save_num'] +
                                  1) % config['num_save_copies']
    which_G = G_ema if config['ema'] and config['use_ema'] else G
    if config['accumulate_stats']:
        if not config['conditional']:
            y_.zero_()
        utils.accumulate_standing_stats(
            G_ema if config['ema'] and config['use_ema'] else G, z_, y_,
            config['n_classes'], config['num_standing_accumulations'])
    with torch.no_grad():
        if config['parallel']:
            fixed_Gz = nn.parallel.data_parallel(
                which_G, (fixed_z, which_G.shared(fixed_y)))
        else:
            fixed_Gz = which_G(fixed_z, which_G.shared(fixed_y))
    if not os.path.isdir('%s/%s' % (config['samples_root'], experiment_name)):
        os.mkdir('%s/%s' % (config['samples_root'], experiment_name))
    image_filename = '%s/%s/fixed_samples%d.jpg' % (
        config['samples_root'], experiment_name, state_dict['itr'])
    torchvision.utils.save_image(fixed_Gz.float().cpu(),
                                 image_filename,
                                 nrow=int(fixed_Gz.shape[0]**0.5),
                                 normalize=True)
    utils.sample_sheet(
        which_G,
        classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']],
        num_classes=config['n_classes'],
        samples_per_class=10,
        parallel=config['parallel'],
        samples_root=config['samples_root'],
        experiment_name=experiment_name,
        folder_number=state_dict['itr'],
        z_=z_)
    for fix_z, fix_y in zip([False, False, True], [False, True, False]):
        utils.interp_sheet(which_G,
                           num_per_sheet=16,
                           num_midpoints=8,
                           num_classes=config['n_classes'],
                           parallel=config['parallel'],
                           samples_root=config['samples_root'],
                           experiment_name=experiment_name,
                           folder_number=state_dict['itr'],
                           sheet_number=0,
                           fix_z=fix_z,
                           fix_y=fix_y,
                           device='cuda')
Ejemplo n.º 5
0
def test(G, D, GD, G_ema, z_, y_, state_dict, config, sample, get_inception_metrics,
         experiment_name, test_log, acc_metrics, acc_itrs):
    print('Calculating validation accuracy...')
    D.eval()
    D_accuracy = []
    loader = utils.get_data_loaders(
        **{**config, 'train': False, 'use_multiepoch_sampler': False, 'load_in_mem': False})[0]
    with torch.no_grad():
        for x, y in loader:
            D_real = GD(None, None, x=x, dy=y, policy=config['DiffAugment'])
            D_accuracy.append((D_real > 0).float().mean().item())
    D.train()
    D_acc_val = np.mean(D_accuracy)

    print('Calculating training accuracy...')
    D.eval()
    D_accuracy = []
    loader = utils.get_data_loaders(
        **{**config, 'train': True, 'use_multiepoch_sampler': False, 'load_in_mem': False})[0]
    with torch.no_grad():
        for x, y in loader:
            D_real = GD(None, None, x=x, dy=y, policy=config['DiffAugment'])
            D_accuracy.append((D_real > 0).float().mean().item())
    D.train()
    D_acc_train = np.mean(D_accuracy)

    print('Gathering inception metrics...')
    if config['accumulate_stats']:
        utils.accumulate_standing_stats(G_ema if config['ema'] and config['use_ema'] else G,
                                        z_, y_, config['n_classes'],
                                        config['num_standing_accumulations'])
    IS_mean, IS_std, FID = get_inception_metrics(sample,
                                                 config['num_inception_images'],
                                                 num_splits=10)
    print('Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f' %
          (state_dict['itr'], IS_mean, IS_std, FID))
    # If improved over previous best metric, save approrpiate copy
    if ((config['which_best'] == 'IS' and IS_mean > state_dict['best_IS'])
            or (config['which_best'] == 'FID' and FID < state_dict['best_FID'])):
        print('%s improved over previous best, saving checkpoint...' %
              config['which_best'])
        utils.save_weights(G, D, state_dict, config['weights_root'],
                           experiment_name, 'best%d' % state_dict['save_best_num'] if config['num_best_copies'] > 1 else 'best',
                           G_ema if config['ema'] else None)
        state_dict['save_best_num'] = (
            state_dict['save_best_num'] + 1) % config['num_best_copies']
    state_dict['best_IS'] = max(state_dict['best_IS'], IS_mean)
    state_dict['best_FID'] = min(state_dict['best_FID'], FID)

    # Log results to file
    test_log.log(itr=int(state_dict['itr']), IS_mean=float(IS_mean),
                 IS_std=float(IS_std), FID=float(FID), D_acc_val=D_acc_val, D_acc_train=D_acc_train,
                 **{k: v / acc_itrs for k, v in acc_metrics.items()})
def save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, state_dict, config,
                    experiment_name):
    utils.save_weights(G, D, state_dict, config['weights_root'],
                       experiment_name, None, G_ema if config['ema'] else None)
    # Save an additional copy to mitigate accidental corruption if process
    # is killed during a save (it's happened to me before -.-)
    if config['num_save_copies'] > 0:
        utils.save_weights(G, D, state_dict, config['weights_root'],
                           experiment_name, 'copy%d' % state_dict['save_num'],
                           G_ema if config['ema'] else None)
        state_dict['save_num'] = (state_dict['save_num'] +
                                  1) % config['num_save_copies']

    # Use EMA G for samples or non-EMA?
    which_G = G_ema if config['ema'] and config['use_ema'] else G

    # Accumulate standing statistics?
    if config['accumulate_stats']:
        utils.accumulate_standing_stats(
            G_ema if config['ema'] and config['use_ema'] else G, z_, y_,
            config['n_classes'], config['num_standing_accumulations'])

    # Save a random sample sheet with fixed z and y
    with torch.no_grad():
        # Get Generator output given noise
        if config['use_dog_cnt'] and config['G_shared']:
            ye0 = G.shared(fixed_y[:, 0])
            ye1 = G.dog_cnt_shared(fixed_y[:, 1])
            gyc = torch.cat([ye0, ye1], 1)
        else:
            gyc = G.shared(fixed_y)

        if config['parallel']:
            fixed_Gz = nn.parallel.data_parallel(which_G, (fixed_z, gyc))
        else:
            fixed_Gz = which_G(fixed_z, gyc)
    if not os.path.isdir('%s/%s' % (config['samples_root'], experiment_name)):
        os.mkdir('%s/%s' % (config['samples_root'], experiment_name))
    image_filename = '%s/%s/fixed_samples%d.jpg' % (
        config['samples_root'], experiment_name, state_dict['itr'])
    torchvision.utils.save_image(fixed_Gz.float().cpu(),
                                 image_filename,
                                 nrow=int(fixed_Gz.shape[0]**0.5),
                                 normalize=True)
    # # For now, every time we save, also save sample sheets
    num_classes = int(np.minimum(120, config['n_classes']))
Ejemplo n.º 7
0
def test(G, D, G_ema, z_, y_, state_dict, config, sample,
         get_inception_metrics, experiment_name, test_log):
    print('Gathering inception metrics...')
    if config['accumulate_stats']:
        utils.accumulate_standing_stats(
            G_ema if config['ema'] and config['use_ema'] else G, z_, y_,
            config['n_classes'], config['num_standing_accumulations'])
    IS_mean, IS_std, FID = get_inception_metrics(
        sample, config['num_inception_images'], num_splits=10)
    if np.isnan(float(FID)):
        FID = 99
    print(
        'Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f'
        % (state_dict['itr'], IS_mean, IS_std, FID))
    # If improved over previous best metric, save approrpiate copy
    if ((config['which_best'] == 'IS' and IS_mean > state_dict['best_IS']) or
        (config['which_best'] == 'FID' and FID < state_dict['best_FID'])):
        print('%s improved over previous best, saving checkpoint...' %
              config['which_best'])
        utils.save_weights(G, D, state_dict, config['weights_root'],
                           experiment_name,
                           'best%d' % state_dict['save_best_num'],
                           G_ema if config['ema'] else None)
        utils.save_weights(G, D, state_dict, config['weights_root'],
                           experiment_name, 'currentbest',
                           G_ema if config['ema'] else None)
        ###early exit when reaching the goal
        # if IS_mean>8.22 and FID<13.7:
        #   print("reach goal:IS:%.4f , FID:%.4f"%(IS_mean,FID))
        #   sys.exit(0)
        state_dict['save_best_num'] = (state_dict['save_best_num'] +
                                       1) % config['num_best_copies']
    state_dict['best_IS'] = max(state_dict['best_IS'], IS_mean)
    state_dict['best_FID'] = min(state_dict['best_FID'], FID)
    # Log results to file
    #test_log.log(itr=int(state_dict['itr']), IS_mean=float(IS_mean),
    #             IS_std=float(IS_std), FID=float(FID))
    test_log.add_scalar(step=int(state_dict['itr']),
                        tag="IS_mean",
                        value=float(IS_mean))
    test_log.add_scalar(step=int(state_dict['itr']),
                        tag="IS_std",
                        value=float(IS_std))
    test_log.add_scalar(step=int(state_dict['itr']),
                        tag="FID",
                        value=float(FID))
Ejemplo n.º 8
0
def save_and_sample(G, D, Dv, G_ema, z_, y_, fixed_z, fixed_y, state_dict,
                    config, experiment_name):
    utils.save_weights(G, D, Dv, state_dict, config['weights_root'],
                       experiment_name, None, G_ema if config['ema'] else None)
    # Save an additional copy to mitigate accidental corruption if process
    # is killed during a save (it's happened to me before -.-)
    if config['num_save_copies'] > 0:
        utils.save_weights(G, D, Dv, state_dict, config['weights_root'],
                           experiment_name, 'copy%d' % state_dict['save_num'],
                           G_ema if config['ema'] else None)
        state_dict['save_num'] = (state_dict['save_num'] +
                                  1) % config['num_save_copies']

    # Use EMA G for samples or non-EMA?
    which_G = G_ema if config['ema'] and config['use_ema'] else G

    # Accumulate standing statistics?
    if config['accumulate_stats']:
        utils.accumulate_standing_stats(
            G_ema if config['ema'] and config['use_ema'] else G, z_, y_,
            config['n_classes'], config['num_standing_accumulations'])
Ejemplo n.º 9
0
def save_and_sample(G, D, G_ema, sample, fixed_z, fixed_y, state_dict, config,
                    experiment_name):
    utils.save_weights(G, D, state_dict, config['weights_root'],
                       experiment_name, None, G_ema if config['ema'] else None)
    # Save an additional copy to mitigate accidental corruption if process
    # is killed during a save (it's happened to me before -.-)
    if config['num_save_copies'] > 0:
        utils.save_weights(G, D, state_dict, config['weights_root'],
                           experiment_name, 'copy%d' % state_dict['save_num'],
                           G_ema if config['ema'] else None)
        state_dict['save_num'] = (state_dict['save_num'] +
                                  1) % config['num_save_copies']

    # Use EMA G for samples or non-EMA?
    which_G = G_ema if config['ema'] and config['use_ema'] and (
        state_dict['itr'] > config['ema_start']) else G

    # Accumulate standing statistics?
    if config['accumulate_stats']:
        utils.accumulate_standing_stats(
            G_ema if config['ema'] and config['use_ema'] else G, sample,
            config['n_classes'], config['num_standing_accumulations'])

    # Save a random sample sheet with fixed z and y
    with torch.no_grad():
        fixed_Gz = which_G(fixed_z, which_G.shared(fixed_y))
    if not os.path.isdir('%s/%s' % (config['samples_root'], experiment_name)):
        try:
            os.mkdir('%s/%s' % (config['samples_root'], experiment_name))
        except:
            pass
    image_filename = '%s/%s/fixed_samples%d.jpg' % (
        config['samples_root'], experiment_name, state_dict['itr'])
    torchvision.utils.save_image(fixed_Gz.float().cpu(),
                                 image_filename,
                                 nrow=int(fixed_Gz.shape[0]**0.5),
                                 normalize=True)
Ejemplo n.º 10
0
def sample(G, G_ema, z_, y_, fixed_z, fixed_y, state_dict, config, experiment_name):
    
  # Use EMA G for samples or non-EMA?
  which_G = G_ema if config['ema'] and config['use_ema'] else G
  
  # Accumulate standing statistics?
  if config['accumulate_stats']:
    utils.accumulate_standing_stats(G_ema if config['ema'] and config['use_ema'] else G,
                           z_, y_, config['n_classes'],
                           config['num_standing_accumulations'])
  
  # Save a random sample sheet with fixed z and y      
  with torch.no_grad():
    if config['parallel']:
      fixed_Gz =  nn.parallel.data_parallel(which_G, (fixed_z, which_G.shared(fixed_y)))
    else:
      fixed_Gz = which_G(fixed_z, which_G.shared(fixed_y))
  if not os.path.isdir('%s/%s' % (config['samples_root'], experiment_name)):
    os.mkdir('%s/%s' % (config['samples_root'], experiment_name))
  image_filename = '%s/%s/fixed_samples%d.jpg' % (config['samples_root'], 
                                                  experiment_name,
                                                  state_dict['itr'])
  torchvision.utils.save_image(fixed_Gz.float().cpu(), image_filename,
                             nrow=int(fixed_Gz.shape[0] **0.5), normalize=True)
  # For now, every time we save, also save sample sheets
  utils.sample_sheet(which_G,
                     classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']],
                     num_classes=config['n_classes'],
                     samples_per_class=10, parallel=config['parallel'],
                     samples_root=config['samples_root'],
                     experiment_name=experiment_name,
                     folder_number=state_dict['itr'],
                     z_=z_)
  # Also save interp sheets
  # AG don't interpolate (save memory and time)
  '''
Ejemplo n.º 11
0
def run(config):
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # Optionally, get the configuration from the state dict. This allows for
    # recovery of the config provided only a state dict and experiment name,
    # and can be convenient for writing less verbose sample shell scripts.
    if config['config_from_name']:
        utils.load_weights(None,
                           None,
                           state_dict,
                           config['weights_root'],
                           config['experiment_name'],
                           config['load_weights'],
                           None,
                           strict=False,
                           load_optim=False)
        # Ignore items which we might want to overwrite from the command line
        for item in state_dict['config']:
            if item not in [
                    'z_var', 'base_root', 'batch_size', 'G_batch_size',
                    'use_ema', 'G_eval_mode'
            ]:
                config[item] = state_dict['config'][item]

    # update config (see train.py for explanation)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['n_channels'] = utils.nchannels_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    config = utils.update_config_roots(config)
    config['skip_init'] = True
    config['no_optim'] = True
    device = 'cuda'

    # Seed RNG
    # utils.seed_rng(config['seed'])

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    G = model.Generator(**config).cuda()
    utils.count_parameters(G)

    # In some cases we need to load D
    if True or config['get_test_error'] or config['get_train_error'] or config[
            'get_self_error'] or config['get_generator_error']:
        disc_config = config.copy()
        if config['mh_csc_loss'] or config['mh_loss']:
            disc_config['output_dim'] = disc_config['n_classes'] + 1
        D = model.Discriminator(**disc_config).to(device)

        def get_n_correct_from_D(x, y):
            """Gets the "classifications" from D.
      
      y: the correct labels
      
      In the case of projection discrimination we have to pass in all the labels
      as conditionings to get the class specific affinity.
      """
            x = x.to(device)
            if config['model'] == 'BigGAN':  # projection discrimination case
                if not config['get_self_error']:
                    y = y.to(device)
                yhat = D(x, y)
                for i in range(1, config['n_classes']):
                    yhat_ = D(x, ((y + i) % config['n_classes']))
                    yhat = torch.cat([yhat, yhat_], 1)
                preds_ = yhat.data.max(1)[1].cpu()
                return preds_.eq(0).cpu().sum()
            else:  # the mh gan case
                if not config['get_self_error']:
                    y = y.to(device)
                yhat = D(x)
                preds_ = yhat[:, :config['n_classes']].data.max(1)[1]
                return preds_.eq(y.data).cpu().sum()

    # Load weights
    print('Loading weights...')
    # Here is where we deal with the ema--load ema weights or load normal weights
    utils.load_weights(G if not (config['use_ema']) else None,
                       D,
                       state_dict,
                       config['weights_root'],
                       experiment_name,
                       config['load_weights'],
                       G if config['ema'] and config['use_ema'] else None,
                       strict=False,
                       load_optim=False)
    # Update batch size setting used for G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'],
                               z_var=config['z_var'])

    if config['G_eval_mode']:
        print('Putting G in eval mode..')
        G.eval()
    else:
        print('G is in %s mode...' % ('training' if G.training else 'eval'))

    sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config)
    brief_expt_name = config['experiment_name'][-30:]

    # load results dict always
    HIST_FNAME = 'scoring_hist.npy'

    def load_or_make_hist(d):
        """make/load history files in each
    """
        if not os.path.isdir(d):
            raise Exception('%s is not a valid directory' % d)
        f = os.path.join(d, HIST_FNAME)
        if os.path.isfile(f):
            return np.load(f, allow_pickle=True).item()
        else:
            return defaultdict(dict)

    hist_dir = os.path.join(config['weights_root'], config['experiment_name'])
    hist = load_or_make_hist(hist_dir)

    if config['get_test_error'] or config['get_train_error']:
        loaders = utils.get_data_loaders(
            **{
                **config, 'batch_size': config['batch_size'],
                'start_itr': state_dict['itr'],
                'use_test_set': config['get_test_error']
            })
        acc_type = 'Test' if config['get_test_error'] else 'Train'

        pbar = tqdm(loaders[0])
        loader_total = len(loaders[0]) * config['batch_size']
        sample_todo = min(config['sample_num_error'], loader_total)
        print('Getting %s error accross %i examples' % (acc_type, sample_todo))
        correct = 0
        total = 0

        with torch.no_grad():
            for i, (x, y) in enumerate(pbar):
                correct += get_n_correct_from_D(x, y)
                total += config['batch_size']
                if loader_total > total and total >= config['sample_num_error']:
                    print('Quitting early...')
                    break

        accuracy = float(correct) / float(total)
        hist = load_or_make_hist(hist_dir)
        hist[state_dict['itr']][acc_type] = accuracy
        np.save(os.path.join(hist_dir, HIST_FNAME), hist)

        print('[%s][%06d] %s accuracy: %f.' %
              (brief_expt_name, state_dict['itr'], acc_type, accuracy * 100))

    if config['get_self_error']:
        n_used_imgs = config['sample_num_error']
        correct = 0
        imageSize = config['resolution']
        x = np.empty((n_used_imgs, imageSize, imageSize, 3), dtype=np.uint8)
        for l in tqdm(range(n_used_imgs // G_batch_size),
                      desc='Generating [%s][%06d]' %
                      (brief_expt_name, state_dict['itr'])):
            with torch.no_grad():
                images, y = sample()
                correct += get_n_correct_from_D(images, y)

        accuracy = float(correct) / float(n_used_imgs)
        print('[%s][%06d] %s accuracy: %f.' %
              (brief_expt_name, state_dict['itr'], 'Self', accuracy * 100))
        hist = load_or_make_hist(hist_dir)
        hist[state_dict['itr']]['Self'] = accuracy
        np.save(os.path.join(hist_dir, HIST_FNAME), hist)

    if config['get_generator_error']:

        if config['dataset'] == 'C10':
            from classification.models.densenet import DenseNet121
            from torchvision import transforms
            compnet = DenseNet121()
            compnet = torch.nn.DataParallel(compnet)
            #checkpoint = torch.load(os.path.join('/scratch0/ilya/locDoc/classifiers/densenet121','ckpt_47.t7'))
            checkpoint = torch.load(
                os.path.join(
                    '/fs/vulcan-scratch/ilyak/locDoc/experiments/classifiers/cifar/densenet121',
                    'ckpt_47.t7'))
            compnet.load_state_dict(checkpoint['net'])
            compnet = compnet.to(device)
            compnet.eval()
            minimal_trans = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.2023, 0.1994, 0.2010)),
            ])
        elif config['dataset'] == 'C100':
            from classification.models.densenet import DenseNet121
            from torchvision import transforms
            compnet = DenseNet121(num_classes=100)
            compnet = torch.nn.DataParallel(compnet)
            checkpoint = torch.load(
                os.path.join(
                    '/scratch0/ilya/locDoc/classifiers/cifar100/densenet121',
                    'ckpt.copy.t7'))
            #checkpoint = torch.load(os.path.join('/fs/vulcan-scratch/ilyak/locDoc/experiments/classifiers/cifar100/densenet121','ckpt.copy.t7'))
            compnet.load_state_dict(checkpoint['net'])
            compnet = compnet.to(device)
            compnet.eval()
            minimal_trans = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.507, 0.487, 0.441),
                                     (0.267, 0.256, 0.276)),
            ])
        elif config['dataset'] == 'STL48':
            from classification.models.wideresnet import WideResNet48
            from torchvision import transforms
            checkpoint = torch.load(
                os.path.join(
                    '/fs/vulcan-scratch/ilyak/locDoc/experiments/classifiers/stl/mixmatch_48',
                    'model_best.pth.tar'))
            compnet = WideResNet48(num_classes=10)
            compnet = compnet.to(device)
            for param in compnet.parameters():
                param.detach_()
            compnet.load_state_dict(checkpoint['ema_state_dict'])
            compnet.eval()
            minimal_trans = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.2023, 0.1994, 0.2010)),
            ])
        else:
            raise ValueError('Dataset %s has no comparison network.' %
                             config['dataset'])

        n_used_imgs = 10000
        correct = 0
        mean_label = np.zeros(config['n_classes'])
        imageSize = config['resolution']
        x = np.empty((n_used_imgs, imageSize, imageSize, 3), dtype=np.uint8)
        for l in tqdm(range(n_used_imgs // G_batch_size),
                      desc='Generating [%s][%06d]' %
                      (brief_expt_name, state_dict['itr'])):
            with torch.no_grad():
                images, y = sample()
                fake = images.data.cpu().numpy()
                fake = np.floor((fake + 1) * 255 / 2.0).astype(np.uint8)
                fake_input = np.zeros(fake.shape)
                for bi in range(fake.shape[0]):
                    fake_input[bi] = minimal_trans(np.moveaxis(
                        fake[bi], 0, -1))
                images.data.copy_(torch.from_numpy(fake_input))
                lab = compnet(images).max(1)[1]
                mean_label += np.bincount(lab.data.cpu(),
                                          minlength=config['n_classes'])
                correct += int((lab == y).sum().cpu())

        accuracy = float(correct) / float(n_used_imgs)
        mean_label_normalized = mean_label / float(n_used_imgs)

        print(
            '[%s][%06d] %s accuracy: %f.' %
            (brief_expt_name, state_dict['itr'], 'Generator', accuracy * 100))
        hist = load_or_make_hist(hist_dir)
        hist[state_dict['itr']]['Generator'] = accuracy
        hist[state_dict['itr']]['Mean_Label'] = mean_label_normalized
        np.save(os.path.join(hist_dir, HIST_FNAME), hist)

    if config['accumulate_stats']:
        print('Accumulating standing stats across %d accumulations...' %
              config['num_standing_accumulations'])
        utils.accumulate_standing_stats(G, z_, y_, config['n_classes'],
                                        config['num_standing_accumulations'])

    # Sample a number of images and save them to an NPZ, for use with TF-Inception
    if config['sample_npz']:
        # Lists to hold images and labels for images
        x, y = [], []
        print('Sampling %d images and saving them to npz...' %
              config['sample_num_npz'])
        for i in trange(
                int(np.ceil(config['sample_num_npz'] / float(G_batch_size)))):
            with torch.no_grad():
                images, labels = sample()
            x += [np.uint8(255 * (images.cpu().numpy() + 1) / 2.)]
            y += [labels.cpu().numpy()]
        x = np.concatenate(x, 0)[:config['sample_num_npz']]
        y = np.concatenate(y, 0)[:config['sample_num_npz']]
        print('Images shape: %s, Labels shape: %s' % (x.shape, y.shape))
        npz_filename = '%s/%s/samples.npz' % (config['samples_root'],
                                              experiment_name)
        print('Saving npz to %s...' % npz_filename)
        np.savez(npz_filename, **{'x': x, 'y': y})

    if config['official_FID']:
        f = np.load(config['dataset_is_fid'])
        # this is for using the downloaded one from
        # https://github.com/bioinf-jku/TTUR
        #mdata, sdata = f['mu'][:], f['sigma'][:]

        # this one is for my format files
        mdata, sdata = f['mfid'], f['sfid']

    # Sample a number of images and stick them in memory, for use with TF-Inception official_IS and official_FID
    data_gen_necessary = False
    if config['sample_np_mem']:
        is_saved = int('IS' in hist[state_dict['itr']])
        is_todo = int(config['official_IS'])
        fid_saved = int('FID' in hist[state_dict['itr']])
        fid_todo = int(config['official_FID'])
        data_gen_necessary = config['overwrite'] or (is_todo > is_saved) or (
            fid_todo > fid_saved)
    if config['sample_np_mem'] and data_gen_necessary:
        n_used_imgs = 50000
        imageSize = config['resolution']
        x = np.empty((n_used_imgs, imageSize, imageSize, 3), dtype=np.uint8)
        for l in tqdm(range(n_used_imgs // G_batch_size),
                      desc='Generating [%s][%06d]' %
                      (brief_expt_name, state_dict['itr'])):
            start = l * G_batch_size
            end = start + G_batch_size

            with torch.no_grad():
                images, labels = sample()
            fake = np.uint8(255 * (images.cpu().numpy() + 1) / 2.)
            x[start:end] = np.moveaxis(fake, 1, -1)
            #y += [labels.cpu().numpy()]

    if config['official_IS']:
        if (not ('IS' in hist[state_dict['itr']])) or config['overwrite']:
            mis, sis = iscore.get_inception_score(x)
            print('[%s][%06d] IS mu: %f. IS sigma: %f.' %
                  (brief_expt_name, state_dict['itr'], mis, sis))
            hist = load_or_make_hist(hist_dir)
            hist[state_dict['itr']]['IS'] = [mis, sis]
            np.save(os.path.join(hist_dir, HIST_FNAME), hist)
        else:
            mis, sis = hist[state_dict['itr']]['IS']
            print(
                '[%s][%06d] Already done (skipping...): IS mu: %f. IS sigma: %f.'
                % (brief_expt_name, state_dict['itr'], mis, sis))

    if config['official_FID']:
        import tensorflow as tf

        def fid_ms_for_imgs(images, mem_fraction=0.5):
            gpu_options = tf.GPUOptions(
                per_process_gpu_memory_fraction=mem_fraction)
            inception_path = fid.check_or_download_inception(None)
            fid.create_inception_graph(
                inception_path)  # load the graph into the current TF graph
            with tf.Session(config=tf.ConfigProto(
                    gpu_options=gpu_options)) as sess:
                sess.run(tf.global_variables_initializer())
                mu_gen, sigma_gen = fid.calculate_activation_statistics(
                    images, sess, batch_size=100)
            return mu_gen, sigma_gen

        if (not ('FID' in hist[state_dict['itr']])) or config['overwrite']:
            m1, s1 = fid_ms_for_imgs(x)
            fid_value = fid.calculate_frechet_distance(m1, s1, mdata, sdata)
            print('[%s][%06d] FID: %f' %
                  (brief_expt_name, state_dict['itr'], fid_value))
            hist = load_or_make_hist(hist_dir)
            hist[state_dict['itr']]['FID'] = fid_value
            np.save(os.path.join(hist_dir, HIST_FNAME), hist)
        else:
            fid_value = hist[state_dict['itr']]['FID']
            print('[%s][%06d] Already done (skipping...): FID: %f' %
                  (brief_expt_name, state_dict['itr'], fid_value))

    # Prepare sample sheets
    if config['sample_sheets']:
        print('Preparing conditional sample sheets...')
        folder_number = config['sample_sheet_folder_num']
        if folder_number == -1:
            folder_number = config['load_weights']
        utils.sample_sheet(
            G,
            classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']],
            num_classes=config['n_classes'],
            samples_per_class=10,
            parallel=config['parallel'],
            samples_root=config['samples_root'],
            experiment_name=experiment_name,
            folder_number=folder_number,
            z_=z_,
        )
    # Sample interp sheets
    if config['sample_interps']:
        print('Preparing interp sheets...')
        folder_number = config['sample_sheet_folder_num']
        if folder_number == -1:
            folder_number = config['load_weights']
        for fix_z, fix_y in zip([False, False, True], [False, True, False]):
            utils.interp_sheet(G,
                               num_per_sheet=16,
                               num_midpoints=8,
                               num_classes=config['n_classes'],
                               parallel=config['parallel'],
                               samples_root=config['samples_root'],
                               experiment_name=experiment_name,
                               folder_number=int(folder_number),
                               sheet_number=0,
                               fix_z=fix_z,
                               fix_y=fix_y,
                               device='cuda')
    # Sample random sheet
    if config['sample_random']:
        print('Preparing random sample sheet...')
        images, labels = sample()
        torchvision.utils.save_image(
            images.float(),
            '%s/%s/%s.jpg' %
            (config['samples_root'], experiment_name, config['load_weights']),
            nrow=int(G_batch_size**0.5),
            normalize=True)

    # Prepare a simple function get metrics that we use for trunc curves
    def get_metrics():
        # Get Inception Score and FID
        get_inception_metrics = inception_utils.prepare_inception_metrics(
            config['dataset'], config['parallel'], config['no_fid'])
        sample = functools.partial(utils.sample,
                                   G=G,
                                   z_=z_,
                                   y_=y_,
                                   config=config)
        IS_mean, IS_std, FID = get_inception_metrics(
            sample,
            config['num_inception_images'],
            num_splits=10,
            prints=False)
        # Prepare output string
        outstring = 'Using %s weights ' % ('ema'
                                           if config['use_ema'] else 'non-ema')
        outstring += 'in %s mode, ' % ('eval' if config['G_eval_mode'] else
                                       'training')
        outstring += 'with noise variance %3.3f, ' % z_.var
        outstring += 'over %d images, ' % config['num_inception_images']
        if config['accumulate_stats'] or not config['G_eval_mode']:
            outstring += 'with batch size %d, ' % G_batch_size
        if config['accumulate_stats']:
            outstring += 'using %d standing stat accumulations, ' % config[
                'num_standing_accumulations']
        outstring += 'Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f' % (
            state_dict['itr'], IS_mean, IS_std, FID)
        print(outstring)

    if config['sample_inception_metrics']:
        print('Calculating Inception metrics...')
        get_metrics()

    # Sample truncation curve stuff. This is basically the same as the inception metrics code
    if config['sample_trunc_curves']:
        start, step, end = [
            float(item) for item in config['sample_trunc_curves'].split('_')
        ]
        print(
            'Getting truncation values for variance in range (%3.3f:%3.3f:%3.3f)...'
            % (start, step, end))
        for var in np.arange(start, end + step, step):
            z_.var = var
            # Optionally comment this out if you want to run with standing stats
            # accumulated at one z variance setting
            if config['accumulate_stats']:
                utils.accumulate_standing_stats(
                    G, z_, y_, config['n_classes'],
                    config['num_standing_accumulations'])
            get_metrics()
Ejemplo n.º 12
0
def run(config):
  # Prepare state dict, which holds things like epoch # and itr #
  state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0,
                'best_IS': 0, 'best_FID': 999999, 'config': config}
                
  # Optionally, get the configuration from the state dict. This allows for
  # recovery of the config provided only a state dict and experiment name,
  # and can be convenient for writing less verbose sample shell scripts.
  if config['config_from_name']:
    utils.load_weights(None, None, state_dict, config['weights_root'], 
                       config['experiment_name'], config['load_weights'], None,
                       strict=False, load_optim=False)
    # Ignore items which we might want to overwrite from the command line
    for item in state_dict['config']:
      if item not in ['z_var', 'base_root', 'batch_size', 'G_batch_size', 'use_ema', 'G_eval_mode']:
        config[item] = state_dict['config'][item]
  # Update the config dict as necessary
  # This is for convenience, to add settings derived from the user-specified
  # configuration into the config-dict (e.g. inferring the number of classes
  # and size of the images from the dataset, passing in a pytorch object
  # for the activation specified as a string)
  config['resolution'] = utils.imsize_dict[config['dataset']]
  config['n_classes'] = utils.nclass_dict[config['dataset']]
  config['G_activation'] = utils.activation_dict[config['G_nl']]
  config['D_activation'] = utils.activation_dict[config['D_nl']]
  # By default, skip init if resuming training.
  if config['resume']:
    print('Skipping initialization for training resumption...')
    config['skip_init'] = True
  config = utils.update_config_roots(config)
  device = 'cuda'
  
  # Seed RNG
  utils.seed_rng(config['seed'])

  # Prepare root folders if necessary
  utils.prepare_root(config)

  # Setup cudnn.benchmark for free speed
  torch.backends.cudnn.benchmark = True

  # Import the model--this line allows us to dynamically select different files.
  model = __import__(config['model'])
  experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
  print('Experiment name is %s' % experiment_name)

  # Next, build the model
  G = model.Generator(**config).to(device)
  D = model.Discriminator(**config).to(device)
  utils.count_parameters(G)
  
  # Load weights
  print('Loading weights...')
  # Here is where we deal with the ema--load ema weights or load normal weights
  utils.load_weights(G if not (config['use_ema']) else None, None, state_dict, 
                     config['weights_root'], experiment_name, config['load_weights'],
                     G if config['ema'] and config['use_ema'] else None,
                     strict=False, load_optim=False)
  # Update batch size setting used for G
  G_batch_size = max(config['G_batch_size'], config['batch_size']) 
  z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'],
                             device=device, fp16=config['G_fp16'], 
                             z_var=config['z_var'])
  
  if config['G_eval_mode']:
    print('Putting G in eval mode..')
    G.eval()
  else:
    print('G is in %s mode...' % ('training' if G.training else 'eval'))
    
  #Sample function
  sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config)  
  if config['accumulate_stats']:
    print('Accumulating standing stats across %d accumulations...' % config['num_standing_accumulations'])
    utils.accumulate_standing_stats(G, z_, y_, config['n_classes'],
                                    config['num_standing_accumulations'])
    
  
  # Sample a number of images and save them to an NPZ, for use with TF-Inception
  if config['sample_npz']:
    # Lists to hold images and labels for images
    x, y = [], []
    print('Sampling %d images and saving them to npz...' % config['sample_num_npz'])
    for i in trange(int(np.ceil(config['sample_num_npz'] / float(G_batch_size)))):
      with torch.no_grad():
        images, labels = sample()
      x += [np.uint8(255 * (images.cpu().numpy() + 1) / 2.)]
      y += [labels.cpu().numpy()]
    x = np.concatenate(x, 0)[:config['sample_num_npz']]
    y = np.concatenate(y, 0)[:config['sample_num_npz']]    
    print('Images shape: %s, Labels shape: %s' % (x.shape, y.shape))
    npz_filename = '%s/%s/samples.npz' % (config['samples_root'], experiment_name)
    print('Saving npz to %s...' % npz_filename)
    np.savez(npz_filename, **{'x' : x, 'y' : y})
  
  # Prepare sample sheets
  if config['sample_sheets']:
    print('Preparing conditional sample sheets...')
    utils.sample_sheet(G, classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']], 
                         num_classes=config['n_classes'], 
                         samples_per_class=10, parallel=config['parallel'],
                         samples_root=config['samples_root'], 
                         experiment_name=experiment_name,
                         folder_number=config['sample_sheet_folder_num'],
                         z_=z_,)
  # Sample interp sheets
  if config['sample_interps']:
    print('Preparing interp sheets...')
    for fix_z, fix_y in zip([False, False, True], [False, True, False]):
      utils.interp_sheet(G, num_per_sheet=16, num_midpoints=8,
                         num_classes=config['n_classes'], 
                         parallel=config['parallel'], 
                         samples_root=config['samples_root'], 
                         experiment_name=experiment_name,
                         folder_number=config['sample_sheet_folder_num'], 
                         sheet_number=0,
                         fix_z=fix_z, fix_y=fix_y, device='cuda')
  # Sample random sheet
  if config['sample_random']:
    print('Preparing random sample sheet...')
    images, labels = sample()
    print("labels size", labels)    
    torchvision.utils.save_image(images.float(),
                                 '%s/%s/random_samples.jpg' % (config['samples_root'], experiment_name),
                                 nrow=int(G_batch_size**0.5),
                                 normalize=True)

   # If using EMA, prepare it
  if config['ema']:
    print('Preparing EMA for G with decay of {}'.format(config['ema_decay']))
    G_ema = model.Generator(**{**config, 'skip_init':True, 
                               'no_optim': True}).to(device)
    ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
  else:
    G_ema, ema = None, None
  
  # FP16?
  if config['G_fp16']:
    print('Casting G to float16...')
    G = G.half()
    if config['ema']:
      G_ema = G_ema.half()
  if config['D_fp16']:
    print('Casting D to fp16...')
    D = D.half()
    # Consider automatically reducing SN_eps?
  GD = model.G_D(G, D)
  #print(G)
  #print(D)
  print('Number of params in G: {} D: {}'.format(
    *[sum([p.data.nelement() for p in net.parameters()]) for net in [G,D]]))
  # Prepare state dict, which holds things like epoch # and itr #
  state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0,
                'best_IS': 0, 'best_FID': 999999, 'config': config}

  # If loading from a pre-trained model, load weights
  if config['resume']:
    print('Loading weights...')
    utils.load_weights(G, D, state_dict,
                       config['weights_root'], experiment_name, 
                       config['load_weights'] if config['load_weights'] else None,
                       G_ema if config['ema'] else None)

  # If parallel, parallelize the GD module
  if config['parallel']:
    GD = nn.DataParallel(GD)
    if config['cross_replica']:
      patch_replication_callback(GD)

  
  D_fake = D(images[1,:,:,:],labels[0])
  print("D_fake ",D_fake)
Ejemplo n.º 13
0
def test(G,
         D,
         G_ema,
         Prior,
         state_dict,
         config,
         sample,
         get_inception_metrics,
         experiment_name,
         test_log,
         E=None,
         test_acc=None):
    print('Gathering inception metrics...')
    if config['accumulate_stats']:
        utils.accumulate_standing_stats(
            G_ema if config['ema'] and config['use_ema'] else G, Prior,
            config['n_classes'], config['num_standing_accumulations'])
    if (config['dataset'] in ['C10']):
        IS_mean, IS_std, FID = get_inception_metrics(
            sample, config['num_inception_images'], num_splits=10)
        print(
            'Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f'
            % (state_dict['itr'], IS_mean, IS_std, FID))
        # If improved over previous best metric, save approrpiate copy
        if ((config['which_best'] == 'IS' and IS_mean > state_dict['best_IS'])
                or
            (config['which_best'] == 'FID' and FID < state_dict['best_FID'])):
            print('%s improved over previous best, saving checkpoint...' %
                  config['which_best'])
            utils.save_weights(
                G,
                D,
                state_dict,
                config['weights_root'],
                experiment_name,
                'best%d' % state_dict['save_best_num'],
                G_ema if config['ema'] else None,
                E=None if not config['is_encoder'] else E,
                Prior=None if config['prior_type'] == 'default' else Prior)
            state_dict['save_best_num'] = (state_dict['save_best_num'] +
                                           1) % config['num_best_copies']
    else:
        IS_mean, IS_std, FID = 0, 0, 0

    if test_acc is not None:
        _, _, error_rec, _, _ = test_acc
        if error_rec < state_dict['best_error_rec']:
            utils.save_weights(
                G,
                D,
                state_dict,
                config['weights_root'],
                experiment_name,
                'best_error_rec',
                G_ema if config['ema'] else None,
                E=None if not config['is_encoder'] else E,
                Prior=None if config['prior_type'] == 'default' else Prior)
        state_dict['best_error_rec'] = min(state_dict['best_error_rec'],
                                           error_rec)

    state_dict['best_IS'] = max(state_dict['best_IS'], IS_mean)
    state_dict['best_FID'] = min(state_dict['best_FID'], FID)
    # Log results to file
    if test_acc is None:
        test_log.log(itr=int(state_dict['itr']),
                     IS_mean=float(IS_mean),
                     IS_std=float(IS_std),
                     FID=float(FID))
    else:
        test_acc, test_acc_iter, error_rec, p_mse, p_lik = test_acc
        test_log.log(itr=int(state_dict['itr']),
                     IS_mean=float(IS_mean),
                     IS_std=float(IS_std),
                     FID=float(FID),
                     Acc=float(test_acc),
                     Acc_iter=float(test_acc_iter),
                     error_rec=float(error_rec),
                     p_mse=float(p_mse),
                     p_lik=float(p_lik))
Ejemplo n.º 14
0
def run(config):
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }
    # print(config)
    # exit()
    # Optionally, get the configuration from the state dict. This allows for
    # recovery of the config provided only a state dict and experiment name,
    # and can be convenient for writing less verbose sample shell scripts.
    if config['config_from_name']:
        # print(config['weights_root'],config['experiment_name'], config['load_weights'])
        utils.load_weights(None,
                           None,
                           state_dict,
                           config['weights_root'],
                           config['experiment_name'],
                           config['load_weights'],
                           None,
                           strict=False,
                           load_optim=False)
        # Ignore items which we might want to overwrite from the command line
        for item in state_dict['config']:
            if item not in [
                    'z_var', 'base_root', 'batch_size', 'G_batch_size',
                    'use_ema', 'G_eval_mode'
            ]:
                config[item] = state_dict['config'][item]

    # update config (see train.py for explanation)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    config = utils.update_config_roots(config)
    config['skip_init'] = True
    config['no_optim'] = True
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    G = model.Generator(**config).cuda()

    # zht: my code
    # D = model.Discriminator(**config).cuda()
    from torch.nn import ReLU
    config_fixed = {
        'dataset': 'I128_hdf5',
        'augment': False,
        'num_workers': 0,
        'pin_memory': True,
        'shuffle': True,
        'load_in_mem': False,
        'use_multiepoch_sampler': True,
        'model': 'BigGAN',
        'G_param': 'SN',
        'D_param': 'SN',
        'G_ch': 96,
        'D_ch': 96,
        'G_depth': 1,
        'D_depth': 1,
        'D_wide': True,
        'G_shared': True,
        'shared_dim': 128,
        'dim_z': 120,
        'z_var': 1.0,
        'hier': True,
        'cross_replica': False,
        'mybn': False,
        'G_nl': 'inplace_relu',
        'D_nl': 'inplace_relu',
        'G_attn': '64',
        'D_attn': '64',
        'norm_style': 'bn',
        'seed': 0,
        'G_init': 'ortho',
        'D_init': 'ortho',
        'skip_init': True,
        'G_lr': 0.0001,
        'D_lr': 0.0004,
        'G_B1': 0.0,
        'D_B1': 0.0,
        'G_B2': 0.999,
        'D_B2': 0.999,
        'batch_size': 256,
        'G_batch_size': 64,
        'num_G_accumulations': 8,
        'num_D_steps': 1,
        'num_D_accumulations': 8,
        'split_D': False,
        'num_epochs': 100,
        'parallel': True,
        'G_fp16': False,
        'D_fp16': False,
        'D_mixed_precision': False,
        'G_mixed_precision': False,
        'accumulate_stats': False,
        'num_standing_accumulations': 16,
        'G_eval_mode': True,
        'save_every': 1000,
        'num_save_copies': 2,
        'num_best_copies': 5,
        'which_best': 'IS',
        'no_fid': False,
        'test_every': 2000,
        'num_inception_images': 50000,
        'hashname': False,
        'base_root': '',
        'data_root': 'data',
        'weights_root': 'weights',
        'logs_root': 'logs',
        'samples_root': 'samples',
        'pbar': 'mine',
        'name_suffix': '',
        'experiment_name': '',
        'config_from_name': False,
        'ema': True,
        'ema_decay': 0.9999,
        'use_ema': True,
        'ema_start': 20000,
        'adam_eps': 1e-06,
        'BN_eps': 1e-05,
        'SN_eps': 1e-06,
        'num_G_SVs': 1,
        'num_D_SVs': 1,
        'num_G_SV_itrs': 1,
        'num_D_SV_itrs': 1,
        'G_ortho': 0.0,
        'D_ortho': 0.0,
        'toggle_grads': True,
        'which_train_fn': 'GAN',
        'load_weights': '',
        'resume': False,
        'logstyle': '%3.3e',
        'log_G_spectra': False,
        'log_D_spectra': False,
        'sv_log_interval': 10,
        'sample_npz': True,
        'sample_num_npz': 50000,
        'sample_sheets': True,
        'sample_interps': True,
        'sample_sheet_folder_num': -1,
        'sample_random': True,
        'sample_trunc_curves': '0.05_0.05_1.0',
        'sample_inception_metrics': True,
        'resolution': 128,
        'n_classes': 1000,
        'G_activation': ReLU(inplace=True),
        'D_activation': ReLU(inplace=True),
        'no_optim': True
    }
    # config_fixed = {'dataset': 'I128_hdf5', 'augment': False, 'num_workers': 0, 'pin_memory': True, 'shuffle': True, 'load_in_mem': False, 'use_multiepoch_sampler': True, 'model': 'BigGAN', 'G_param': 'SN', 'D_param': 'SN', 'G_ch': 96, 'D_ch': 96, 'G_depth': 1, 'D_depth': 1, 'D_wide': True, 'G_shared': True, 'shared_dim': 128, 'dim_z': 120, 'z_var': 1.0, 'hier': True, 'cross_replica': False, 'mybn': False, 'G_nl': 'inplace_relu', 'D_nl': 'inplace_relu', 'G_attn': '64', 'D_attn': '64', 'norm_style': 'bn', 'seed': 0, 'G_init': 'ortho', 'D_init': 'ortho', 'skip_init': True, 'G_lr': 0.0001, 'D_lr': 0.0004, 'G_B1': 0.0, 'D_B1': 0.0, 'G_B2': 0.999, 'D_B2': 0.999, 'batch_size': 256, 'G_batch_size': 64, 'num_G_accumulations': 8, 'num_D_steps': 1, 'num_D_accumulations': 8, 'split_D': False, 'num_epochs': 100, 'parallel': True, 'G_fp16': False, 'D_fp16': False, 'D_mixed_precision': False, 'G_mixed_precision': False, 'accumulate_stats': False, 'num_standing_accumulations': 16, 'G_eval_mode': True, 'save_every': 1000, 'num_save_copies': 2, 'num_best_copies': 5, 'which_best': 'IS', 'no_fid': False, 'test_every': 2000, 'num_inception_images': 50000, 'hashname': False, 'base_root': '', 'data_root': 'data', 'weights_root': 'weights', 'logs_root': 'logs', 'samples_root': 'samples', 'pbar': 'mine', 'name_suffix': '', 'experiment_name': '', 'config_from_name': False, 'ema': True, 'ema_decay': 0.9999, 'use_ema': True, 'ema_start': 20000, 'adam_eps': 1e-06, 'BN_eps': 1e-05, 'SN_eps': 1e-06, 'num_G_SVs': 1, 'num_D_SVs': 1, 'num_G_SV_itrs': 1, 'num_D_SV_itrs': 1, 'G_ortho': 0.0, 'D_ortho': 0.0, 'toggle_grads': True, 'which_train_fn': 'GAN', 'load_weights': '', 'resume': False, 'logstyle': '%3.3e', 'log_G_spectra': False, 'log_D_spectra': False, 'sv_log_interval': 10, 'sample_npz': True, 'sample_num_npz': 50000, 'sample_sheets': True, 'sample_interps': True, 'sample_sheet_folder_num': -1, 'sample_random': True, 'sample_trunc_curves': '0.05_0.05_1.0', 'sample_inception_metrics': True, 'resolution': 128, 'n_classes': 1000, 'no_optim': True}
    D = model.Discriminator(**config_fixed).cuda()
    utils.load_weights(None,
                       D,
                       state_dict,
                       config['weights_root'],
                       experiment_name,
                       config['load_weights'],
                       None,
                       strict=False,
                       load_optim=False)
    D.eval()

    utils.count_parameters(G)

    # Load weights
    print('Loading weights...')
    # Here is where we deal with the ema--load ema weights or load normal weights
    utils.load_weights(G if not (config['use_ema']) else None,
                       None,
                       state_dict,
                       config['weights_root'],
                       experiment_name,
                       config['load_weights'],
                       G if config['ema'] and config['use_ema'] else None,
                       strict=False,
                       load_optim=False)
    # Update batch size setting used for G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'],
                               z_var=config['z_var'])

    if config['G_eval_mode']:
        print('Putting G in eval mode..')
        G.eval()
    else:
        print('G is in %s mode...' % ('training' if G.training else 'eval'))

    #Sample function
    sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config)
    if config['accumulate_stats']:
        print('Accumulating standing stats across %d accumulations...' %
              config['num_standing_accumulations'])
        utils.accumulate_standing_stats(G, z_, y_, config['n_classes'],
                                        config['num_standing_accumulations'])

    # Sample a number of images and save them to an NPZ, for use with TF-Inception
    if config['sample_npz']:
        # Lists to hold images and labels for images
        x, y = [], []
        print('Sampling %d images and saving them to npz...' %
              config['sample_num_npz'])
        for i in trange(
                int(np.ceil(config['sample_num_npz'] / float(G_batch_size)))):
            with torch.no_grad():
                images, labels = sample()
            # zht: show discriminator results
            print(images.size(), labels.size())
            dis_loss = D(x=images, y=labels)
            print(dis_loss.size())
            print(dis_loss)
            exit()

            x += [np.uint8(255 * (images.cpu().numpy() + 1) / 2.)]
            y += [labels.cpu().numpy()]

            plt.imshow(x[0][i, :, :, :].transpose((1, 2, 0)))
            plt.show()

        x = np.concatenate(x, 0)[:config['sample_num_npz']]
        y = np.concatenate(y, 0)[:config['sample_num_npz']]
        print('Images shape: %s, Labels shape: %s' % (x.shape, y.shape))
        npz_filename = '%s/%s/samples.npz' % (config['samples_root'],
                                              experiment_name)
        print('Saving npz to %s...' % npz_filename)
        np.savez(npz_filename, **{'x': x, 'y': y})

    # Prepare sample sheets
    if config['sample_sheets']:
        print('Preparing conditional sample sheets...')
        utils.sample_sheet(
            G,
            classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']],
            num_classes=config['n_classes'],
            samples_per_class=10,
            parallel=config['parallel'],
            samples_root=config['samples_root'],
            experiment_name=experiment_name,
            folder_number=config['sample_sheet_folder_num'],
            z_=z_,
        )
    # Sample interp sheets
    if config['sample_interps']:
        print('Preparing interp sheets...')
        for fix_z, fix_y in zip([False, False, True], [False, True, False]):
            utils.interp_sheet(G,
                               num_per_sheet=16,
                               num_midpoints=8,
                               num_classes=config['n_classes'],
                               parallel=config['parallel'],
                               samples_root=config['samples_root'],
                               experiment_name=experiment_name,
                               folder_number=config['sample_sheet_folder_num'],
                               sheet_number=0,
                               fix_z=fix_z,
                               fix_y=fix_y,
                               device='cuda')
    # Sample random sheet
    if config['sample_random']:
        print('Preparing random sample sheet...')
        images, labels = sample()
        torchvision.utils.save_image(images.float(),
                                     '%s/%s/random_samples.jpg' %
                                     (config['samples_root'], experiment_name),
                                     nrow=int(G_batch_size**0.5),
                                     normalize=True)

    # Get Inception Score and FID
    get_inception_metrics = inception_utils.prepare_inception_metrics(
        config['dataset'], config['parallel'], config['no_fid'])

    # Prepare a simple function get metrics that we use for trunc curves
    def get_metrics():
        sample = functools.partial(utils.sample,
                                   G=G,
                                   z_=z_,
                                   y_=y_,
                                   config=config)
        IS_mean, IS_std, FID = get_inception_metrics(
            sample,
            config['num_inception_images'],
            num_splits=10,
            prints=False)
        # Prepare output string
        outstring = 'Using %s weights ' % ('ema'
                                           if config['use_ema'] else 'non-ema')
        outstring += 'in %s mode, ' % ('eval' if config['G_eval_mode'] else
                                       'training')
        outstring += 'with noise variance %3.3f, ' % z_.var
        outstring += 'over %d images, ' % config['num_inception_images']
        if config['accumulate_stats'] or not config['G_eval_mode']:
            outstring += 'with batch size %d, ' % G_batch_size
        if config['accumulate_stats']:
            outstring += 'using %d standing stat accumulations, ' % config[
                'num_standing_accumulations']
        outstring += 'Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f' % (
            state_dict['itr'], IS_mean, IS_std, FID)
        print(outstring)

    if config['sample_inception_metrics']:
        print('Calculating Inception metrics...')
        get_metrics()

    # Sample truncation curve stuff. This is basically the same as the inception metrics code
    if config['sample_trunc_curves']:
        start, step, end = [
            float(item) for item in config['sample_trunc_curves'].split('_')
        ]
        print(
            'Getting truncation values for variance in range (%3.3f:%3.3f:%3.3f)...'
            % (start, step, end))
        for var in np.arange(start, end + step, step):
            z_.var = var
            # Optionally comment this out if you want to run with standing stats
            # accumulated at one z variance setting
            if config['accumulate_stats']:
                utils.accumulate_standing_stats(
                    G, z_, y_, config['n_classes'],
                    config['num_standing_accumulations'])
            get_metrics()
Ejemplo n.º 15
0
def save_and_sample(G, D, E, G_ema, fixed_x, fixed_y, z_, y_,
                    state_dict, config, experiment_name, img_pool, save_weights):

    # sparsity mode
    if 'gradient_topk_mode_' in config['sparsity_mode']:
        sparsity_mode = ('gradient_topk', float(config['sparsity_mode'].split('_')[-1]))
    elif 'regularized_sparse_hw_topk_mode_' in config['sparsity_mode']:
        sparsity_mode = ('regularized_sparse_hw_topk', float(config['sparsity_mode'].split('_')[-1]))
    else:
        sparsity_mode = ("", 0.0)
    # Use EMA G for samples or non-EMA?
    which_G = G_ema if config['ema'] and config['use_ema'] else G

    if save_weights:
        utils.save_weights(G, D, E, state_dict, config['weights_root'],
                           experiment_name, None, G_ema if config['ema'] else None)
        # Save an additional copy to mitigate accidental corruption if process
        # is killed during a save (it's happened to me before -.-)
        if config['num_save_copies'] > 0:
            utils.save_weights(G, D, E, state_dict, config['weights_root'],
                               experiment_name,
                               'copy%d' % state_dict['save_num'],
                               G_ema if config['ema'] else None)
            state_dict['save_num'] = (
                state_dict['save_num'] + 1) % config['num_save_copies']



        # Accumulate standing statistics?
        if config['accumulate_stats']:
            utils.accumulate_standing_stats(G_ema if config['ema'] and config['use_ema'] else G,
                                            z_, y_, config['n_classes'],
                                            config['num_standing_accumulations'])

    # Save a random sample sheet with fixed z and y
    # TODO: change here to encode fixed x into z and feed z with fixed y into G
    fixed_X = fixed_x.split(config['batch_size'])
    fixed_Y = fixed_y.split(config['batch_size'])
    # print(f"fixed_X {fixed_X.shape}")
    for idx, fixed_x in enumerate(fixed_X):
        with torch.no_grad():
            # if config['parallel']:
            #     print("fixed_x: ", fixed_x.shape)
            #     if config['inference_nosample']:
            #         _, fixed_z, _ = [i.detach() for i in nn.parallel.data_parallel(E, fixed_x)]
            #     else:
            #         fixed_z, _, _ =  [i.detach() for i in nn.parallel.data_parallel(E, fixed_x)]
            #     print("fixed_z: ", fixed_z.shape)
            #     print("fixed_y: ", fixed_y.shape)

            #     output = nn.parallel.data_parallel(
            #             which_G, (fixed_z, which_G.shared(fixed_y), state_dict['itr'], repeat(None),\
            #                      repeat(True)))
                
            #     fixed_Gz = output[0].detach()
            #     intermediates = output[1]
            #     # else:
            #     #     mask_x_all = [item.detach() for item in output[1]] # [[n, k, h, w],...,]

            # else:
            fixed_y = fixed_Y[idx]
            fixed_z, _, _ = E(fixed_x)
            # print("fixed_x: ", fixed_x.shape)
            # print("fixed_z: ", fixed_z.shape)
            # print("fixed_y: ", fixed_y.shape)
            output = which_G(fixed_z, which_G.shared(fixed_y), state_dict['itr'], return_inter_activation=True)
            
            fixed_Gz = output[0].detach()
            intermediates = output[1]
        
        

        
        n, c, h, w = fixed_Gz.shape
        # log masks 
        save_dir = '%s/%s/%s' % (config['samples_root'], experiment_name, state_dict['itr'])
        os.makedirs(save_dir, exist_ok=True)
        image_filename = save_dir + '/fixed_samples%d-batch-%d.jpg' % (state_dict['itr'], idx)


        # nrows_image = config['onfly_class_inst_num'] if (config['dataset'] not in ['CelebA', 'tiny_imagenet']) else int(fixed_Gz.shape[0] ** 0.5)
        nrows_image = int(fixed_Gz.shape[0] ** 0.5)
        print(f"WHHHHHHH {nrows_image}")
        torchvision.utils.save_image(fixed_Gz, image_filename,
                                    nrow=nrows_image, normalize=True)

    print(f"sparsity_mode {sparsity_mode}")
    if sparsity_mode[0] in ['gradient_topk', 'regularized_sparse_hw_topk']: 
        
        # save each channel
        image_indexs = config['eval_image_indexs']
        block_names = intermediates['sparse_block']
        print(f"block_names {block_names}")
        for block in block_names:
            activation = intermediates[block].cpu()
            # print(f"activation {activation.shape}")
        
            for image_index in image_indexs:
                # print(f"image_index {image_index}")
                channel_image_filename = '%s/eval_iter_%s_img_%s_block_%s_channels_.jpg' % (save_dir,
                                                                state_dict['itr'],
                                                                str(image_index), 
                                                                block)
                
                # sort activation by activation magnitude
                # activation_sort = torch.argsort(activation[image_index].std((1,2)), descending=True)
                activation_i = activation[image_index].unsqueeze(1).repeat(1, 3, 1, 1)
                
                torchvision.utils.save_image(activation_i.float(), channel_image_filename,
                                            nrow=int(activation_i.shape[0] ** 0.5), normalize=True)

                # activation 10%?
                sparsity = torch.where(activation_i == 0)[0]
                print(f"Sparsity: {1 - sparsity.shape[0] / (activation_i.shape[0] * 3 * activation_i.shape[2] * activation_i.shape[3])}")


    

    # if True and mask_x_all:
    #     for layers in range(len(mask_x_all)):
    #         layer_map_name = os.path.join(save_dir, f"layer_{layers}_mask.jpg")
    #         mask_i = mask_x_all[layers].float().cpu() # [n, k, h_m, w_m]
    #         n, k, h, w = mask_i.shape
    #         # mask_i dot product
    #         mask_i_TT = torch.matmul(mask_i.view(n, k, -1), torch.transpose(mask_i.view(n, k, -1), 1, 2)) # n, k, k
    #         mask_i_TT = mask_i_TT.unsqueeze(1).repeat(1, 3, 1, 1) # n, 3, k, k
    #         cov_name = os.path.join(save_dir, f"layer_{layers}_mask_cov.jpg")
    #         torchvision.utils.save_image(mask_i_TT, cov_name, nrow=int(fixed_Gz.shape[0] ** 0.5), normalize=True)
    #         print(f"saved {cov_name}...")
    #         # # upsample latent map to real map
    #         # assert h % h_m == 0, f"mask {h_m}x{h_m} is not a factor of h {h}x{h}"
    #         # scale_factor = h // h_m
    #         # mask_i = F.interpolate(mask_i, scale_factor=scale_factor) # n, k, h, w

    #         mask_sum = mask_i.sum(1, keepdim=True).repeat(1, 3, 1, 1) # n, 3, h, w
            
    #         mask_i = mask_i.view(n, k, 1, h, w).repeat(1, 1, 3, 1, 1) # n, k, 3, h, w

            
            
    #         # print(f"fixed_Gz {fixed_Gz.unsqueeze(1).shape}")
    #         print(f"mask_sum { mask_sum.unsqueeze(1).shape}")
    #         print(f"mask_i {mask_i.shape}")
    #         # mask_img = torch.cat([fixed_Gz.float().cpu().unsqueeze(1), mask_sum.unsqueeze(1), mask_i], dim=1) # n, k+2, 3, h, w
    #         mask_img = torch.cat([mask_sum.unsqueeze(1), mask_i], dim=1) # n, k+2, 3, h, w
    #         torchvision.utils.save_image(mask_img.view(-1, 3, h, w), layer_map_name, nrow=k+1, normalize=True)
            
            # TODO: why the mask looks all the same for every images? Is the latent code all the same? for different n? should you use vae? 



    
    # if (not config['no_adaptive_tau']) and (state_dict['itr'] * config['sparse_decay_rate'] < 1.1):
    #     image_filename_train = '%s/%s/fixed_samples%d_train.jpg' % (config['samples_root'],
    #                                                 experiment_name,
    #                                                 state_dict['itr'])

    #     torchvision.utils.save_image(fixed_Gz_train.float().cpu(), image_filename_train,
    #                              nrow=int(fixed_Gz_train.shape[0] ** 0.5), normalize=True)
    # # For now, every time we save, also save sample sheets
    # utils.sample_sheet(which_G,
    #                    classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']],
    #                    num_classes=config['n_classes'],
    #                    samples_per_class=10, parallel=config['parallel'],
    #                    samples_root=config['samples_root'],
    #                    experiment_name=experiment_name,
    #                    folder_number=state_dict['itr'],
    #                    z_=z_,
    #                    iter_num=int(1 / config['sparse_decay_rate']) + 100)
    # # Also save interp sheets
    # for fix_z, fix_y in zip([False, False, True], [False, True, False]):
    #     utils.interp_sheet(which_G,
    #                        num_per_sheet=16,
    #                        num_midpoints=8,
    #                        num_classes=config['n_classes'],
    #                        parallel=config['parallel'],
    #                        samples_root=config['samples_root'],
    #                        experiment_name=experiment_name,
    #                        folder_number=state_dict['itr'],
    #                        sheet_number=0,
    #                        fix_z=fix_z, fix_y=fix_y, device='cuda',
    #                        iter_num=int(1 / config['sparse_decay_rate']) + 100)

    # plot the intermediate layer activation
    # plot_channel_activation(intermed_activation, config['img_index'], fixed_Gz[config['img_index']], experiment_name, config, state_dict, config['samples_root'])
    
    # Save ImagePool
    if config['img_pool_size'] != 0:
        img_pool.save()
Ejemplo n.º 16
0
def run(config):
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num_fair': 0,
        'save_best_num_fid': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'best_fair_d': 999999,
        'config': config
    }

    # Optionally, get the configuration from the state dict. This allows for
    # recovery of the config provided only a state dict and experiment name,
    # and can be convenient for writing less verbose sample shell scripts.
    if config['config_from_name']:
        utils.load_weights(None,
                           None,
                           state_dict,
                           config['weights_root'],
                           config['experiment_name'],
                           config['load_weights'],
                           None,
                           strict=False,
                           load_optim=False)
        # Ignore items which we might want to overwrite from the command line
        for item in state_dict['config']:
            if item not in [
                    'z_var', 'base_root', 'batch_size', 'G_batch_size',
                    'use_ema', 'G_eval_mode'
            ]:
                config[item] = state_dict['config'][item]

    # update config (see train.py for explanation)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = 1
    if config['conditional']:
        config['n_classes'] = 2
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    config = utils.update_config_roots(config)
    config['skip_init'] = True
    config['no_optim'] = True
    device = 'cuda'
    config['sample_num_npz'] = 10000
    print(config['ema_start'])

    # Seed RNG
    # utils.seed_rng(config['seed'])  # config['seed'])

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    G = model.Generator(**config).cuda()
    utils.count_parameters(G)

    # Load weights
    print('Loading weights...')
    assert config['mode'] in ['fair', 'fid']
    print('sampling from model with best FID scores...')
    config[
        'mode'] = 'fid'  # can change this to 'fair', but this assumes access to ground-truth attribute classifier (and labels)

    # find best weights for either FID or fair checkpointing
    weights_root = config['weights_root']
    ckpts = glob.glob(
        os.path.join(weights_root, experiment_name,
                     'state_dict_best_{}*'.format(config['mode'])))
    best_ckpt = 'best_{}{}'.format(config['mode'], len(ckpts) - 1)
    config['load_weights'] = best_ckpt

    # load weights to sample from generator
    utils.load_weights(G if not (config['use_ema']) else None,
                       None,
                       state_dict,
                       weights_root,
                       experiment_name,
                       config['load_weights'],
                       G if config['ema'] and config['use_ema'] else None,
                       strict=False,
                       load_optim=False)

    # Update batch size setting used for G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'],
                               z_var=config['z_var'])
    print('Putting G in eval mode..')
    G.eval()

    # Sample function
    sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config)
    if config['accumulate_stats']:
        print('Accumulating standing stats across %d accumulations...' %
              config['num_standing_accumulations'])
        utils.accumulate_standing_stats(G, z_, y_, config['n_classes'],
                                        config['num_standing_accumulations'])

    # # # Sample a number of images and save them to an NPZ, for use with TF-Inception
    sample_path = '%s/%s/' % (config['samples_root'], experiment_name)
    print('looking in sample path {}'.format(sample_path))
    if not os.path.exists(sample_path):
        print('creating sample path: {}'.format(sample_path))
        os.mkdir(sample_path)

    # Lists to hold images and labels for images
    print('saving samples from best FID checkpoint')
    # sampling 10 sets of 10K samples
    for k in range(10):
        npz_filename = '%s/%s/fid_samples_%s.npz' % (config['samples_root'],
                                                     experiment_name, k)
        if os.path.exists(npz_filename):
            print('samples already exist, skipping...')
            continue
        x, y = [], []
        print('Sampling %d images and saving them to npz...' %
              config['sample_num_npz'])
        for i in trange(
                int(np.ceil(config['sample_num_npz'] / float(G_batch_size)))):
            with torch.no_grad():
                images, labels = sample()
            x += [np.uint8(255 * (images.cpu().numpy() + 1) / 2.)]
            y += [labels.cpu().numpy()]
        x = np.concatenate(x, 0)[:config['sample_num_npz']]
        y = np.concatenate(y, 0)[:config['sample_num_npz']]
        print('checking labels: {}'.format(y.sum()))
        print('Images shape: %s, Labels shape: %s' % (x.shape, y.shape))
        npz_filename = '%s/%s/fid_samples_%s.npz' % (config['samples_root'],
                                                     experiment_name, k)
        print('Saving npz to %s...' % npz_filename)
        np.savez(npz_filename, **{'x': x, 'y': y})

    # classify proportions
    metrics = {'l2': 0, 'l1': 0, 'kl': 0}
    l2_db = np.zeros(10)
    l1_db = np.zeros(10)
    kl_db = np.zeros(10)

    # output file
    fname = '%s/%s/fair_disc_fid_samples.p' % (config['samples_root'],
                                               experiment_name)

    # load classifier
    if not config['multi']:
        print('Pre-loading pre-trained single-attribute classifier...')
        clf_state_dict = torch.load(CLF_PATH)['state_dict']
        clf_classes = 2
    else:
        # multi-attribute
        print('Pre-loading pre-trained multi-attribute classifier...')
        clf_state_dict = torch.load(MULTI_CLF_PATH)['state_dict']
        clf_classes = 4
    # load attribute classifier here
    clf = ResNet18(block=BasicBlock,
                   layers=[2, 2, 2, 2],
                   num_classes=clf_classes,
                   grayscale=False)
    clf.load_state_dict(clf_state_dict)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    clf = clf.to(device)
    clf.eval()  # turn off batch norm

    # classify examples and get probabilties
    n_classes = 2
    if config['multi']:
        n_classes = 4

    # number of classes
    probs_db = np.zeros((10, 10000, n_classes))
    for i in range(10):
        # grab appropriate samples
        npz_filename = '{}/{}/{}_samples_{}.npz'.format(
            config['samples_root'], experiment_name, config['mode'], i)
        preds, probs = classify_examples(clf, npz_filename)
        l2, l1, kl = utils.fairness_discrepancy(preds, clf_classes)

        # save metrics
        l2_db[i] = l2
        l1_db[i] = l1
        kl_db[i] = kl
        probs_db[i] = probs
        print('fair_disc for iter {} is: l2:{}, l1:{}, kl:{}'.format(
            i, l2, l1, kl))
    metrics['l2'] = l2_db
    metrics['l1'] = l1_db
    metrics['kl'] = kl_db
    print('fairness discrepancies saved in {}'.format(fname))
    print(l2_db)

    # save all metrics
    with open(fname, 'wb') as fp:
        pickle.dump(metrics, fp)
    np.save(
        os.path.join(config['samples_root'], experiment_name, 'clf_probs.npy'),
        probs_db)
Ejemplo n.º 17
0
def run(config_parser):
    # Geting config
    experiment_name = model_name = config_parser['experiment_name']
    model_path = '%s/%s' % (config_parser['weights_path'], model_name)
    logs_path = '%s/%s' % (config_parser['logs_path'], model_name)

    config_path = '%s/metalog.txt' % logs_path
    new_file = 'saved_stuff'
    save_path = '%s/%s' % (model_path, new_file)

    if not os.path.exists(save_path):
        os.mkdir(save_path)

    device = 'cuda'

    file = open(config_path, 'r')
    all_file = file.read()
    fs1 = all_file.find('{')
    fs2 = all_file.find('}')
    config = all_file[fs1:fs2 + 1]
    import ast
    config = config.replace(", 'G_activation': ReLU()", "")
    config = config.replace(", 'D_activation': ReLU()", "")
    config = ast.literal_eval(config)

    config['samples_root'] = 'samples_test'

    config['skip_init'] = True
    #config['no_optim']     = True

    # Loading Model
    config['weights_root'] = config_parser['weights_path']

    model = __import__(config['model'])
    utils.seed_rng(config['seed'])
    # Prepare root folders if necessary
    utils.prepare_root(config)

    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)
    if config['is_encoder']:
        E = model.Encoder(**{**config, 'D': D}).to(device)
    Prior = layers.Prior(**config).to(device)
    GE = model.G_E(G, E, Prior)

    utils.load_weights(
        G,
        None,
        '',
        config['weights_root'],
        model_name,
        config_parser['name_suffix'],
        G if config['ema'] else None,
        E=None if not config['is_encoder'] else E,
        Prior=Prior if not config['prior_type'] == 'default' else None)

    # Sample functions
    sample = functools.partial(utils.sample, G=G, Prior=Prior, config=config)

    # Accumulate stats?
    samples_name = 'samples'
    if config_parser['accumulate_stats']:
        samples_name = 'samples_acc'
        utils.accumulate_standing_stats(G, Prior, config['n_classes'],
                                        config['num_standing_accumulations'])

    if config_parser['name_suffix'] is not None:
        samples_name += config_parser['name_suffix']

    # Sample and save in npz
    config['sample_npz'] = True
    config['sample_num_npz'] = 50000
    G_batch_size = Prior.bs
    # Sample a number of images and save them to an NPZ, for use with TF-Inception

    if config['sample_npz']:
        # Lists to hold images and labels for images
        x, y = [], []
        print('Sampling %d images and saving them to npz...' %
              config['sample_num_npz'])
        for i in trange(
                int(np.ceil(config['sample_num_npz'] / float(G_batch_size)))):
            with torch.no_grad():
                images, labels = sample()
            x += [np.uint8(255 * (images.cpu().numpy() + 1) / 2.)]
            y += [labels.cpu().numpy()]
        x = np.concatenate(x, 0)[:config['sample_num_npz']]
        y = np.concatenate(y, 0)[:config['sample_num_npz']]
        print('Images shape: %s, Labels shape: %s' % (x.shape, y.shape))
        npz_filename = '%s/%s/%s.npz' % (config['weights_root'],
                                         experiment_name, samples_name)
        print('Num %d' % len(x))
        print('Saving npz to %s...' % npz_filename)
        np.savez(npz_filename, **{'x': x, 'y': y})

    # Reconstruction metrics
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    config_aux = config.copy()
    config_aux['augment'] = False
    dataloader_noaug = utils.get_data_loaders(
        **{
            **config_aux, 'batch_size': D_batch_size
        })

    if config_parser['accumulate_stats']:
        utils.accumulate_standing_stats_E(E, Prior, dataloader_noaug, device,
                                          config)
    test_acc, test_acc_iter, error_rec = train_fns.test_accuracy(
        GE, dataloader_noaug, device, config['D_fp16'], config)

    json_metric_name = samples_name + '_json'
    if not os.path.isfile('%s/%s.json' % (model_path, json_metric_name)):
        metric_dict = {}
        metric_dict['test_acc'] = test_acc
        metric_dict['error_rec'] = error_rec
        json.dump(metric_dict,
                  open('%s/%s.json' % (model_path, json_metric_name), 'w'))
    else:
        metric_dict = json.load(
            open('%s/%s.json' % (model_path, json_metric_name)))
        metric_dict['inception_mean'] = test_acc
        metric_dict['inception_std'] = error_rec
        json.dump(metric_dict,
                  open('%s/%s.json' % (model_path, json_metric_name), 'w'))
Ejemplo n.º 18
0
def save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y,
                                        state_dict, config, experiment_name,sample_only=False, use_real = False, real_batch = None,
                                         id = "", mixed=False, target_map = None):
    if not sample_only:
        utils.save_weights(G, D, state_dict, config['weights_root'],
                                             experiment_name, None, G_ema if config['ema'] else None)
        # Save an additional copy to mitigate accidental corruption if process
        # is killed during a save (it's happened to me before -.-)
        if config['num_save_copies'] > 0:
            utils.save_weights(G, D, state_dict, config['weights_root'],
                                                 experiment_name,
                                                 'copy%d' %    state_dict['save_num'],
                                                 G_ema if config['ema'] else None)
            state_dict['save_num'] = (state_dict['save_num'] + 1 ) % config['num_save_copies']
    else:
        # Use EMA G for samples or non-EMA?
        which_G = G_ema if config['ema'] and config['use_ema'] else G

        # Accumulate standing statistics?
        if config['accumulate_stats']:
            print("accumulating stats!")
            utils.accumulate_standing_stats(G_ema if config['ema'] and config['use_ema'] else G,
                                                         z_, y_, config['n_classes'],
                                                         config['num_standing_accumulations'])

        # Save a random sample sheet with fixed z and y
        with torch.no_grad():
            if use_real:
                fixed_Gz = real_batch
                experiment_name += "_real"
            else:
                if config['parallel']:
                    fixed_Gz =    nn.parallel.data_parallel(which_G, (fixed_z, which_G.shared(fixed_y)))
                else:
                    fixed_Gz = which_G(fixed_z, which_G.shared(fixed_y))


        if not os.path.isdir('%s/%s' % (config['samples_root'], experiment_name)):
            os.mkdir('%s/%s' % (config['samples_root'], experiment_name))
        image_filename = '%s/%s/fixed_samples%d' % (config['samples_root'],experiment_name,state_dict['itr'])
        image_filename += id + ".jpg"

        if not (state_dict["itr"]>config["sample_every"] and use_real and not mixed):
            torchvision.utils.save_image(fixed_Gz.float().cpu(), image_filename,
                                                             nrow=int(fixed_Gz.shape[0] **0.5), normalize=True)


        with torch.no_grad():

            D_map, c = D(fixed_Gz ,fixed_y )
            D_map = F.sigmoid(D_map)
            c = F.sigmoid(c)

            s = D_map.mean([2,3])
            s = s.view(-1)
            c = c.view(-1)
            cs =  torch.cat((c.view(c.size(0),1)   ,s.view(s.size(0),1) ),dim=1)
            cs = cs.cpu().numpy()
            cs = cs.round(3)
            if mixed:
                s_real = D_map.clone()
                s_real = s_real*target_map # all fakes are zero now
                s_real = s_real.sum([2,3])/target_map.sum([2,3])

                s_fake = D_map.clone()
                s_fake = s_fake*(1-target_map) # all real are zero now
                s_fake = s_fake.sum([2,3])/(1-target_map).sum([2,3])

                s_fake = s_fake.view(-1)
                s_real = s_real.view(-1)

                cs_real =  torch.cat((c.view(c.size(0),1)   ,s_real.view(s_real.size(0),1) ),dim=1)
                cs_real = cs_real.cpu().numpy()
                cs_real = cs_real.round(3)

                cs_fake =  torch.cat((c.view(c.size(0),1)   ,s_fake.view(s_fake.size(0),1) ),dim=1)
                cs_fake = cs_fake.cpu().numpy()
                cs_fake = cs_fake.round(3)

                cs_mix =  torch.cat((c.view(c.size(0),1) ,s_real.view(s_real.size(0),1)  ,s_fake.view(s_fake.size(0),1) ),dim=1)
                cs_mix = cs_mix.cpu().numpy()
                cs_mix = cs_mix.round(3)

            for i in range(D_map.size(0)):
                D_map[i] = D_map[i] - D_map[i].min()
                D_map[i] = D_map[i]/D_map[i].max()

            image_filename = '%s/%s/fixed_samples_D%d' % (config['samples_root'],experiment_name,state_dict['itr'])
            image_filename += id + ".jpg"
            torchvision.utils.save_image(D_map.float().cpu(), image_filename,
                                                                 nrow=int(fixed_Gz.shape[0] **0.5), normalize=False)


        if config["resolution"]==128:
            num_per_sheet=16
            num_midpoints=8
        elif config["resolution"]==256:
            num_per_sheet=8
            num_midpoints=4
        elif config["resolution"]==64:
            num_per_sheet=32
            num_midpoints=16

        if not use_real:
            # For now, every time we save, also save sample sheets
            utils.sample_sheet(which_G,
                                                 classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']],
                                                 num_classes=config['n_classes'],
                                                 samples_per_class=10, parallel=config['parallel'],
                                                 samples_root=config['samples_root'],
                                                 experiment_name=experiment_name,
                                                 folder_number=state_dict['itr'],
                                                 z_=z_)
            # Also save interp sheets
            if config["dataset"]=="coco":
                for fix_z, fix_y in zip([False, False, True], [False, True, False]):
                    utils.interp_sheet(which_G,
                                                         num_per_sheet=num_per_sheet,
                                                         num_midpoints=num_midpoints,
                                                         num_classes=config['n_classes'],
                                                         parallel=config['parallel'],
                                                         samples_root=config['samples_root'],
                                                         experiment_name=experiment_name,
                                                         folder_number=state_dict['itr'],
                                                         sheet_number=0,
                                                         fix_z=fix_z, fix_y=fix_y, device='cuda',config=config)
Ejemplo n.º 19
0
def run(config):
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # Optionally, get the configuration from the state dict. This allows for
    # recovery of the config provided only a state dict and experiment name,
    # and can be convenient for writing less verbose sample shell scripts.
    if config['config_from_name']:
        utils.load_weights(None,
                           None,
                           state_dict,
                           config['weights_root'],
                           config['experiment_name'],
                           config['load_weights'],
                           None,
                           strict=False,
                           load_optim=False)
        # Ignore items which we might want to overwrite from the command line
        for item in state_dict['config']:
            if item not in [
                    'z_var', 'base_root', 'batch_size', 'G_batch_size',
                    'use_ema', 'G_eval_mode'
            ]:
                config[item] = state_dict['config'][item]

    # update config (see train.py for explanation)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    config = utils.update_config_roots(config)
    config['skip_init'] = True
    config['no_optim'] = True
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    G = model.Generator(**config).cuda()
    utils.count_parameters(G)

    # Load weights
    print('Loading weights...')
    # Here is where we deal with the ema--load ema weights or load normal weights
    utils.load_weights(G if not (config['use_ema']) else None,
                       None,
                       state_dict,
                       config['weights_root'],
                       experiment_name,
                       config['load_weights'],
                       G if config['ema'] and config['use_ema'] else None,
                       strict=False,
                       load_optim=False)
    # Update batch size setting used for G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'],
                               z_var=config['z_var'])

    if config['G_eval_mode']:
        print('Putting G in eval mode..')
        G.eval()
    else:
        print('G is in %s mode...' % ('training' if G.training else 'eval'))

    # Sample function
    sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config)
    if config['accumulate_stats']:
        print('Accumulating standing stats across %d accumulations...' %
              config['num_standing_accumulations'])
        utils.accumulate_standing_stats(G, z_, y_, config['n_classes'],
                                        config['num_standing_accumulations'])

    # Sample a number of images and save them to an NPZ, for use with TF-Inception
    if config['sample_npz']:
        # Lists to hold images and labels for images
        x, y = [], []
        print('Sampling %d images and saving them to npz...' %
              config['sample_num_npz'])
        for i in trange(
                int(np.ceil(config['sample_num_npz'] / float(G_batch_size)))):
            with torch.no_grad():
                images, labels = sample()
            x += [np.uint8(255 * (images.cpu().numpy() + 1) / 2.)]
            y += [labels.cpu().numpy()]
        x = np.concatenate(x, 0)[:config['sample_num_npz']]
        y = np.concatenate(y, 0)[:config['sample_num_npz']]
        print('Images shape: %s, Labels shape: %s' % (x.shape, y.shape))
        npz_filename = '%s/%s/samples.npz' % (config['samples_root'],
                                              experiment_name)
        print('Saving npz to %s...' % npz_filename)
        np.savez(npz_filename, **{'x': x, 'y': y})

    # Prepare sample sheets
    if config['sample_sheets']:
        print('Preparing conditional sample sheets...')
        utils.sample_sheet(
            G,
            classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']],
            num_classes=config['n_classes'],
            samples_per_class=10,
            parallel=config['parallel'],
            samples_root=config['samples_root'],
            experiment_name=experiment_name,
            folder_number=config['sample_sheet_folder_num'],
            z_=z_,
        )
    # Sample interp sheets
    if config['sample_interps']:
        print('Preparing interp sheets...')
        for fix_z, fix_y in zip([False, False, True], [False, True, False]):
            utils.interp_sheet(G,
                               num_per_sheet=16,
                               num_midpoints=8,
                               num_classes=config['n_classes'],
                               parallel=config['parallel'],
                               samples_root=config['samples_root'],
                               experiment_name=experiment_name,
                               folder_number=config['sample_sheet_folder_num'],
                               sheet_number=0,
                               fix_z=fix_z,
                               fix_y=fix_y,
                               device='cuda')
    # Sample random sheet
    if config['sample_random']:
        print('Preparing random sample sheet...')
        images, labels = sample()
        torchvision.utils.save_image(images.float(),
                                     '%s/%s/random_samples.jpg' %
                                     (config['samples_root'], experiment_name),
                                     nrow=int(G_batch_size**0.5),
                                     normalize=True)

    # Get Inception Score and FID
    get_inception_metrics = inception_utils.prepare_inception_metrics(
        config['dataset'], config['parallel'], config['no_fid'])

    # Prepare a simple function get metrics that we use for trunc curves

    def get_metrics():
        sample = functools.partial(utils.sample,
                                   G=G,
                                   z_=z_,
                                   y_=y_,
                                   config=config)
        IS_mean, IS_std, FID = get_inception_metrics(
            sample,
            config['num_inception_images'],
            num_splits=10,
            prints=False)
        # Prepare output string
        outstring = 'Using %s weights ' % ('ema'
                                           if config['use_ema'] else 'non-ema')
        outstring += 'in %s mode, ' % ('eval' if config['G_eval_mode'] else
                                       'training')
        outstring += 'with noise variance %3.3f, ' % z_.var
        outstring += 'over %d images, ' % config['num_inception_images']
        if config['accumulate_stats'] or not config['G_eval_mode']:
            outstring += 'with batch size %d, ' % G_batch_size
        if config['accumulate_stats']:
            outstring += 'using %d standing stat accumulations, ' % config[
                'num_standing_accumulations']
        outstring += 'Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f' % (
            state_dict['itr'], IS_mean, IS_std, FID)
        print(outstring)

    if config['sample_inception_metrics']:
        print('Calculating Inception metrics...')
        get_metrics()

    # Sample truncation curve stuff. This is basically the same as the inception metrics code
    if config['sample_trunc_curves']:
        start, step, end = [
            float(item) for item in config['sample_trunc_curves'].split('_')
        ]
        print(
            'Getting truncation values for variance in range (%3.3f:%3.3f:%3.3f)...'
            % (start, step, end))
        for var in np.arange(start, end + step, step):
            z_.var = var
            # Optionally comment this out if you want to run with standing stats
            # accumulated at one z variance setting
            if config['accumulate_stats']:
                utils.accumulate_standing_stats(
                    G, z_, y_, config['n_classes'],
                    config['num_standing_accumulations'])
            get_metrics()
Ejemplo n.º 20
0
def save_and_sample(G, D, G_ema, z_, y_, yd_, fixed_z, fixed_y, fixed_yd,
                    state_dict, config, experiment_name):
    utils.save_weights(G, D, state_dict, config['weights_root'],
                       experiment_name, None, G_ema if config['ema'] else None)
    # Save an additional copy to mitigate accidental corruption if process
    # is killed during a save (it's happened to me before -.-)
    if config['num_save_copies'] > 0:
        utils.save_weights(G, D, state_dict, config['weights_root'],
                           experiment_name, 'copy%d' % state_dict['save_num'],
                           G_ema if config['ema'] else None)
        state_dict['save_num'] = (state_dict['save_num'] +
                                  1) % config['num_save_copies']

    # Use EMA G for samples or non-EMA?
    which_G = G_ema if config['ema'] and config['use_ema'] else G

    # Accumulate standing statistics?
    if config['accumulate_stats']:
        utils.accumulate_standing_stats(
            G_ema if config['ema'] and config['use_ema'] else G, z_, y_, yd_,
            config['n_classes'], config['num_standing_accumulations'])

    # Save a random sample sheet with fixed z and y
    which_G.eval()
    # with torch.no_grad():
    #   if config['parallel']:
    #     fixed_yd[:]=0
    #     fixed_Gzs =  nn.parallel.data_parallel(which_G, (fixed_z, which_G.shared(fixed_y),which_G.shared_d(fixed_yd)))
    #     fixed_yd[:] = 1
    #     fixed_Gzd = nn.parallel.data_parallel(which_G, (fixed_z, which_G.shared(fixed_y), which_G.shared_d(fixed_yd)))
    #   else:
    #     fixed_yd[:] = 0
    #     fixed_Gzs = which_G(fixed_z, which_G.shared(fixed_y),which_G.shared_d(fixed_yd))
    #     fixed_yd[:] = 1
    #     fixed_Gzd1 = which_G(fixed_z, which_G.shared(fixed_y),which_G.shared_d(fixed_yd))
    #     fixed_yd[:] = 2
    #     fixed_Gzd2 = which_G(fixed_z, which_G.shared(fixed_y), which_G.shared_d(fixed_yd))
    #     fixed_yd[:] = 3
    #     fixed_Gzd3 = which_G(fixed_z, which_G.shared(fixed_y), which_G.shared_d(fixed_yd))
    # if not os.path.isdir('%s/%s' % (config['samples_root'], experiment_name)):
    #   os.mkdir('%s/%s' % (config['samples_root'], experiment_name))
    #
    # image_filename = '%s/%s/source_fixed_samples%d.jpg' % (config['samples_root'],
    #                                                 experiment_name,
    #                                                 state_dict['itr'])
    # torchvision.utils.save_image(fixed_Gzs.float().cpu(), image_filename,
    #                            nrow=int(fixed_Gzs.shape[0] **0.5), normalize=True)
    #
    # ####################################################
    #
    # image_filename = '%s/%s/target_fixed_samples%d.jpg' % (config['samples_root'],
    #                                                        experiment_name,
    #                                                        state_dict['itr'])
    # torchvision.utils.save_image(fixed_Gzd1.float().cpu(), image_filename,
    #                              nrow=int(fixed_Gzd.shape[0] ** 0.5), normalize=True)
    # For now, every time we save, also save sample sheets
    utils.sample_sheet(which_G,
                       classes_per_sheet=10,
                       num_classes=config['n_classes'],
                       num_domain=config['n_domain'],
                       samples_per_class=10,
                       parallel=config['parallel'],
                       samples_root=config['samples_root'],
                       experiment_name=experiment_name,
                       folder_number=state_dict['itr'],
                       z_=z_)
    #
    # utils.sample_sheet_inter(which_G,
    #                    classes_per_sheet=9,
    #                    num_classes=config['n_classes'],
    #                    num_domain=config['n_domain'],
    #                    samples_per_class=9, parallel=config['parallel'],
    #                    samples_root=config['samples_root'],
    #                    experiment_name=experiment_name,
    #                    folder_number=state_dict['itr'],
    #                    z_=z_)

    # utils.interp_sheet(which_G,
    #                    num_per_sheet=16,
    #                    num_midpoints=8,
    #                    num_classes=config['n_classes'],
    #                    parallel=config['parallel'],
    #                    samples_root=config['samples_root'],
    #                    experiment_name=experiment_name,
    #                    folder_number=state_dict['itr'],
    #                    sheet_number=0,
    #                    fix_z=True, fix_y=True, fix_yd=False, device='cuda')

    which_G.train()