Beispiel #1
0
def sample_folder(args, config, G):
    args.sample_count = {}

    (z_, y_) = utils.prepare_z_y(args.batch_size,
                                 G.dim_z,
                                 config["n_classes"],
                                 device=config["device"],
                                 fp16=config["G_fp16"],
                                 z_var=config["z_var"],
                                 num_categories_to_sample=args.num_classes,
                                 per_category_to_sample=args.num_per_classes)

    out_dir = os.path.join(args.samples_dir, 'sample_folder')
    with torch.no_grad():
        count = 0
        while y_.next:
            z_.sample_()
            y_.sample_()
            if z_.shape[0] > y_.shape[0]:
                z_ = z_[y_.shape[0], :]
            if z_.shape[0] < y_.shape[0]:
                y_ = y_[z_.shape[0]]
            if args.parallel:
                nn.parallel.data_parallel(G, (z_, y_))
            else:
                image_tensors = G(z_, G.shared(y_))  # batch_size, 3, h, w
            save_samples(args,
                         sample=image_tensors,
                         save_dir=out_dir,
                         meta={'y': y_})
Beispiel #2
0
def run_eval(config):
    # update config (see train.py for explanation)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    config = utils.update_config_roots(config)
    config['skip_init'] = True
    config['no_optim'] = True
    device = 'cuda'

    model = __import__(config['model'])
    G = model.Generator(**config).cuda()
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'],
                               z_var=config['z_var'])
    get_inception_metrics = inception_tf.prepare_inception_metrics(
        config['dataset'], config['parallel'], config)

    network_url = config['network'].replace(
        'mit-han-lab:',
        'https://hanlab.mit.edu/projects/data-efficient-gans/models/')
    G.load_state_dict(torch.load(dnnlib.util.open_file_or_url(network_url)))
    if config['G_eval_mode']:
        G.eval()
    else:
        G.train()

    sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config)
    IS_list = []
    FID_list = []
    for _ in tqdm(range(config['repeat'])):
        IS, _, FID = get_inception_metrics(sample,
                                           config['num_inception_images'],
                                           num_splits=10,
                                           prints=False)
        IS_list.append(IS)
        FID_list.append(FID)

    if config['repeat'] > 1:
        print('IS mean: {}, std: {}'.format(np.mean(IS_list), np.std(IS_list)))
        print('FID mean: {}, std: {}'.format(np.mean(FID_list),
                                             np.std(FID_list)))
    else:
        print('IS: {}'.format(np.mean(IS_list)))
        print('FID: {}'.format(np.mean(FID_list)))
Beispiel #3
0
    def get_random_inputs(self, bs=1, target=None, seed=None):
        if seed is not None:
            torch.manual_seed(seed)

        (z_, y_) = utils.prepare_z_y(
            bs,
            self.generator.dim_z,
            self.config["n_classes"],
            device=self.config["device"],
            fp16=self.config["G_fp16"],
            z_var=self.config["z_var"],
            target=target,
            range=self.config["range"],
        )
        return (z_, y_)
    def __init__(self,
                 config,
                 model_name,
                 thr=None,
                 multi_gans=None,
                 gan_weights=None):
        # Updating settings
        G_batch_size = config['G_batch_size']
        n_classes = config['n_classes']

        # Loading GAN weights
        if multi_gans is None:
            self.G = utils.initialize(config, model_name)
        else:
            # Assuming that weight files follows the naming convention:
            # model_name_k, where k is in [0,multi_gans-1]
            self.G = [
                utils.initialize(config, model_name + "_%d" % k)
                for k in range(multi_gans)
            ]
        self.multi_gans = multi_gans
        self.gan_weights = gan_weights

        # Preparing sampling functions
        self.z_, self.y_ = utils.prepare_z_y(G_batch_size,
                                             config['dim_z'],
                                             n_classes,
                                             device='cuda',
                                             fp16=config['G_fp16'],
                                             z_var=config['z_var'],
                                             thr=thr)

        # Preparing fixed y tensors
        self.y_fixed = {
            y: utils.make_y(G_batch_size, y)
            for y in range(n_classes)
        }
Beispiel #5
0
def run(config):

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    ## *** 新增 resolution 使用 I128_hdf5 数据集, 这里也许需要使用 C10数据集
    config['resolution'] = utils.imsize_dict[config['dataset']]
    ## *** 新增 nclass_dict 加载 I128_hdf5 的类别, 这里也许需要使用 C10的类别 10类
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    ## 加载 GD的 激活函数, 都用Relu, 这里的Relu是小写,不知道是否要改大写R
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]

    ## 从头训练吧,么有历史的参数,不用改,默认的就是
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True

    ## 日志加载,也不用改应该
    config = utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    ## 设置初始随机数种子,都为0,*** 需要修改为paddle的设置
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    ## 设置日志根目录,这个应该也不用改
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    ## @@@ 这里不需要更改,直接注释掉,Paddle不一定需要这个设置
    ## 用于加速固定网络结构的参数
    # torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    ## *** !!! 这个方法很酷哦,直接导入BigGan的model,要看一下BigGAN里面的网络结构配置
    model = __import__(config['model'])
    ## 不用改,把一系列配置作为名字放到了实验名称中
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    ## *** 导入参数,需要修改两个方法
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)

    # If using EMA, prepare it
    ## *** 默认不开,可以先不改EMA部分
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None

    # FP16?
    ## C10比较小,G和D这部分也可以暂时不改,使用默认精度
    if config['G_fp16']:
        print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        # Consider automatically reducing SN_eps?
    ## 把设置完结构G和D打包放入结构模型G_D中
    GD = model.G_D(G, D)
    ## *** 这两个print也许可以删掉,没必要。可能源于继承的nn.Module的一些打印属性
    print(G)
    print(D)
    ## *** 这个parameters也是继承torch的属性
    print('Number of params in G: {} D: {}'.format(
        *
        [sum([p.data.nelement() for p in net.parameters()])
         for net in [G, D]]))
    # Prepare state dict, which holds things like epoch # and itr #
    ## 初始化统计参数记录表 不用变动
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # If loading from a pre-trained model, load weights
    ## 暂时不用预训练,所以这一块不用更改
    if config['resume']:
        print('Loading weights...')
        utils.load_weights(
            G, D, state_dict, config['weights_root'], experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            G_ema if config['ema'] else None)

    # If parallel, parallelize the GD module
    ## 暂时不用管,GD 默认不并行
    if config['parallel']:
        GD = nn.DataParallel(GD)
        if config['cross_replica']:
            patch_replication_callback(GD)

    ## 日志中心,应该也可以不用管,如果需要就是把IS和FID的结果看看能不能抽出来
    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                              experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])

    ## 这个才是重要的,这个是用来做结果统计的。
    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)

    ## *** D的数据加载,加载的过程中,get_data_loaders用到了torchvision的transforms方法
    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    loaders = utils.get_data_loaders(**{
        **config, 'batch_size': D_batch_size,
        'start_itr': state_dict['itr']
    })

    ## 准备评价指标,FID和IS的计算流程,可以使用np版本计算,也不用改
    # Prepare inception metrics: FID and IS
    get_inception_metrics = inception_utils.prepare_inception_metrics(
        config['dataset'], config['parallel'], config['no_fid'])

    ## 准备噪声和随机采样的标签组
    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    ## *** 有一部分torch的numpy用法,需要更改一下,获得噪声和标签
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])

    # Prepare a fixed z & y to see individual sample evolution throghout training
    ## *** 有一部分torch的numpy用法,需要更改一下,获得噪声和标签
    ## TODO 获得两份噪声和标签,有社么用意吗?
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])

    ## *** 从Distribution中获得采样的方法,可以选择高斯采样和categorical采样
    fixed_z.sample_()
    fixed_y.sample_()
    # Loaders are loaded, prepare the training function
    ## *** 实例化GAN_training_function训练流程
    if config['which_train_fn'] == 'GAN':
        train = train_fns.GAN_training_function(G, D, GD, z_, y_, ema,
                                                state_dict, config)
    # Else, assume debugging and use the dummy train fn
    ## 如果没有指定训练模型,那么就用假训走一下流程Debug
    else:
        train = train_fns.dummy_training_function()
    # Prepare Sample function for use with inception metrics
    ## *** 把函数utils.sample中部分入参事先占掉,定义为新的函数sample
    sample = functools.partial(
        utils.sample,
        G=(G_ema if config['ema'] and config['use_ema'] else G),
        z_=z_,
        y_=y_,
        config=config)

    print('Beginning training at epoch %d...' % state_dict['epoch'])
    # Train for specified number of epochs, although we mostly track G iterations.
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        # Which progressbar to use? TQDM or my own?
        if config['pbar'] == 'mine':
            ## 这一部分无需翻
            ## !!! loaders[0] 代表了数据采样对象
            pbar = utils.progress(loaders[0],
                                  displaytype='s1k' if
                                  config['use_multiepoch_sampler'] else 'eta')
        else:
            pbar = tqdm(loaders[0])
        for i, (x, y) in enumerate(pbar):
            # Increment the iteration counter
            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter much.
            ## *** 继承nn.Module中的train, 对应的是
            G.train()
            D.train()
            if config['ema']:
                G_ema.train()

            if config['D_fp16']:
                x, y = x.to(device).half(), y.to(device)
            else:
                x, y = x.to(device), y.to(device)
            ## *** 把数据和标签放入训练函数里,train本身有很多需要改写
            metrics = train(x, y)
            ## 记录日志,把metrics信息都输入日志
            train_log.log(itr=int(state_dict['itr']), **metrics)

            # Every sv_log_interval, log singular values
            ## 记录资格迹的变化日志
            if (config['sv_log_interval'] > 0) and (
                    not (state_dict['itr'] % config['sv_log_interval'])):
                train_log.log(itr=int(state_dict['itr']),
                              **{
                                  **utils.get_SVs(G, 'G'),
                                  **utils.get_SVs(D, 'D')
                              })

            # If using my progbar, print metrics.
            if config['pbar'] == 'mine':
                print(', '.join(
                    ['itr: %d' % state_dict['itr']] +
                    ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                      end=' ')

            # Save weights and copies as configured at specified interval
            ## 默认每2000步记录一次结果
            if not (state_dict['itr'] % config['save_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    ## *** module中的方法
                    G.eval()
                    ## 如果采用指数滑动平均
                    if config['ema']:
                        G_ema.eval()
                train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z,
                                          fixed_y, state_dict, config,
                                          experiment_name)

            # Test every specified interval
            ## 默认每5000步测试一次
            if not (state_dict['itr'] % config['test_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample,
                               get_inception_metrics, experiment_name,
                               test_log)
        # Increment epoch counter at end of epoch
        state_dict['epoch'] += 1
Beispiel #6
0
def run(config):
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }
    # print(config)
    # exit()
    # Optionally, get the configuration from the state dict. This allows for
    # recovery of the config provided only a state dict and experiment name,
    # and can be convenient for writing less verbose sample shell scripts.
    if config['config_from_name']:
        # print(config['weights_root'],config['experiment_name'], config['load_weights'])
        utils.load_weights(None,
                           None,
                           state_dict,
                           config['weights_root'],
                           config['experiment_name'],
                           config['load_weights'],
                           None,
                           strict=False,
                           load_optim=False)
        # Ignore items which we might want to overwrite from the command line
        for item in state_dict['config']:
            if item not in [
                    'z_var', 'base_root', 'batch_size', 'G_batch_size',
                    'use_ema', 'G_eval_mode'
            ]:
                config[item] = state_dict['config'][item]

    # update config (see train.py for explanation)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    config = utils.update_config_roots(config)
    config['skip_init'] = True
    config['no_optim'] = True
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    G = model.Generator(**config).cuda()

    # zht: my code
    # D = model.Discriminator(**config).cuda()
    from torch.nn import ReLU
    config_fixed = {
        'dataset': 'I128_hdf5',
        'augment': False,
        'num_workers': 0,
        'pin_memory': True,
        'shuffle': True,
        'load_in_mem': False,
        'use_multiepoch_sampler': True,
        'model': 'BigGAN',
        'G_param': 'SN',
        'D_param': 'SN',
        'G_ch': 96,
        'D_ch': 96,
        'G_depth': 1,
        'D_depth': 1,
        'D_wide': True,
        'G_shared': True,
        'shared_dim': 128,
        'dim_z': 120,
        'z_var': 1.0,
        'hier': True,
        'cross_replica': False,
        'mybn': False,
        'G_nl': 'inplace_relu',
        'D_nl': 'inplace_relu',
        'G_attn': '64',
        'D_attn': '64',
        'norm_style': 'bn',
        'seed': 0,
        'G_init': 'ortho',
        'D_init': 'ortho',
        'skip_init': True,
        'G_lr': 0.0001,
        'D_lr': 0.0004,
        'G_B1': 0.0,
        'D_B1': 0.0,
        'G_B2': 0.999,
        'D_B2': 0.999,
        'batch_size': 256,
        'G_batch_size': 64,
        'num_G_accumulations': 8,
        'num_D_steps': 1,
        'num_D_accumulations': 8,
        'split_D': False,
        'num_epochs': 100,
        'parallel': True,
        'G_fp16': False,
        'D_fp16': False,
        'D_mixed_precision': False,
        'G_mixed_precision': False,
        'accumulate_stats': False,
        'num_standing_accumulations': 16,
        'G_eval_mode': True,
        'save_every': 1000,
        'num_save_copies': 2,
        'num_best_copies': 5,
        'which_best': 'IS',
        'no_fid': False,
        'test_every': 2000,
        'num_inception_images': 50000,
        'hashname': False,
        'base_root': '',
        'data_root': 'data',
        'weights_root': 'weights',
        'logs_root': 'logs',
        'samples_root': 'samples',
        'pbar': 'mine',
        'name_suffix': '',
        'experiment_name': '',
        'config_from_name': False,
        'ema': True,
        'ema_decay': 0.9999,
        'use_ema': True,
        'ema_start': 20000,
        'adam_eps': 1e-06,
        'BN_eps': 1e-05,
        'SN_eps': 1e-06,
        'num_G_SVs': 1,
        'num_D_SVs': 1,
        'num_G_SV_itrs': 1,
        'num_D_SV_itrs': 1,
        'G_ortho': 0.0,
        'D_ortho': 0.0,
        'toggle_grads': True,
        'which_train_fn': 'GAN',
        'load_weights': '',
        'resume': False,
        'logstyle': '%3.3e',
        'log_G_spectra': False,
        'log_D_spectra': False,
        'sv_log_interval': 10,
        'sample_npz': True,
        'sample_num_npz': 50000,
        'sample_sheets': True,
        'sample_interps': True,
        'sample_sheet_folder_num': -1,
        'sample_random': True,
        'sample_trunc_curves': '0.05_0.05_1.0',
        'sample_inception_metrics': True,
        'resolution': 128,
        'n_classes': 1000,
        'G_activation': ReLU(inplace=True),
        'D_activation': ReLU(inplace=True),
        'no_optim': True
    }
    # config_fixed = {'dataset': 'I128_hdf5', 'augment': False, 'num_workers': 0, 'pin_memory': True, 'shuffle': True, 'load_in_mem': False, 'use_multiepoch_sampler': True, 'model': 'BigGAN', 'G_param': 'SN', 'D_param': 'SN', 'G_ch': 96, 'D_ch': 96, 'G_depth': 1, 'D_depth': 1, 'D_wide': True, 'G_shared': True, 'shared_dim': 128, 'dim_z': 120, 'z_var': 1.0, 'hier': True, 'cross_replica': False, 'mybn': False, 'G_nl': 'inplace_relu', 'D_nl': 'inplace_relu', 'G_attn': '64', 'D_attn': '64', 'norm_style': 'bn', 'seed': 0, 'G_init': 'ortho', 'D_init': 'ortho', 'skip_init': True, 'G_lr': 0.0001, 'D_lr': 0.0004, 'G_B1': 0.0, 'D_B1': 0.0, 'G_B2': 0.999, 'D_B2': 0.999, 'batch_size': 256, 'G_batch_size': 64, 'num_G_accumulations': 8, 'num_D_steps': 1, 'num_D_accumulations': 8, 'split_D': False, 'num_epochs': 100, 'parallel': True, 'G_fp16': False, 'D_fp16': False, 'D_mixed_precision': False, 'G_mixed_precision': False, 'accumulate_stats': False, 'num_standing_accumulations': 16, 'G_eval_mode': True, 'save_every': 1000, 'num_save_copies': 2, 'num_best_copies': 5, 'which_best': 'IS', 'no_fid': False, 'test_every': 2000, 'num_inception_images': 50000, 'hashname': False, 'base_root': '', 'data_root': 'data', 'weights_root': 'weights', 'logs_root': 'logs', 'samples_root': 'samples', 'pbar': 'mine', 'name_suffix': '', 'experiment_name': '', 'config_from_name': False, 'ema': True, 'ema_decay': 0.9999, 'use_ema': True, 'ema_start': 20000, 'adam_eps': 1e-06, 'BN_eps': 1e-05, 'SN_eps': 1e-06, 'num_G_SVs': 1, 'num_D_SVs': 1, 'num_G_SV_itrs': 1, 'num_D_SV_itrs': 1, 'G_ortho': 0.0, 'D_ortho': 0.0, 'toggle_grads': True, 'which_train_fn': 'GAN', 'load_weights': '', 'resume': False, 'logstyle': '%3.3e', 'log_G_spectra': False, 'log_D_spectra': False, 'sv_log_interval': 10, 'sample_npz': True, 'sample_num_npz': 50000, 'sample_sheets': True, 'sample_interps': True, 'sample_sheet_folder_num': -1, 'sample_random': True, 'sample_trunc_curves': '0.05_0.05_1.0', 'sample_inception_metrics': True, 'resolution': 128, 'n_classes': 1000, 'no_optim': True}
    D = model.Discriminator(**config_fixed).cuda()
    utils.load_weights(None,
                       D,
                       state_dict,
                       config['weights_root'],
                       experiment_name,
                       config['load_weights'],
                       None,
                       strict=False,
                       load_optim=False)
    D.eval()

    utils.count_parameters(G)

    # Load weights
    print('Loading weights...')
    # Here is where we deal with the ema--load ema weights or load normal weights
    utils.load_weights(G if not (config['use_ema']) else None,
                       None,
                       state_dict,
                       config['weights_root'],
                       experiment_name,
                       config['load_weights'],
                       G if config['ema'] and config['use_ema'] else None,
                       strict=False,
                       load_optim=False)
    # Update batch size setting used for G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'],
                               z_var=config['z_var'])

    if config['G_eval_mode']:
        print('Putting G in eval mode..')
        G.eval()
    else:
        print('G is in %s mode...' % ('training' if G.training else 'eval'))

    #Sample function
    sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config)
    if config['accumulate_stats']:
        print('Accumulating standing stats across %d accumulations...' %
              config['num_standing_accumulations'])
        utils.accumulate_standing_stats(G, z_, y_, config['n_classes'],
                                        config['num_standing_accumulations'])

    # Sample a number of images and save them to an NPZ, for use with TF-Inception
    if config['sample_npz']:
        # Lists to hold images and labels for images
        x, y = [], []
        print('Sampling %d images and saving them to npz...' %
              config['sample_num_npz'])
        for i in trange(
                int(np.ceil(config['sample_num_npz'] / float(G_batch_size)))):
            with torch.no_grad():
                images, labels = sample()
            # zht: show discriminator results
            print(images.size(), labels.size())
            dis_loss = D(x=images, y=labels)
            print(dis_loss.size())
            print(dis_loss)
            exit()

            x += [np.uint8(255 * (images.cpu().numpy() + 1) / 2.)]
            y += [labels.cpu().numpy()]

            plt.imshow(x[0][i, :, :, :].transpose((1, 2, 0)))
            plt.show()

        x = np.concatenate(x, 0)[:config['sample_num_npz']]
        y = np.concatenate(y, 0)[:config['sample_num_npz']]
        print('Images shape: %s, Labels shape: %s' % (x.shape, y.shape))
        npz_filename = '%s/%s/samples.npz' % (config['samples_root'],
                                              experiment_name)
        print('Saving npz to %s...' % npz_filename)
        np.savez(npz_filename, **{'x': x, 'y': y})

    # Prepare sample sheets
    if config['sample_sheets']:
        print('Preparing conditional sample sheets...')
        utils.sample_sheet(
            G,
            classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']],
            num_classes=config['n_classes'],
            samples_per_class=10,
            parallel=config['parallel'],
            samples_root=config['samples_root'],
            experiment_name=experiment_name,
            folder_number=config['sample_sheet_folder_num'],
            z_=z_,
        )
    # Sample interp sheets
    if config['sample_interps']:
        print('Preparing interp sheets...')
        for fix_z, fix_y in zip([False, False, True], [False, True, False]):
            utils.interp_sheet(G,
                               num_per_sheet=16,
                               num_midpoints=8,
                               num_classes=config['n_classes'],
                               parallel=config['parallel'],
                               samples_root=config['samples_root'],
                               experiment_name=experiment_name,
                               folder_number=config['sample_sheet_folder_num'],
                               sheet_number=0,
                               fix_z=fix_z,
                               fix_y=fix_y,
                               device='cuda')
    # Sample random sheet
    if config['sample_random']:
        print('Preparing random sample sheet...')
        images, labels = sample()
        torchvision.utils.save_image(images.float(),
                                     '%s/%s/random_samples.jpg' %
                                     (config['samples_root'], experiment_name),
                                     nrow=int(G_batch_size**0.5),
                                     normalize=True)

    # Get Inception Score and FID
    get_inception_metrics = inception_utils.prepare_inception_metrics(
        config['dataset'], config['parallel'], config['no_fid'])

    # Prepare a simple function get metrics that we use for trunc curves
    def get_metrics():
        sample = functools.partial(utils.sample,
                                   G=G,
                                   z_=z_,
                                   y_=y_,
                                   config=config)
        IS_mean, IS_std, FID = get_inception_metrics(
            sample,
            config['num_inception_images'],
            num_splits=10,
            prints=False)
        # Prepare output string
        outstring = 'Using %s weights ' % ('ema'
                                           if config['use_ema'] else 'non-ema')
        outstring += 'in %s mode, ' % ('eval' if config['G_eval_mode'] else
                                       'training')
        outstring += 'with noise variance %3.3f, ' % z_.var
        outstring += 'over %d images, ' % config['num_inception_images']
        if config['accumulate_stats'] or not config['G_eval_mode']:
            outstring += 'with batch size %d, ' % G_batch_size
        if config['accumulate_stats']:
            outstring += 'using %d standing stat accumulations, ' % config[
                'num_standing_accumulations']
        outstring += 'Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f' % (
            state_dict['itr'], IS_mean, IS_std, FID)
        print(outstring)

    if config['sample_inception_metrics']:
        print('Calculating Inception metrics...')
        get_metrics()

    # Sample truncation curve stuff. This is basically the same as the inception metrics code
    if config['sample_trunc_curves']:
        start, step, end = [
            float(item) for item in config['sample_trunc_curves'].split('_')
        ]
        print(
            'Getting truncation values for variance in range (%3.3f:%3.3f:%3.3f)...'
            % (start, step, end))
        for var in np.arange(start, end + step, step):
            z_.var = var
            # Optionally comment this out if you want to run with standing stats
            # accumulated at one z variance setting
            if config['accumulate_stats']:
                utils.accumulate_standing_stats(
                    G, z_, y_, config['n_classes'],
                    config['num_standing_accumulations'])
            get_metrics()
Beispiel #7
0
def run(config):
    if config['wandb_entity'] is not None:
        init_wandb(config, config['experiment_name'], config['wandb_entity'],
                   'imagenet')
    if config["G_path"] is None:  # Download a pre-trained G if necessary
        download_G()
        config["G_path"] = 'checkpoints/138k'
    G, state_dict, device, experiment_name = load_G(config)
    # If parallel, parallelize the GD module
    if config['parallel']:
        G = nn.DataParallel(DataParallelLoss(G))
        if config['cross_replica']:
            patch_replication_callback(G)

    num_gpus = torch.cuda.device_count()
    print(f'Using {num_gpus} GPUs')

    # If search_space != 'all', then we need to pad the z components that we are leaving alone:
    pad = get_direction_padding_fn(config)
    direction_size = config['dim_z'] if config[
        'search_space'] == 'all' else config['ndirs']
    # A is our (ndirs, |z|) matrix of directions, where ndirs indicates the number of directions we want to learn
    if config['load_A'] == 'coords':
        print('Initializing with standard basis directions')
        A = torch.nn.Parameter(torch.eye(config['ndirs'],
                                         direction_size,
                                         device=device),
                               requires_grad=True)
    elif config['load_A'] == 'random':
        print('Initializing with random directions')
        A = torch.nn.Parameter(torch.empty(config['ndirs'],
                                           direction_size,
                                           device=device),
                               requires_grad=True)
        torch.nn.init.kaiming_normal_(A)
    else:
        raise NotImplementedError
    # We only learn A; G is left frozen during training:
    optim = torch.optim.Adam(params=[A], lr=config['A_lr'])

    # Allow for different batch sizes in G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.module.G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])

    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.module.G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()

    interp_z, interp_y = utils.prepare_z_y(config["n_samples"],
                                           G.module.G.dim_z,
                                           config['n_classes'],
                                           device=device,
                                           fp16=config['G_fp16'])
    interp_z.sample_()
    interp_y.sample_()

    if config['fix_class'] is not None:
        y_ = y_.new_full(y_.size(), config['fix_class'])
        fixed_y = fixed_y.new_full(fixed_y.size(), config['fix_class'])
        interp_y = interp_y.new_full(interp_y.size(), config['fix_class'])

    print('Beginning training at epoch %d...' % state_dict['epoch'])
    # Train for specified number of epochs, although we mostly track G iterations.
    iters_per_epoch = 1000
    dummy_loader = [None] * iters_per_epoch  # We don't need any real data

    path_size = config['path_size']
    # Simply stores a |z|-dimensional one-hot vector indicating each direction we are learning:
    direction_indicators = torch.eye(config['ndirs']).to(device)

    G.eval()

    G.module.optim = optim

    writer = SummaryWriter('%s/%s' % (config['logs_root'], experiment_name))
    sample_sheet = train_fns.save_and_sample(G.module.G, None, G.module.G, z_,
                                             y_, fixed_z, fixed_y, state_dict,
                                             config, experiment_name)
    writer.add_image('samples', sample_sheet, 0)

    interp_y_ = G.module.G.shared(interp_y)
    # Make directions orthogonal via Gram Schmidt:
    Q = pad(fast_gram_schmidt(A)) if not config["no_ortho"] else pad(A)

    if config["vis_during_training"]:
        print("Generating initial visualizations...")
        interp_vis = visualize_directions(G.module.G,
                                          interp_z,
                                          interp_y_,
                                          path_sizes=path_size,
                                          Q=Q,
                                          high_quality=False,
                                          npv=1)
        for w_ix in range(config['ndirs']):
            writer.add_video('G_ema/w%03d' % w_ix, interp_vis[w_ix], 0, fps=24)

    for epoch in range(state_dict['epoch'], config['num_epochs']):
        if config['pbar'] == 'mine':
            pbar = utils.progress(dummy_loader,
                                  displaytype='s1k' if
                                  config['use_multiepoch_sampler'] else 'eta')
        else:
            pbar = tqdm(dummy_loader)
        for i, _ in enumerate(pbar):
            state_dict['itr'] += 1
            z_.sample_()
            if config['fix_class'] is None:
                y_.sample_()
            y = G.module.G.shared(y_)
            sampled_directions = torch.randint(low=0,
                                               high=config['ndirs'],
                                               size=(G_batch_size, ),
                                               device=device)
            # Distances are sampled from U[-path_size, path_size]:
            distances = torch.rand(G_batch_size, 1, device=device).mul(
                2 * path_size).add(-path_size)
            # w_sampled is an (N, ndirs)-shaped tensor. If i indexes batch elements and j indexes directions, then
            # w_sampled[i, j] represents how far we will move z[i] in the direction Q[j]. The final z[i] will be the sum
            # over all directions stored in the rows of Q.
            w_sampled = direction_indicators[sampled_directions] * distances
            # TODO: The Q.repeat below is a DataParallel hack to make sure each GPU gets the same copy of the Q matrix.
            # There is almost certainly a cleaner way to do this.
            # Hessian Penalty taken w.r.t. w_sampled, NOT z:
            penalty = G(z_, y, w=w_sampled, Q=Q.repeat(num_gpus, 1)).mean()

            optim.zero_grad()
            penalty.backward()
            optim.step()
            # re-orthogonalize A for visualizations and the next training iteration:
            Q = pad(fast_gram_schmidt(A)) if not config["no_ortho"] else pad(A)

            # Log metrics to TensorBoard/WandB:
            cur_training_iter = epoch * iters_per_epoch + i
            writer.add_scalar('Metrics/hessian_penalty', penalty.item(),
                              cur_training_iter)
            writer.add_scalar('Metrics/direction_norm',
                              A.pow(2).mean().pow(0.5).item(),
                              cur_training_iter)

            # Save directions and log visuals:
            if not (state_dict['itr'] % config['save_every']):
                torch.save(
                    A.cpu().detach(),
                    '%s/%s/A_%06d.pt' % (config['weights_root'],
                                         experiment_name, cur_training_iter))
                if config["vis_during_training"]:
                    interp_vis = visualize_directions(G.module.G,
                                                      interp_z,
                                                      interp_y_,
                                                      path_sizes=path_size,
                                                      Q=Q,
                                                      high_quality=False,
                                                      npv=1)
                    for w_ix in range(config['ndirs']):
                        writer.add_video('G_ema/w%03d' % w_ix,
                                         interp_vis[w_ix],
                                         cur_training_iter,
                                         fps=24)

        state_dict['epoch'] += 1
Beispiel #8
0
def run(config):

  # Update the config dict as necessary
  # This is for convenience, to add settings derived from the user-specified
  # configuration into the config-dict (e.g. inferring the number of classes
  # and size of the images from the dataset, passing in a pytorch object
  # for the activation specified as a string)
  config['resolution'] = utils.imsize_dict[config['dataset']]
  config['n_classes'] = utils.nclass_dict[config['dataset']]
  config['G_activation'] = utils.activation_dict[config['G_nl']]
  config['D_activation'] = utils.activation_dict[config['D_nl']]
  # By default, skip init if resuming training.
  if config['resume']:
    print('Skipping initialization for training resumption...')
    config['skip_init'] = True
  config = utils.update_config_roots(config)
  device = 'cuda'
  if config['base_root']:
    os.makedirs(config['base_root'],exist_ok=True)

  # Seed RNG
  utils.seed_rng(config['seed'])

  # Prepare root folders if necessary
  utils.prepare_root(config)

  # Setup cudnn.benchmark for free speed
  torch.backends.cudnn.benchmark = True

  # Import the model--this line allows us to dynamically select different files.
  model = __import__(config['model'])
  experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
  print('Experiment name is %s' % experiment_name)

  # Next, build the model
  G = model.Generator(**config).to(device)
  D = model.Discriminator(**config).to(device)
  
   # If using EMA, prepare it
  if config['ema']:
    print('Preparing EMA for G with decay of {}'.format(config['ema_decay']))
    G_ema = model.Generator(**{**config, 'skip_init':True, 
                               'no_optim': True}).to(device)
    ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
  else:
    G_ema, ema = None, None
  
  # FP16?
  if config['G_fp16']:
    print('Casting G to float16...')
    G = G.half()
    if config['ema']:
      G_ema = G_ema.half()
  if config['D_fp16']:
    print('Casting D to fp16...')
    D = D.half()
    # Consider automatically reducing SN_eps?
  GD = model.G_D(G, D)
  print(G)
  print(D)
  print('Number of params in G: {} D: {}'.format(
    *[sum([p.data.nelement() for p in net.parameters()]) for net in [G,D]]))
  # Prepare state dict, which holds things like epoch # and itr #
  state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0,
                'best_IS': 0, 'best_FID': 999999, 'config': config}

  # If loading from a pre-trained model, load weights
  if config['resume']:
    print('Loading weights...')
    utils.load_weights(G, D, state_dict,
                       config['weights_root'], experiment_name, 
                       config['load_weights'] if config['load_weights'] else None,
                       G_ema if config['ema'] else None,
                       )
    if G.lr_sched is not None:G.lr_sched.step(state_dict['epoch'])
    if D.lr_sched is not None:D.lr_sched.step(state_dict['epoch'])

  # If parallel, parallelize the GD module
  if config['parallel']:
    GD = nn.DataParallel(GD)
    if config['cross_replica']:
      patch_replication_callback(GD)

  # Prepare loggers for stats; metrics holds test metrics,
  # lmetrics holds any desired training metrics.
  test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                            experiment_name)
  train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
  print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
  test_log = utils.MetricsLogger(test_metrics_fname, 
                                 reinitialize=(not config['resume']))
  print('Training Metrics will be saved to {}'.format(train_metrics_fname))
  train_log = utils.MyLogger(train_metrics_fname, 
                             reinitialize=(not config['resume']),
                             logstyle=config['logstyle'])
  # Write metadata
  utils.write_metadata(config['logs_root'], experiment_name, config, state_dict)
  # Prepare data; the Discriminator's batch size is all that needs to be passed
  # to the dataloader, as G doesn't require dataloading.
  # Note that at every loader iteration we pass in enough data to complete
  # a full D iteration (regardless of number of D steps and accumulations)
  D_batch_size = (config['batch_size'] * config['num_D_steps']
                  * config['num_D_accumulations'])
  loaders = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size,
                                      'start_itr': state_dict['itr']})

  # Prepare inception metrics: FID and IS
  if not config['on_kaggle']:
    get_inception_metrics = inception_utils.prepare_inception_metrics(config['base_root'],config['dataset'], config['parallel'], config['no_fid'])

  # Prepare noise and randomly sampled label arrays
  # Allow for different batch sizes in G
  G_batch_size = max(config['G_batch_size'], config['batch_size'])
  if config['use_dog_cnt']:
    y_dist='categorical_dog_cnt'
  else:
    y_dist = 'categorical'

  dim_z=G.dim_z*2 if config['mix_style'] else G.dim_z
  z_, y_ = utils.prepare_z_y(G_batch_size, dim_z, config['n_classes'],
                             device=device, fp16=config['G_fp16'],z_dist=config['z_dist'],
                             threshold=config['truncated_threshold'],y_dist=y_dist)
  # Prepare a fixed z & y to see individual sample evolution throghout training
  fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, dim_z,
                                       config['n_classes'], device=device,
                                       fp16=config['G_fp16'],z_dist=config['z_dist'],
                                       threshold=config['truncated_threshold'],y_dist=y_dist)
  fixed_z.sample_()
  fixed_y.sample_()
  # Loaders are loaded, prepare the training function
  if config['which_train_fn'] == 'GAN':
    train = train_fns.GAN_training_function(G, D, GD, z_, y_, 
                                            ema, state_dict, config)
  # Else, assume debugging and use the dummy train fn
  else:
    train = train_fns.dummy_training_function()
  # Prepare Sample function for use with inception metrics
  sample = functools.partial(utils.sample,
                              G=(G_ema if config['ema'] and config['use_ema']
                                 else G),
                              z_=z_, y_=y_, config=config)

  print('Beginning training at epoch %d...' % state_dict['epoch'])
  #I find by epoch is more convelient,so I suggest change to it.if save_every<100,I will change to py epoch
  by_epoch=False if config['save_every']>100 else True


  # Train for specified number of epochs, although we mostly track G iterations.
  start_time = time.time()
  for epoch in range(state_dict['epoch'], config['num_epochs']):
    # Which progressbar to use? TQDM or my own?
    if config['on_kaggle']:
      pbar = loaders[0]
    elif config['pbar'] == 'mine':
      pbar = utils.progress(loaders[0],displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta')
    else:
      pbar = tqdm(loaders[0])
    epoch_start_time = time.time()
    for i, (x, y) in enumerate(pbar):
      # Increment the iteration counter
      state_dict['itr'] += 1
      # Make sure G and D are in training mode, just in case they got set to eval
      # For D, which typically doesn't have BN, this shouldn't matter much.
      G.train()
      D.train()
      if config['ema']:
        G_ema.train()
      if type(y) == list or type(y)==tuple:
        y=torch.cat([yi.unsqueeze(1) for yi in y],dim=1)

      if config['D_fp16']:
        x, y = x.to(device).half(), y.to(device)
      else:
        x, y = x.to(device), y.to(device)
      metrics = train(x, y)
      train_log.log(itr=int(state_dict['itr']), **metrics)
      
      # Every sv_log_interval, log singular values
      if (config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval'])):
        train_log.log(itr=int(state_dict['itr']), 
                      **{**utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D')})

      # If using my progbar, print metrics.
      if config['on_kaggle']:
        if i == len(loaders[0])-1:
          metrics_str = ', '.join(['%s : %+4.3f' % (key, metrics[key]) for key in metrics])
          epoch_time = (time.time()-epoch_start_time) / 60
          total_time = (time.time()-start_time) / 60
          print(f"[{epoch+1}/{config['num_epochs']}][{epoch_time:.1f}min/{total_time:.1f}min] {metrics_str}")
      elif config['pbar'] == 'mine':
        if D.lr_sched is None:
          print(', '.join(['epoch:%d' % (epoch+1),'itr: %d' % state_dict['itr']]
                       + ['%s : %+4.3f' % (key, metrics[key])
                       for key in metrics]), end=' ')
        else:
          print(', '.join(['epoch:%d' % (epoch+1),'lr:%.5f' % D.lr_sched.get_lr()[0] ,'itr: %d' % state_dict['itr']]
                       + ['%s : %+4.3f' % (key, metrics[key])
                       for key in metrics]), end=' ')
      if not by_epoch:
        # Save weights and copies as configured at specified interval
        if not (state_dict['itr'] % config['save_every']) and not config['on_kaggle']:
          if config['G_eval_mode']:
            print('Switchin G to eval mode...')
            G.eval()
            if config['ema']:
              G_ema.eval()
          train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y,
                                    state_dict, config, experiment_name)

        # Test every specified interval
        if not (state_dict['itr'] % config['test_every']) and not config['on_kaggle']:
          if config['G_eval_mode']:
            print('Switchin G to eval mode...')
            G.eval()
          train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample,
                         get_inception_metrics, experiment_name, test_log)

    if by_epoch:
      # Save weights and copies as configured at specified interval
      if not ((epoch+1) % config['save_every']) and not config['on_kaggle']:
        if config['G_eval_mode']:
          print('Switchin G to eval mode...')
          G.eval()
          if config['ema']:
            G_ema.eval()
        train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y,
                                  state_dict, config, experiment_name)

      # Test every specified interval
      if not ((epoch+1) % config['test_every']) and not config['on_kaggle']:
        if config['G_eval_mode']:
          print('Switchin G to eval mode...')
          G.eval()
        train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample,
                       get_inception_metrics, experiment_name, test_log)

      if G_ema is not None and (epoch+1) % config['test_every'] == 0 and not config['on_kaggle']:
        torch.save(G_ema.state_dict(),  '%s/%s/G_ema_epoch_%03d.pth' %
                   (config['weights_root'], config['experiment_name'], epoch+1))
    # Increment epoch counter at end of epoch
    state_dict['epoch'] += 1
    if G.lr_sched is not None:
      G.lr_sched.step()
    if D.lr_sched is not None:
      D.lr_sched.step()
  if config['on_kaggle']:
    train_fns.generate_submission(sample, config, experiment_name)
Beispiel #9
0
    def __init__(self,
                 G_batch_size=100,
                 batch_size=100,
                 dim_z=128,
                 n_classes=1000,
                 sigma=0.5,
                 is_y_uniform=False,
                 prior_type='default',
                 G_fp16=False,
                 arch_aux=0,
                 G_param='SN',
                 D_param='SN',
                 device='cuda',
                 P_lr=2e-4,
                 G_lr=5e-5,
                 G_B1=0.0,
                 G_B2=0.999,
                 adam_eps=1e-8,
                 num_G_SVs=1,
                 num_G_SV_itrs=1,
                 SN_eps=1e-12,
                 G_mixed_precision=False,
                 num_D_SVs=1,
                 num_D_SV_itrs=1,
                 G_activation=nn.ReLU(inplace=True),
                 GMM_init='ortho',
                 sharpen=1.0,
                 optimizer_type='adam',
                 weight_decay=5e-4,
                 **kwargs):

        super(Prior, self).__init__()
        dtype = torch.float16 if G_fp16 else torch.float32

        self.dim_z = dim_z
        self.sigma = sigma
        self.n_classes = n_classes
        self.prior_type = prior_type
        self.is_y_uniform = is_y_uniform
        self.bs = max(G_batch_size, batch_size)
        self.sharpen = sharpen
        self.weight_decay = weight_decay

        import utils
        self.z_, self.y_ = utils.prepare_z_y(self.bs,
                                             dim_z,
                                             n_classes,
                                             device=device,
                                             fp16=G_fp16)
        self.eps_ = self.z_

        which_embedding = nn.Embedding
        self.sample_ = self.sample_default
        self.obtain_latent_from_z_y = self.obtain_latent_from_z_y_default
        self.latent_classification = self.latent_classification_default
        G_activation = nn.ReLU(inplace=True)

        if prior_type == 'default':
            self.y_aux = (
                1 / self.n_classes *
                torch.arange(self.n_classes, dtype=torch.float).reshape(
                    1, n_classes)).cuda()

        elif prior_type == 'aux':
            if G_param == 'SN':
                which_linear = functools.partial(SNLinear,
                                                 num_svs=num_G_SVs,
                                                 num_itrs=num_G_SV_itrs,
                                                 eps=SN_eps)
            else:
                which_linear = nn.Linear

            if arch_aux == 0:
                self.gen_linear = which_linear(2 * dim_z, dim_z)
                latent_classification = nn.Sequential(
                    which_linear(dim_z, dim_z), G_activation,
                    which_linear(dim_z, n_classes), nn.Softmax())

            elif arch_aux == 1:
                self.gen_linear = nn.Sequential(which_linear(2 * dim_z, dim_z),
                                                nn.Tanh(True))
                latent_classification = nn.Sequential(
                    which_linear(dim_z, dim_z), G_activation,
                    which_linear(dim_z, n_classes), nn.Softmax())

            self.first_embedding = which_embedding(n_classes, dim_z)
            self.sample_ = self.sample_aux
            self.latent_classification = latent_classification
            self.obtain_latent_from_z_y = self.obtain_latent_from_z_y_aux

        elif prior_type == 'GMM':

            self.init = GMM_init
            self.mu_c = nn.Parameter(data=torch.zeros((n_classes, dim_z),
                                                      dtype=dtype),
                                     requires_grad=True)
            self.lv_c = nn.Parameter(data=torch.ones((n_classes, dim_z),
                                                     dtype=dtype),
                                     requires_grad=True)
            self.phi_c = nn.Parameter(data=self.sigma *
                                      torch.ones(n_classes, dtype=dtype),
                                      requires_grad=False)
            self.sample_ = self.sample_from_gmm
            self.latent_classification = self.gmm_membressy2
            self.obtain_latent_from_z_y = self.obtain_latent_from_z_y_gmm

            if self.init == 'ortho':
                init.orthogonal_(self.mu_c)
            elif self.init == 'N02':
                init.normal_(self.mu_c, 0, 0.02)
            elif self.init in ['glorot', 'xavier']:
                init.xavier_uniform_(self.mu_c)
            elif self.init == 'mu_sep':
                extra_dim = dim_z % n_classes
                reap_dim = dim_z // n_classes
                mu_init = 1
                gmm_mu = mu_init * (1 + self.sigma) * np.hstack(
                    (np.eye(n_classes).repeat(
                        reap_dim, 1), np.zeros((n_classes, extra_dim))))
                del self.mu_c
                self.mu_c = nn.Parameter(data=torch.tensor(gmm_mu,
                                                           dtype=dtype),
                                         requires_grad=True)

        if prior_type == 'aux' or prior_type == 'GMM':
            self.lr, self.B1, self.B2, self.adam_eps = P_lr, G_B1, G_B2, adam_eps
            if G_mixed_precision:
                print('Using fp16 adam in Prior...')
                self.optim = utils.Adam16(params=self.parameters(),
                                          lr=self.lr,
                                          betas=(self.B1, self.B2),
                                          weight_decay=0,
                                          eps=self.adam_eps)
            else:
                if optimizer_type == 'adam':
                    self.optim = optim.Adam(params=self.parameters(),
                                            lr=self.lr,
                                            betas=(self.B1, self.B2),
                                            weight_decay=0,
                                            eps=self.adam_eps)
                elif optimizer_type == 'radam':
                    self.optim = optimizers.RAdam(
                        params=self.parameters(),
                        lr=self.lr,
                        betas=(self.B1, self.B2),
                        weight_decay=self.weight_decay,
                        eps=self.adam_eps)

                elif optimizer_type == 'ranger':
                    self.optim = optimizers.Ranger(
                        params=self.parameters(),
                        lr=self.lr,
                        betas=(self.B1, self.B2),
                        weight_decay=self.weight_decay,
                        eps=self.adam_eps)

        if is_y_uniform:
            del self.y_
            self.y_ = torch.arange(n_classes).repeat(
                self.bs // n_classes, ).to(
                    device, device, torch.float16 if G_fp16 else torch.float32)
Beispiel #10
0
def run(config):
  # Update the config dict as necessary
  # This is for convenience, to add settings derived from the user-specified
  # configuration into the config-dict (e.g. inferring the number of classes
  # and size of the images from the dataset, passing in a pytorch object
  # for the activation specified as a string)
  config['resolution'] = utils.imsize_dict[config['dataset']]
  config['n_classes'] = utils.nclass_dict[config['dataset']]



  config['G_activation'] = utils.activation_dict[config['G_nl']]
  config['D_activation'] = utils.activation_dict[config['D_nl']]
  # By default, skip init if resuming training.
  if config['resume']:
    print('Skipping initialization for training resumption...')
    config['skip_init'] = True



  config = utils.update_config_roots(config)
  device = 'cuda'
  
  # Seed RNG
  utils.seed_rng(config['seed'])


  # Prepare root folders if necessary
  utils.prepare_root(config)



  # Import the model--this line allows us to dynamically select different files.
  model = __import__(config['model'])
  experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
  print('Experiment name is %s' % experiment_name)

  # Next, build the model
  G = model.Generator(**config).to(device)
  D = model.Discriminator(**config).to(device)
  
   # If using EMA, prepare it
  if config['ema']:
    print('Preparing EMA for G with decay of {}'.format(config['ema_decay']))
    G_ema = model.Generator(**{**config, 'skip_init':True, 
                               'no_optim': True}).to(device)
    ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
  else:
    G_ema, ema = None, None
  

    # Consider automatically reducing SN_eps?
  GD = model.G_D(G, D)
  print(G)
  print(D)
  print('Number of params in G: {} D: {}'.format(
    *[sum([np.prod(p.shape) for p in net.parameters()]) for net in [G,D]]))
  # Prepare state dict, which holds things like epoch # and itr #
  state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0,
                'best_IS': 0, 'best_FID': 999999, 'config': config}

  # If loading from a pre-trained model, load weights
  if config['resume']:
    print('Loading weights...')
    utils.load_weights(G, D, state_dict,
                       config['weights_root'], experiment_name, 
                       config['load_weights'] if config['load_weights'] else None,
                       G_ema if config['ema'] else None)



  # Prepare loggers for stats; metrics holds test metrics,
  # lmetrics holds any desired training metrics.
  test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                            experiment_name)
  train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
  print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
  #test_log = utils.MetricsLogger(test_metrics_fname, 
  #                               reinitialize=(not config['resume']))
  test_log=LogWriter(logdir='%s/%s_log' % (config['logs_root'],
                                            experiment_name))
  print('Training Metrics will be saved to {}'.format(train_metrics_fname))
  train_log = utils.MyLogger(train_metrics_fname, 
                             reinitialize=(not config['resume']),
                             logstyle=config['logstyle'])
  # Write metadata
  utils.write_metadata(config['logs_root'], experiment_name, config, state_dict)
  # Prepare data; the Discriminator's batch size is all that needs to be passed
  # to the dataloader, as G doesn't require dataloading.
  # Note that at every loader iteration we pass in enough data to complete
  # a full D iteration (regardless of number of D steps and accumulations)
  D_batch_size = (config['batch_size'] * config['num_D_steps']
                  * config['num_D_accumulations'])
  loaders = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size,
                                      'start_itr': state_dict['itr']})

  # Prepare inception metrics: FID and IS
  get_inception_metrics = inception_utils.prepare_inception_metrics(config['dataset'], config['parallel'], config['no_fid'])

  # Prepare noise and randomly sampled label arrays
  # Allow for different batch sizes in G
  G_batch_size = max(config['G_batch_size'], config['batch_size'])
  z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'],
                             device=device, fp16=config['G_fp16'])
  # Prepare a fixed z & y to see individual sample evolution throghout training
  fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z,
                                       config['n_classes'], device=device,
                                       fp16=config['G_fp16'])  
  fixed_z.sample_()
  fixed_y.sample_()
  # Loaders are loaded, prepare the training function
  if config['which_train_fn'] == 'GAN':
    train = train_fns.GAN_training_function(G, D, GD, z_, y_, 
                                            ema, state_dict, config)
  # Else, assume debugging and use the dummy train fn
  else:
    train = train_fns.dummy_training_function()
  # Prepare Sample function for use with inception metrics
  sample = functools.partial(utils.sample,
                              G=(G_ema if config['ema'] and config['use_ema']
                                 else G),
                              z_=z_, y_=y_, config=config)

  print('Beginning training at epoch %d...' % state_dict['epoch'])
  # Train for specified number of epochs, although we mostly track G iterations.
  for epoch in range(state_dict['epoch'], config['num_epochs']):    
    # Which progressbar to use? TQDM or my own?
    if config['pbar'] == 'mine':
      pbar = utils.progress(loaders[0],displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta')
    else:
      pbar = tqdm(loaders[0])
    for i, (x, y) in enumerate(pbar):
      # Increment the iteration counter
      state_dict['itr'] += 1
      # Make sure G and D are in training mode, just in case they got set to eval
      # For D, which typically doesn't have BN, this shouldn't matter much.
      G.train()
      D.train()
      x, y=x, y.astype(np.int64) ## special handling for paddle dataloader
      if config['ema']:
        G_ema.train()

      metrics = train(x, y)
      train_log.log(itr=int(state_dict['itr']), **metrics)

      for tag in metrics:
        try:
          test_log.add_scalar(step=int(state_dict['itr']),tag="train/"+tag,value=float(metrics[tag]))
        except:
          pass

      # Every sv_log_interval, log singular values
      if (config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval'])):
        train_log.log(itr=int(state_dict['itr']), 
                      **{**utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D')})

      # If using my progbar, print metrics.
      if config['pbar'] == 'mine':
          print(', '.join(['itr: %d' % state_dict['itr']] 
                           + ['%s : %+4.3f' % (key, metrics[key])
                           for key in metrics]), end=' ')
      else:
          pbar.set_description(', '.join(['itr: %d' % state_dict['itr']] 
                           + ['%s : %+4.3f' % (key, metrics[key])
                           for key in metrics]))

      # Save weights and copies as configured at specified interval
      if not (state_dict['itr'] % config['save_every']):
        if config['G_eval_mode']:
          print('Switchin G to eval mode...')
          G.eval()
          if config['ema']:
            G_ema.eval()
        train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, 
                                  state_dict, config, experiment_name)

      # Test every specified interval
      if not (state_dict['itr'] % config['test_every']):
        if config['G_eval_mode']:
          print('Switchin G to eval mode...')
          G.eval()
        train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample,
                       get_inception_metrics, experiment_name, test_log)
    # Increment epoch counter at end of epoch
    state_dict['epoch'] += 1
def run(config):
    if config["G_path"] is None:  # Download a pre-trained G if necessary
        download_G()
        config["G_path"] = f'checkpoints/138k'
    G, state_dict, device, experiment_name = load_G(config)
    # If parallel, parallelize the GD module
    if config['parallel']:
        G = nn.DataParallel(G)
        if config['cross_replica']:
            patch_replication_callback(G)
    pad = get_direction_padding_fn(config)
    ndirs = config["ndirs"] if config["directions_to_vis"] is None else len(
        config["directions_to_vis"])

    path_sizes = torch.tensor([config["path_size"]] * ndirs,
                              dtype=torch.float32)

    interp_z, interp_y = utils.prepare_z_y(config["n_samples"],
                                           G.module.dim_z,
                                           config['n_classes'],
                                           device=device,
                                           fp16=config['G_fp16'])
    interp_z.sample_()
    interp_y.sample_()

    if config['fix_class'] is not None:
        interp_y = interp_y.new_full(interp_y.size(), config['fix_class'])

    interp_y_ = G.module.shared(interp_y)

    direction_size = config["dim_z"] if config[
        "search_space"] == "all" else config["ndirs"]
    if config['load_A'] == 'random':
        print('Visualizing RANDOM directions')
        A = torch.randn(ndirs, direction_size)
        A_name = 'random'
        nn.init.kaiming_normal_(A)
    elif config['load_A'] == 'coord':
        print('Visualizing COORDINATE directions')
        A = torch.eye(ndirs, direction_size)
        A_name = 'coord'
    else:
        print('Visualizing PRE-TRAINED directions')
        A = torch.load(config["load_A"])
        A_name = 'pretrained'

    A = A.cuda()
    Q = pad(fast_gram_schmidt(A)) if not config["no_ortho"] else pad(A)

    visuals_dir = f'visuals/{experiment_name}/{A_name}'
    os.makedirs(visuals_dir, exist_ok=True)
    print('Generating interpolation videos...')
    visualize_directions(G,
                         interp_z,
                         interp_y_,
                         path_sizes=path_sizes,
                         Q=Q,
                         base_path=visuals_dir,
                         interp_steps=180,
                         interp_mode='smooth_center',
                         high_quality=True,
                         quiet=False,
                         minibatch_size=config["val_minibatch_size"],
                         directions_to_vis=config["directions_to_vis"])
Beispiel #12
0
def run(config):

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)
    E = model.ImgEncoder(**config).to(device)
    # E = model.Encoder(**config).to(device)

    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None

    # FP16?
    if config['G_fp16']:
        print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        # Consider automatically reducing SN_eps?
    GDE = model.G_D_E(G, D, E)

    print('Number of params in G: {} D: {} E: {}'.format(*[
        sum([p.data.nelement() for p in net.parameters()])
        for net in [G, D, E]
    ]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        utils.load_weights(
            G, D, E, state_dict, config['weights_root'], experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            G_ema if config['ema'] else None)

    # If parallel, parallelize the GD module
    if config['parallel']:
        GDE = nn.DataParallel(GDE)
        if config['cross_replica']:
            patch_replication_callback(GDE)

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                              experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])
    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)
    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    loaders, train_dataset = utils.get_data_loaders(
        **{
            **config, 'batch_size': D_batch_size,
            'start_itr': state_dict['itr']
        })

    # # Prepare inception metrics: FID and IS
    # get_inception_metrics = inception_utils.prepare_inception_metrics(
    #     config['dataset'], config['parallel'], config['no_fid'])

    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])
    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()
    print("fixed_y original: {} {}".format(fixed_y.shape, fixed_y[:10]))
    ## TODO: change the sample method to sample x and y
    fixed_x, fixed_y_of_x = utils.prepare_x_y(G_batch_size, train_dataset,
                                              experiment_name, config)

    # Build image pool to prevent mode collapes
    if config['img_pool_size'] != 0:
        img_pool = ImagePool(config['img_pool_size'], train_dataset.num_class,\
                                    save_dir=os.path.join(config['imgbuffer_root'], experiment_name),
                                    resume_buffer=config['resume_buffer'])
    else:
        img_pool = None

    # Loaders are loaded, prepare the training function
    if config['which_train_fn'] == 'GAN':
        train = train_fns.GAN_training_function(G, D, E, GDE, ema, state_dict,
                                                config, img_pool)
    # Else, assume debugging and use the dummy train fn
    else:
        train = train_fns.dummy_training_function()
    # Prepare Sample function for use with inception metrics
    sample = functools.partial(
        utils.sample,
        G=(G_ema if config['ema'] and config['use_ema'] else G),
        z_=z_,
        y_=y_,
        config=config)

    # print('Beginning training at epoch %f...' % (state_dict['itr'] * D_batch_size / len(train_dataset)))
    print("Beginning training at Epoch {} (iteration {})".format(
        state_dict['epoch'], state_dict['itr']))
    # # Train for specified number of epochs, although we mostly track G iterations.
    # for epoch in range(state_dict['epoch'], config['num_epochs']):
    # Which progressbar to use? TQDM or my own?
    if config['pbar'] == 'mine':
        pbar = utils.progress(
            loaders[0],
            displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta')
    else:
        pbar = tqdm(loaders[0])

    for i, (x, y) in enumerate(pbar):
        # Increment the iteration counter
        state_dict['itr'] += 1
        # Make sure G and D are in training mode, just in case they got set to eval
        # For D, which typically doesn't have BN, this shouldn't matter much.
        G.eval()
        D.eval()
        if config['ema']:
            G_ema.eval()
        if config['D_fp16']:
            x, y = x.to(device).half(), y.to(device)
        else:
            x, y = x.to(device), y.to(device)

        # Every sv_log_interval, log singular values
        if (config['sv_log_interval'] >
                0) and (not (state_dict['itr'] % config['sv_log_interval'])):
            train_log.log(itr=int(state_dict['itr']),
                          **{
                              **utils.get_SVs(G, 'G'),
                              **utils.get_SVs(D, 'D')
                          })

        # If using my progbar, print metrics.
        if config['pbar'] == 'mine':
            print(', '.join(
                ['itr: %d' % state_dict['itr']] +
                ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                  end=' ')

        # Save weights and copies as configured at specified interval
        if (not state_dict['itr'] % config['save_img_every']) or (
                not state_dict['itr'] % config['save_model_every']):
            if config['G_eval_mode']:
                print('Switchin G to eval mode...')
                G.eval()
                if config['ema']:
                    G_ema.eval()
            save_weights = config['save_weights']
            if state_dict['itr'] % config['save_model_every']:
                save_weights = False
            train_fns.save_and_sample(G,
                                      D,
                                      E,
                                      G_ema,
                                      fixed_x,
                                      fixed_y_of_x,
                                      z_,
                                      y_,
                                      state_dict,
                                      config,
                                      experiment_name,
                                      img_pool,
                                      save_weights=save_weights)

        # # Test every specified interval
        # if not (state_dict['itr'] % config['test_every']):
        #     if config['G_eval_mode']:
        #         print('Switchin G to eval mode...')
        #         G.eval()
        #     train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample,
        #                    get_inception_metrics, experiment_name, test_log)
        # Increment epoch counter at end of epoch
        state_dict['epoch'] = state_dict['itr'] * D_batch_size / (
            len(train_dataset))
        print("Finished Epoch {} (iteration {})".format(
            state_dict['epoch'], state_dict['itr']))
Beispiel #13
0
def run(config):
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # Optionally, get the configuration from the state dict. This allows for
    # recovery of the config provided only a state dict and experiment name,
    # and can be convenient for writing less verbose sample shell scripts.
    if config['config_from_name']:
        utils.load_weights(None,
                           None,
                           state_dict,
                           config['weights_root'],
                           config['experiment_name'],
                           config['load_weights'],
                           None,
                           strict=False,
                           load_optim=False)
        # Ignore items which we might want to overwrite from the command line
        for item in state_dict['config']:
            if item not in [
                    'z_var', 'base_root', 'batch_size', 'G_batch_size',
                    'use_ema', 'G_eval_mode'
            ]:
                config[item] = state_dict['config'][item]

    # update config (see train.py for explanation)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['n_channels'] = utils.nchannels_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    config = utils.update_config_roots(config)
    config['skip_init'] = True
    config['no_optim'] = True
    device = 'cuda'

    # Seed RNG
    # utils.seed_rng(config['seed'])

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    G = model.Generator(**config).cuda()
    utils.count_parameters(G)

    # In some cases we need to load D
    if True or config['get_test_error'] or config['get_train_error'] or config[
            'get_self_error'] or config['get_generator_error']:
        disc_config = config.copy()
        if config['mh_csc_loss'] or config['mh_loss']:
            disc_config['output_dim'] = disc_config['n_classes'] + 1
        D = model.Discriminator(**disc_config).to(device)

        def get_n_correct_from_D(x, y):
            """Gets the "classifications" from D.
      
      y: the correct labels
      
      In the case of projection discrimination we have to pass in all the labels
      as conditionings to get the class specific affinity.
      """
            x = x.to(device)
            if config['model'] == 'BigGAN':  # projection discrimination case
                if not config['get_self_error']:
                    y = y.to(device)
                yhat = D(x, y)
                for i in range(1, config['n_classes']):
                    yhat_ = D(x, ((y + i) % config['n_classes']))
                    yhat = torch.cat([yhat, yhat_], 1)
                preds_ = yhat.data.max(1)[1].cpu()
                return preds_.eq(0).cpu().sum()
            else:  # the mh gan case
                if not config['get_self_error']:
                    y = y.to(device)
                yhat = D(x)
                preds_ = yhat[:, :config['n_classes']].data.max(1)[1]
                return preds_.eq(y.data).cpu().sum()

    # Load weights
    print('Loading weights...')
    # Here is where we deal with the ema--load ema weights or load normal weights
    utils.load_weights(G if not (config['use_ema']) else None,
                       D,
                       state_dict,
                       config['weights_root'],
                       experiment_name,
                       config['load_weights'],
                       G if config['ema'] and config['use_ema'] else None,
                       strict=False,
                       load_optim=False)
    # Update batch size setting used for G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'],
                               z_var=config['z_var'])

    if config['G_eval_mode']:
        print('Putting G in eval mode..')
        G.eval()
    else:
        print('G is in %s mode...' % ('training' if G.training else 'eval'))

    sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config)
    brief_expt_name = config['experiment_name'][-30:]

    # load results dict always
    HIST_FNAME = 'scoring_hist.npy'

    def load_or_make_hist(d):
        """make/load history files in each
    """
        if not os.path.isdir(d):
            raise Exception('%s is not a valid directory' % d)
        f = os.path.join(d, HIST_FNAME)
        if os.path.isfile(f):
            return np.load(f, allow_pickle=True).item()
        else:
            return defaultdict(dict)

    hist_dir = os.path.join(config['weights_root'], config['experiment_name'])
    hist = load_or_make_hist(hist_dir)

    if config['get_test_error'] or config['get_train_error']:
        loaders = utils.get_data_loaders(
            **{
                **config, 'batch_size': config['batch_size'],
                'start_itr': state_dict['itr'],
                'use_test_set': config['get_test_error']
            })
        acc_type = 'Test' if config['get_test_error'] else 'Train'

        pbar = tqdm(loaders[0])
        loader_total = len(loaders[0]) * config['batch_size']
        sample_todo = min(config['sample_num_error'], loader_total)
        print('Getting %s error accross %i examples' % (acc_type, sample_todo))
        correct = 0
        total = 0

        with torch.no_grad():
            for i, (x, y) in enumerate(pbar):
                correct += get_n_correct_from_D(x, y)
                total += config['batch_size']
                if loader_total > total and total >= config['sample_num_error']:
                    print('Quitting early...')
                    break

        accuracy = float(correct) / float(total)
        hist = load_or_make_hist(hist_dir)
        hist[state_dict['itr']][acc_type] = accuracy
        np.save(os.path.join(hist_dir, HIST_FNAME), hist)

        print('[%s][%06d] %s accuracy: %f.' %
              (brief_expt_name, state_dict['itr'], acc_type, accuracy * 100))

    if config['get_self_error']:
        n_used_imgs = config['sample_num_error']
        correct = 0
        imageSize = config['resolution']
        x = np.empty((n_used_imgs, imageSize, imageSize, 3), dtype=np.uint8)
        for l in tqdm(range(n_used_imgs // G_batch_size),
                      desc='Generating [%s][%06d]' %
                      (brief_expt_name, state_dict['itr'])):
            with torch.no_grad():
                images, y = sample()
                correct += get_n_correct_from_D(images, y)

        accuracy = float(correct) / float(n_used_imgs)
        print('[%s][%06d] %s accuracy: %f.' %
              (brief_expt_name, state_dict['itr'], 'Self', accuracy * 100))
        hist = load_or_make_hist(hist_dir)
        hist[state_dict['itr']]['Self'] = accuracy
        np.save(os.path.join(hist_dir, HIST_FNAME), hist)

    if config['get_generator_error']:

        if config['dataset'] == 'C10':
            from classification.models.densenet import DenseNet121
            from torchvision import transforms
            compnet = DenseNet121()
            compnet = torch.nn.DataParallel(compnet)
            #checkpoint = torch.load(os.path.join('/scratch0/ilya/locDoc/classifiers/densenet121','ckpt_47.t7'))
            checkpoint = torch.load(
                os.path.join(
                    '/fs/vulcan-scratch/ilyak/locDoc/experiments/classifiers/cifar/densenet121',
                    'ckpt_47.t7'))
            compnet.load_state_dict(checkpoint['net'])
            compnet = compnet.to(device)
            compnet.eval()
            minimal_trans = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.2023, 0.1994, 0.2010)),
            ])
        elif config['dataset'] == 'C100':
            from classification.models.densenet import DenseNet121
            from torchvision import transforms
            compnet = DenseNet121(num_classes=100)
            compnet = torch.nn.DataParallel(compnet)
            checkpoint = torch.load(
                os.path.join(
                    '/scratch0/ilya/locDoc/classifiers/cifar100/densenet121',
                    'ckpt.copy.t7'))
            #checkpoint = torch.load(os.path.join('/fs/vulcan-scratch/ilyak/locDoc/experiments/classifiers/cifar100/densenet121','ckpt.copy.t7'))
            compnet.load_state_dict(checkpoint['net'])
            compnet = compnet.to(device)
            compnet.eval()
            minimal_trans = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.507, 0.487, 0.441),
                                     (0.267, 0.256, 0.276)),
            ])
        elif config['dataset'] == 'STL48':
            from classification.models.wideresnet import WideResNet48
            from torchvision import transforms
            checkpoint = torch.load(
                os.path.join(
                    '/fs/vulcan-scratch/ilyak/locDoc/experiments/classifiers/stl/mixmatch_48',
                    'model_best.pth.tar'))
            compnet = WideResNet48(num_classes=10)
            compnet = compnet.to(device)
            for param in compnet.parameters():
                param.detach_()
            compnet.load_state_dict(checkpoint['ema_state_dict'])
            compnet.eval()
            minimal_trans = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.2023, 0.1994, 0.2010)),
            ])
        else:
            raise ValueError('Dataset %s has no comparison network.' %
                             config['dataset'])

        n_used_imgs = 10000
        correct = 0
        mean_label = np.zeros(config['n_classes'])
        imageSize = config['resolution']
        x = np.empty((n_used_imgs, imageSize, imageSize, 3), dtype=np.uint8)
        for l in tqdm(range(n_used_imgs // G_batch_size),
                      desc='Generating [%s][%06d]' %
                      (brief_expt_name, state_dict['itr'])):
            with torch.no_grad():
                images, y = sample()
                fake = images.data.cpu().numpy()
                fake = np.floor((fake + 1) * 255 / 2.0).astype(np.uint8)
                fake_input = np.zeros(fake.shape)
                for bi in range(fake.shape[0]):
                    fake_input[bi] = minimal_trans(np.moveaxis(
                        fake[bi], 0, -1))
                images.data.copy_(torch.from_numpy(fake_input))
                lab = compnet(images).max(1)[1]
                mean_label += np.bincount(lab.data.cpu(),
                                          minlength=config['n_classes'])
                correct += int((lab == y).sum().cpu())

        accuracy = float(correct) / float(n_used_imgs)
        mean_label_normalized = mean_label / float(n_used_imgs)

        print(
            '[%s][%06d] %s accuracy: %f.' %
            (brief_expt_name, state_dict['itr'], 'Generator', accuracy * 100))
        hist = load_or_make_hist(hist_dir)
        hist[state_dict['itr']]['Generator'] = accuracy
        hist[state_dict['itr']]['Mean_Label'] = mean_label_normalized
        np.save(os.path.join(hist_dir, HIST_FNAME), hist)

    if config['accumulate_stats']:
        print('Accumulating standing stats across %d accumulations...' %
              config['num_standing_accumulations'])
        utils.accumulate_standing_stats(G, z_, y_, config['n_classes'],
                                        config['num_standing_accumulations'])

    # Sample a number of images and save them to an NPZ, for use with TF-Inception
    if config['sample_npz']:
        # Lists to hold images and labels for images
        x, y = [], []
        print('Sampling %d images and saving them to npz...' %
              config['sample_num_npz'])
        for i in trange(
                int(np.ceil(config['sample_num_npz'] / float(G_batch_size)))):
            with torch.no_grad():
                images, labels = sample()
            x += [np.uint8(255 * (images.cpu().numpy() + 1) / 2.)]
            y += [labels.cpu().numpy()]
        x = np.concatenate(x, 0)[:config['sample_num_npz']]
        y = np.concatenate(y, 0)[:config['sample_num_npz']]
        print('Images shape: %s, Labels shape: %s' % (x.shape, y.shape))
        npz_filename = '%s/%s/samples.npz' % (config['samples_root'],
                                              experiment_name)
        print('Saving npz to %s...' % npz_filename)
        np.savez(npz_filename, **{'x': x, 'y': y})

    if config['official_FID']:
        f = np.load(config['dataset_is_fid'])
        # this is for using the downloaded one from
        # https://github.com/bioinf-jku/TTUR
        #mdata, sdata = f['mu'][:], f['sigma'][:]

        # this one is for my format files
        mdata, sdata = f['mfid'], f['sfid']

    # Sample a number of images and stick them in memory, for use with TF-Inception official_IS and official_FID
    data_gen_necessary = False
    if config['sample_np_mem']:
        is_saved = int('IS' in hist[state_dict['itr']])
        is_todo = int(config['official_IS'])
        fid_saved = int('FID' in hist[state_dict['itr']])
        fid_todo = int(config['official_FID'])
        data_gen_necessary = config['overwrite'] or (is_todo > is_saved) or (
            fid_todo > fid_saved)
    if config['sample_np_mem'] and data_gen_necessary:
        n_used_imgs = 50000
        imageSize = config['resolution']
        x = np.empty((n_used_imgs, imageSize, imageSize, 3), dtype=np.uint8)
        for l in tqdm(range(n_used_imgs // G_batch_size),
                      desc='Generating [%s][%06d]' %
                      (brief_expt_name, state_dict['itr'])):
            start = l * G_batch_size
            end = start + G_batch_size

            with torch.no_grad():
                images, labels = sample()
            fake = np.uint8(255 * (images.cpu().numpy() + 1) / 2.)
            x[start:end] = np.moveaxis(fake, 1, -1)
            #y += [labels.cpu().numpy()]

    if config['official_IS']:
        if (not ('IS' in hist[state_dict['itr']])) or config['overwrite']:
            mis, sis = iscore.get_inception_score(x)
            print('[%s][%06d] IS mu: %f. IS sigma: %f.' %
                  (brief_expt_name, state_dict['itr'], mis, sis))
            hist = load_or_make_hist(hist_dir)
            hist[state_dict['itr']]['IS'] = [mis, sis]
            np.save(os.path.join(hist_dir, HIST_FNAME), hist)
        else:
            mis, sis = hist[state_dict['itr']]['IS']
            print(
                '[%s][%06d] Already done (skipping...): IS mu: %f. IS sigma: %f.'
                % (brief_expt_name, state_dict['itr'], mis, sis))

    if config['official_FID']:
        import tensorflow as tf

        def fid_ms_for_imgs(images, mem_fraction=0.5):
            gpu_options = tf.GPUOptions(
                per_process_gpu_memory_fraction=mem_fraction)
            inception_path = fid.check_or_download_inception(None)
            fid.create_inception_graph(
                inception_path)  # load the graph into the current TF graph
            with tf.Session(config=tf.ConfigProto(
                    gpu_options=gpu_options)) as sess:
                sess.run(tf.global_variables_initializer())
                mu_gen, sigma_gen = fid.calculate_activation_statistics(
                    images, sess, batch_size=100)
            return mu_gen, sigma_gen

        if (not ('FID' in hist[state_dict['itr']])) or config['overwrite']:
            m1, s1 = fid_ms_for_imgs(x)
            fid_value = fid.calculate_frechet_distance(m1, s1, mdata, sdata)
            print('[%s][%06d] FID: %f' %
                  (brief_expt_name, state_dict['itr'], fid_value))
            hist = load_or_make_hist(hist_dir)
            hist[state_dict['itr']]['FID'] = fid_value
            np.save(os.path.join(hist_dir, HIST_FNAME), hist)
        else:
            fid_value = hist[state_dict['itr']]['FID']
            print('[%s][%06d] Already done (skipping...): FID: %f' %
                  (brief_expt_name, state_dict['itr'], fid_value))

    # Prepare sample sheets
    if config['sample_sheets']:
        print('Preparing conditional sample sheets...')
        folder_number = config['sample_sheet_folder_num']
        if folder_number == -1:
            folder_number = config['load_weights']
        utils.sample_sheet(
            G,
            classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']],
            num_classes=config['n_classes'],
            samples_per_class=10,
            parallel=config['parallel'],
            samples_root=config['samples_root'],
            experiment_name=experiment_name,
            folder_number=folder_number,
            z_=z_,
        )
    # Sample interp sheets
    if config['sample_interps']:
        print('Preparing interp sheets...')
        folder_number = config['sample_sheet_folder_num']
        if folder_number == -1:
            folder_number = config['load_weights']
        for fix_z, fix_y in zip([False, False, True], [False, True, False]):
            utils.interp_sheet(G,
                               num_per_sheet=16,
                               num_midpoints=8,
                               num_classes=config['n_classes'],
                               parallel=config['parallel'],
                               samples_root=config['samples_root'],
                               experiment_name=experiment_name,
                               folder_number=int(folder_number),
                               sheet_number=0,
                               fix_z=fix_z,
                               fix_y=fix_y,
                               device='cuda')
    # Sample random sheet
    if config['sample_random']:
        print('Preparing random sample sheet...')
        images, labels = sample()
        torchvision.utils.save_image(
            images.float(),
            '%s/%s/%s.jpg' %
            (config['samples_root'], experiment_name, config['load_weights']),
            nrow=int(G_batch_size**0.5),
            normalize=True)

    # Prepare a simple function get metrics that we use for trunc curves
    def get_metrics():
        # Get Inception Score and FID
        get_inception_metrics = inception_utils.prepare_inception_metrics(
            config['dataset'], config['parallel'], config['no_fid'])
        sample = functools.partial(utils.sample,
                                   G=G,
                                   z_=z_,
                                   y_=y_,
                                   config=config)
        IS_mean, IS_std, FID = get_inception_metrics(
            sample,
            config['num_inception_images'],
            num_splits=10,
            prints=False)
        # Prepare output string
        outstring = 'Using %s weights ' % ('ema'
                                           if config['use_ema'] else 'non-ema')
        outstring += 'in %s mode, ' % ('eval' if config['G_eval_mode'] else
                                       'training')
        outstring += 'with noise variance %3.3f, ' % z_.var
        outstring += 'over %d images, ' % config['num_inception_images']
        if config['accumulate_stats'] or not config['G_eval_mode']:
            outstring += 'with batch size %d, ' % G_batch_size
        if config['accumulate_stats']:
            outstring += 'using %d standing stat accumulations, ' % config[
                'num_standing_accumulations']
        outstring += 'Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f' % (
            state_dict['itr'], IS_mean, IS_std, FID)
        print(outstring)

    if config['sample_inception_metrics']:
        print('Calculating Inception metrics...')
        get_metrics()

    # Sample truncation curve stuff. This is basically the same as the inception metrics code
    if config['sample_trunc_curves']:
        start, step, end = [
            float(item) for item in config['sample_trunc_curves'].split('_')
        ]
        print(
            'Getting truncation values for variance in range (%3.3f:%3.3f:%3.3f)...'
            % (start, step, end))
        for var in np.arange(start, end + step, step):
            z_.var = var
            # Optionally comment this out if you want to run with standing stats
            # accumulated at one z variance setting
            if config['accumulate_stats']:
                utils.accumulate_standing_stats(
                    G, z_, y_, config['n_classes'],
                    config['num_standing_accumulations'])
            get_metrics()
Beispiel #14
0
def run(config):
    logger = logging.getLogger('tl')
    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = importlib.import_module(config['model'])
    # model = __import__(config['model'])
    experiment_name = 'exp'
    # experiment_name = (config['experiment_name'] if config['experiment_name']
    #                      else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = model.Generator(**config, cfg=getattr(global_cfg, 'generator',
                                              None)).to(device)
    D = model.Discriminator(**config,
                            cfg=getattr(global_cfg, 'discriminator',
                                        None)).to(device)

    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(**{
            **config, 'skip_init': True,
            'no_optim': True
        },
                                cfg=getattr(global_cfg, 'generator',
                                            None)).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None

    # FP16?
    if config['G_fp16']:
        print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        # Consider automatically reducing SN_eps?
    GD = model.G_D(G, D)
    logger.info(G)
    logger.info(D)
    logger.info('Number of params in G: {} D: {}'.format(
        *
        [sum([p.data.nelement() for p in net.parameters()])
         for net in [G, D]]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        utils.load_weights(G=G,
                           D=D,
                           state_dict=state_dict,
                           weights_root=global_cfg.resume_cfg.weights_root,
                           experiment_name='',
                           name_suffix=config['load_weights']
                           if config['load_weights'] else None,
                           G_ema=G_ema if config['ema'] else None)
        logger.info(f"Resume IS={state_dict['best_IS']}")
        logger.info(f"Resume FID={state_dict['best_FID']}")

    # If parallel, parallelize the GD module
    if config['parallel']:
        GD = nn.DataParallel(GD)
        if config['cross_replica']:
            patch_replication_callback(GD)

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                              experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])
    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)
    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    loaders = utils.get_data_loaders(
        **{
            **config, 'batch_size': D_batch_size,
            'start_itr': state_dict['itr'],
            **getattr(global_cfg, 'train_dataloader', {})
        })

    val_loaders = None
    if hasattr(global_cfg, 'val_dataloader'):
        val_loaders = utils.get_data_loaders(
            **{
                **config, 'batch_size': config['batch_size'],
                'start_itr': state_dict['itr'],
                **global_cfg.val_dataloader
            })[0]
        val_loaders = iter(val_loaders)
    # Prepare inception metrics: FID and IS
    if global_cfg.get('use_unofficial_FID', False):
        get_inception_metrics = inception_utils.prepare_inception_metrics(
            config['inception_file'], config['parallel'], config['no_fid'])
    else:
        get_inception_metrics = inception_utils.prepare_FID_IS(global_cfg)
    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])
    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()
    # Loaders are loaded, prepare the training function
    if config['which_train_fn'] == 'GAN':
        train = train_fns.GAN_training_function(G, D, GD, z_, y_, ema,
                                                state_dict, config,
                                                val_loaders)
    # Else, assume debugging and use the dummy train fn
    elif config['which_train_fn'] == 'dummy':
        train = train_fns.dummy_training_function()
    else:
        train_fns_module = importlib.import_module(config['which_train_fn'])
        train = train_fns_module.GAN_training_function(G, D, GD, z_, y_, ema,
                                                       state_dict, config,
                                                       val_loaders)

    # Prepare Sample function for use with inception metrics
    if global_cfg.get('use_unofficial_FID', False):
        sample = functools.partial(
            utils.sample,
            G=(G_ema if config['ema'] and config['use_ema'] else G),
            z_=z_,
            y_=y_,
            config=config)
    else:
        sample = functools.partial(
            utils.sample_imgs,
            G=(G_ema if config['ema'] and config['use_ema'] else G),
            z_=z_,
            y_=y_,
            config=config)

    state_dict['shown_images'] = state_dict['itr'] * D_batch_size

    if global_cfg.get('resume_cfg', {}).get('eval', False):
        logger.info(f'Evaluating model.')
        G_ema.eval()
        G.eval()
        train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample,
                       get_inception_metrics, experiment_name, test_log)
        return

    print('Beginning training at epoch %d...' % state_dict['epoch'])
    # Train for specified number of epochs, although we mostly track G iterations.
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        # Which progressbar to use? TQDM or my own?
        if config['pbar'] == 'mine':
            pbar = utils.progress(loaders[0],
                                  desc=f'Epoch:{epoch}, Itr: ',
                                  displaytype='s1k' if
                                  config['use_multiepoch_sampler'] else 'eta')
        else:
            pbar = tqdm(loaders[0])
        for i, (x, y) in enumerate(pbar):
            # Increment the iteration counter
            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter much.
            G.train()
            D.train()
            if config['ema']:
                G_ema.train()
            if config['D_fp16']:
                x, y = x.to(device).half(), y.to(device)
            else:
                x, y = x.to(device), y.to(device)

            default_dict = train(x, y)

            state_dict['shown_images'] += D_batch_size

            metrics = default_dict['D_loss']
            train_log.log(itr=int(state_dict['itr']), **metrics)

            summary_defaultdict2txtfig(default_dict=default_dict,
                                       prefix='train',
                                       step=state_dict['shown_images'],
                                       textlogger=textlogger)

            # Every sv_log_interval, log singular values
            if (config['sv_log_interval'] > 0) and (
                    not (state_dict['itr'] % config['sv_log_interval'])):
                train_log.log(itr=int(state_dict['itr']),
                              **{
                                  **utils.get_SVs(G, 'G'),
                                  **utils.get_SVs(D, 'D')
                              })

            # If using my progbar, print metrics.
            if config['pbar'] == 'mine':
                print(', '.join(
                    ['itr: %d' % state_dict['itr']] +
                    ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                      end=' ',
                      flush=True)

            # Save weights and copies as configured at specified interval
            if not (state_dict['itr'] % config['save_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z,
                                          fixed_y, state_dict, config,
                                          experiment_name)

            # Test every specified interval
            if state_dict['itr'] == 1 or \
                  (config['test_every'] > 0 and state_dict['itr'] % config['test_every'] == 0) or \
                  (state_dict['shown_images'] % global_cfg.get('test_every_images', float('inf'))) < D_batch_size:
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...', flush=True)
                    G.eval()
                print('\n' + config['tl_outdir'])
                train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample,
                               get_inception_metrics, experiment_name,
                               test_log)
        # Increment epoch counter at end of epoch
        state_dict['epoch'] += 1
Beispiel #15
0
def run(config):
  # Prepare state dict, which holds things like epoch # and itr #
  state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0,
                'best_IS': 0, 'best_FID': 999999, 'config': config}
                
  # Optionally, get the configuration from the state dict. This allows for
  # recovery of the config provided only a state dict and experiment name,
  # and can be convenient for writing less verbose sample shell scripts.
  if config['config_from_name']:
    utils.load_weights(None, None, state_dict, config['weights_root'], 
                       config['experiment_name'], config['load_weights'], None,
                       strict=False, load_optim=False)
    # Ignore items which we might want to overwrite from the command line
    for item in state_dict['config']:
      if item not in ['z_var', 'base_root', 'batch_size', 'G_batch_size', 'use_ema', 'G_eval_mode']:
        config[item] = state_dict['config'][item]
  # Update the config dict as necessary
  # This is for convenience, to add settings derived from the user-specified
  # configuration into the config-dict (e.g. inferring the number of classes
  # and size of the images from the dataset, passing in a pytorch object
  # for the activation specified as a string)
  config['resolution'] = utils.imsize_dict[config['dataset']]
  config['n_classes'] = utils.nclass_dict[config['dataset']]
  config['G_activation'] = utils.activation_dict[config['G_nl']]
  config['D_activation'] = utils.activation_dict[config['D_nl']]
  # By default, skip init if resuming training.
  if config['resume']:
    print('Skipping initialization for training resumption...')
    config['skip_init'] = True
  config = utils.update_config_roots(config)
  device = 'cuda'
  
  # Seed RNG
  utils.seed_rng(config['seed'])

  # Prepare root folders if necessary
  utils.prepare_root(config)

  # Setup cudnn.benchmark for free speed
  torch.backends.cudnn.benchmark = True

  # Import the model--this line allows us to dynamically select different files.
  model = __import__(config['model'])
  experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
  print('Experiment name is %s' % experiment_name)

  # Next, build the model
  G = model.Generator(**config).to(device)
  D = model.Discriminator(**config).to(device)
  utils.count_parameters(G)
  
  # Load weights
  print('Loading weights...')
  # Here is where we deal with the ema--load ema weights or load normal weights
  utils.load_weights(G if not (config['use_ema']) else None, None, state_dict, 
                     config['weights_root'], experiment_name, config['load_weights'],
                     G if config['ema'] and config['use_ema'] else None,
                     strict=False, load_optim=False)
  # Update batch size setting used for G
  G_batch_size = max(config['G_batch_size'], config['batch_size']) 
  z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'],
                             device=device, fp16=config['G_fp16'], 
                             z_var=config['z_var'])
  
  if config['G_eval_mode']:
    print('Putting G in eval mode..')
    G.eval()
  else:
    print('G is in %s mode...' % ('training' if G.training else 'eval'))
    
  #Sample function
  sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config)  
  if config['accumulate_stats']:
    print('Accumulating standing stats across %d accumulations...' % config['num_standing_accumulations'])
    utils.accumulate_standing_stats(G, z_, y_, config['n_classes'],
                                    config['num_standing_accumulations'])
    
  
  # Sample a number of images and save them to an NPZ, for use with TF-Inception
  if config['sample_npz']:
    # Lists to hold images and labels for images
    x, y = [], []
    print('Sampling %d images and saving them to npz...' % config['sample_num_npz'])
    for i in trange(int(np.ceil(config['sample_num_npz'] / float(G_batch_size)))):
      with torch.no_grad():
        images, labels = sample()
      x += [np.uint8(255 * (images.cpu().numpy() + 1) / 2.)]
      y += [labels.cpu().numpy()]
    x = np.concatenate(x, 0)[:config['sample_num_npz']]
    y = np.concatenate(y, 0)[:config['sample_num_npz']]    
    print('Images shape: %s, Labels shape: %s' % (x.shape, y.shape))
    npz_filename = '%s/%s/samples.npz' % (config['samples_root'], experiment_name)
    print('Saving npz to %s...' % npz_filename)
    np.savez(npz_filename, **{'x' : x, 'y' : y})
  
  # Prepare sample sheets
  if config['sample_sheets']:
    print('Preparing conditional sample sheets...')
    utils.sample_sheet(G, classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']], 
                         num_classes=config['n_classes'], 
                         samples_per_class=10, parallel=config['parallel'],
                         samples_root=config['samples_root'], 
                         experiment_name=experiment_name,
                         folder_number=config['sample_sheet_folder_num'],
                         z_=z_,)
  # Sample interp sheets
  if config['sample_interps']:
    print('Preparing interp sheets...')
    for fix_z, fix_y in zip([False, False, True], [False, True, False]):
      utils.interp_sheet(G, num_per_sheet=16, num_midpoints=8,
                         num_classes=config['n_classes'], 
                         parallel=config['parallel'], 
                         samples_root=config['samples_root'], 
                         experiment_name=experiment_name,
                         folder_number=config['sample_sheet_folder_num'], 
                         sheet_number=0,
                         fix_z=fix_z, fix_y=fix_y, device='cuda')
  # Sample random sheet
  if config['sample_random']:
    print('Preparing random sample sheet...')
    images, labels = sample()
    print("labels size", labels)    
    torchvision.utils.save_image(images.float(),
                                 '%s/%s/random_samples.jpg' % (config['samples_root'], experiment_name),
                                 nrow=int(G_batch_size**0.5),
                                 normalize=True)

   # If using EMA, prepare it
  if config['ema']:
    print('Preparing EMA for G with decay of {}'.format(config['ema_decay']))
    G_ema = model.Generator(**{**config, 'skip_init':True, 
                               'no_optim': True}).to(device)
    ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
  else:
    G_ema, ema = None, None
  
  # FP16?
  if config['G_fp16']:
    print('Casting G to float16...')
    G = G.half()
    if config['ema']:
      G_ema = G_ema.half()
  if config['D_fp16']:
    print('Casting D to fp16...')
    D = D.half()
    # Consider automatically reducing SN_eps?
  GD = model.G_D(G, D)
  #print(G)
  #print(D)
  print('Number of params in G: {} D: {}'.format(
    *[sum([p.data.nelement() for p in net.parameters()]) for net in [G,D]]))
  # Prepare state dict, which holds things like epoch # and itr #
  state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0,
                'best_IS': 0, 'best_FID': 999999, 'config': config}

  # If loading from a pre-trained model, load weights
  if config['resume']:
    print('Loading weights...')
    utils.load_weights(G, D, state_dict,
                       config['weights_root'], experiment_name, 
                       config['load_weights'] if config['load_weights'] else None,
                       G_ema if config['ema'] else None)

  # If parallel, parallelize the GD module
  if config['parallel']:
    GD = nn.DataParallel(GD)
    if config['cross_replica']:
      patch_replication_callback(GD)

  
  D_fake = D(images[1,:,:,:],labels[0])
  print("D_fake ",D_fake)
Beispiel #16
0
def run(config):

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    # config['n_classes'] = utils.nclass_dict[config['dataset']]

    # NOTE: setting n_classes to 1 except in conditional case to train as unconditional model
    config['n_classes'] = 1
    if config['conditional']:
        config['n_classes'] = 2
    print('n classes: {}'.format(config['n_classes']))
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)

    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None

    # FP16?
    if config['G_fp16']:
        print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        # Consider automatically reducing SN_eps?
    GD = model.G_D(
        G, D,
        config['conditional'])  # check if labels are 0's if "unconditional"
    print(G)
    print(D)
    print('Number of params in G: {} D: {}'.format(
        *
        [sum([p.data.nelement() for p in net.parameters()])
         for net in [G, D]]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num_fair': 0,
        'save_best_num_fid': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'best_fair_d': 999999,
        'config': config
    }

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        utils.load_weights(
            G, D, state_dict, config['weights_root'], experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            G_ema if config['ema'] else None)

    # If parallel, parallelize the GD module
    if config['parallel']:
        GD = nn.DataParallel(GD)
        if config['cross_replica']:
            patch_replication_callback(GD)

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.json' % (config['logs_root'],
                                             experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])
    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)
    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    loaders = utils.get_data_loaders(
        config, **{
            **config, 'batch_size': D_batch_size,
            'start_itr': state_dict['itr']
        })

    # Prepare inception metrics: FID and IS
    get_inception_metrics = inception_utils.prepare_inception_metrics(
        config['dataset'], config['parallel'], config['no_fid'])

    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'],
                               true_prop=config['true_prop'])
    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()

    # NOTE: "unconditional" GAN
    if not config['conditional']:
        fixed_y.zero_()
        y_.zero_()

    # Loaders are loaded, prepare the training function
    if config['which_train_fn'] == 'GAN':
        train = train_fns.GAN_training_function(G, D, GD, z_, y_, ema,
                                                state_dict, config)
    # Else, assume debugging and use the dummy train fn
    else:
        train = train_fns.dummy_training_function()
    # Prepare Sample function for use with inception metrics
    sample = functools.partial(
        utils.sample,
        G=(G_ema if config['ema'] and config['use_ema'] else G),
        z_=z_,
        y_=y_,
        config=config)

    print('Beginning training at epoch %d...' % state_dict['epoch'])
    # Train for specified number of epochs, although we mostly track G iterations.
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        # Which progressbar to use? TQDM or my own?
        if config['pbar'] == 'mine':
            pbar = utils.progress(loaders[0],
                                  displaytype='s1k' if
                                  config['use_multiepoch_sampler'] else 'eta')
        else:
            pbar = tqdm(loaders[0])

        # iterate through the dataloaders
        for i, (x, y, ratio) in enumerate(pbar):
            # Increment the iteration counter
            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter much.
            G.train()
            D.train()
            if config['ema']:
                G_ema.train()
            if config['D_fp16']:
                x, y, ratio = x.to(device).half(), y.to(device), ratio.to(
                    device)
            else:
                x, y, ratio = x.to(device), y.to(device), ratio.to(device)
            metrics = train(x, y, ratio)
            train_log.log(itr=int(state_dict['itr']), **metrics)

            # Every sv_log_interval, log singular values
            if (config['sv_log_interval'] > 0) and (
                    not (state_dict['itr'] % config['sv_log_interval'])):
                train_log.log(itr=int(state_dict['itr']),
                              **{
                                  **utils.get_SVs(G, 'G'),
                                  **utils.get_SVs(D, 'D')
                              })

            # If using my progbar, print metrics.
            if config['pbar'] == 'mine':
                print(', '.join(
                    ['itr: %d' % state_dict['itr']] +
                    ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                      end=' ')

            # Save weights and copies as configured at specified interval
            if not (state_dict['itr'] % config['save_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z,
                                          fixed_y, state_dict, config,
                                          experiment_name)

        # Test every epoch (not specified interval)
        if (epoch >= config['start_eval']):
            # First, find correct inception moments
            data_moments = '../../fid_stats/unbiased_all_gender_fid_stats.npz'
            if config['multi']:
                data_moments = '../../fid_stats/unbiased_all_multi_fid_stats.npz'
                fid_type = 'multi'
            else:
                fid_type = 'gender'

            # load appropriate moments
            print('Loaded data moments at: {}'.format(data_moments))
            experiment_name = (config['experiment_name']
                               if config['experiment_name'] else
                               utils.name_from_config(config))

            # eval mode for FID computation
            if config['G_eval_mode']:
                print('Switching G to eval mode...')
                G.eval()
                if config['ema']:
                    G_ema.eval()
            utils.sample_inception(
                G_ema if config['ema'] and config['use_ema'] else G, config,
                str(epoch))
            # Get saved sample path
            folder_number = str(epoch)
            sample_moments = '%s/%s/%s/samples.npz' % (
                config['samples_root'], experiment_name, folder_number)
            # Calculate FID
            FID = fid_score.calculate_fid_given_paths(
                [data_moments, sample_moments],
                batch_size=100,
                cuda=True,
                dims=2048)
            print("FID calculated")
            train_fns.update_FID(G, D, G_ema, state_dict, config, FID,
                                 experiment_name, test_log,
                                 epoch)  # added epoch logging
        # Increment epoch counter at end of epoch
        print('Completed epoch {}'.format(epoch))
        state_dict['epoch'] += 1
Beispiel #17
0
def run(config):
  def len_parallelloader(self):
        return len(self._loader._loader)
  pl.PerDeviceLoader.__len__ = len_parallelloader

  # Update the config dict as necessary
  # This is for convenience, to add settings derived from the user-specified
  # configuration into the config-dict (e.g. inferring the number of classes
  # and size of the images from the dataset, passing in a pytorch object
  # for the activation specified as a string)
  config['resolution'] = utils.imsize_dict[config['dataset']]
  config['n_classes'] = utils.nclass_dict[config['dataset']]
  config['G_activation'] = utils.activation_dict[config['G_nl']]
  config['D_activation'] = utils.activation_dict[config['D_nl']]
  # By default, skip init if resuming training.
  if config['resume']:
    xm.master_print('Skipping initialization for training resumption...')
    config['skip_init'] = True
  config = utils.update_config_roots(config)

  # Seed RNG
  utils.seed_rng(config['seed'])

  # Prepare root folders if necessary
  utils.prepare_root(config)

  # Setup cudnn.benchmark for free speed
  torch.backends.cudnn.benchmark = True

  # Import the model--this line allows us to dynamically select different files.
  model = __import__(config['model'])
  experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
  xm.master_print('Experiment name is %s' % experiment_name)

  device = xm.xla_device(devkind='TPU')

  # Next, build the model
  G = model.Generator(**config)
  D = model.Discriminator(**config)

   # If using EMA, prepare it
  if config['ema']:
    xm.master_print('Preparing EMA for G with decay of {}'.format(config['ema_decay']))
    G_ema = model.Generator(**{**config, 'skip_init':True,
                               'no_optim': True})
  else:
    xm.master_print('Not using ema...')
    G_ema, ema = None, None

  # FP16?
  if config['G_fp16']:
    xm.master_print('Casting G to float16...')
    G = G.half()
    if config['ema']:
      G_ema = G_ema.half()
  if config['D_fp16']:
    xm.master_print('Casting D to fp16...')
    D = D.half()

  # Prepare state dict, which holds things like itr #
  state_dict = {'itr': 0, 'save_num': 0, 'save_best_num': 0,
                'best_IS': 0, 'best_FID': 999999, 'config': config}

  # If loading from a pre-trained model, load weights
  if config['resume']:
    xm.master_print('Loading weights...')
    utils.load_weights(G, D, state_dict,
                       config['weights_root'], experiment_name,
                       config['load_weights'] if config['load_weights'] else None,
                       G_ema if config['ema'] else None)

  # move everything to TPU
  G = G.to(device)
  D = D.to(device)

  G.optim = optim.Adam(params=G.parameters(), lr=G.lr,
                       betas=(G.B1, G.B2), weight_decay=0,
                       eps=G.adam_eps)
  D.optim = optim.Adam(params=D.parameters(), lr=D.lr,
                       betas=(D.B1, D.B2), weight_decay=0,
                       eps=D.adam_eps)


  #for key, val in G.optim.state.items():
  #  G.optim.state[key]['exp_avg'] = G.optim.state[key]['exp_avg'].to(device)
  #  G.optim.state[key]['exp_avg_sq'] = G.optim.state[key]['exp_avg_sq'].to(device)

  #for key, val in D.optim.state.items():
  #  D.optim.state[key]['exp_avg'] = D.optim.state[key]['exp_avg'].to(device)
  #  D.optim.state[key]['exp_avg_sq'] = D.optim.state[key]['exp_avg_sq'].to(device)


  if config['ema']:
    G_ema = G_ema.to(device)
    ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])

  # Consider automatically reducing SN_eps?
  GD = model.G_D(G, D)
  xm.master_print(G)
  xm.master_print(D)
  xm.master_print('Number of params in G: {} D: {}'.format(
    *[sum([p.data.nelement() for p in net.parameters()]) for net in [G,D]]))

  # Prepare loggers for stats; metrics holds test metrics,
  # lmetrics holds any desired training metrics.
  test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                            experiment_name)
  train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
  xm.master_print('Test Metrics will be saved to {}'.format(test_metrics_fname))
  test_log = utils.MetricsLogger(test_metrics_fname,
                                 reinitialize=(not config['resume']))
  xm.master_print('Training Metrics will be saved to {}'.format(train_metrics_fname))
  train_log = utils.MyLogger(train_metrics_fname,
                             reinitialize=(not config['resume']),
                             logstyle=config['logstyle'])

  if xm.is_master_ordinal():
      # Write metadata
      utils.write_metadata(config['logs_root'], experiment_name, config, state_dict)

  # Prepare data; the Discriminator's batch size is all that needs to be passed
  # to the dataloader, as G doesn't require dataloading.
  # Note that at every loader iteration we pass in enough data to complete
  # a full D iteration (regardless of number of D steps and accumulations)
  D_batch_size = (config['batch_size'] * config['num_D_steps']
                  * config['num_D_accumulations'])
  xm.master_print('Preparing data...')
  loader = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size,
                                      'start_itr': state_dict['itr']})

  # Prepare inception metrics: FID and IS
  xm.master_print('Preparing metrics...')

  get_inception_metrics = inception_utils.prepare_inception_metrics(
      config['dataset'], config['parallel'],
      no_inception=config['no_inception'],
      no_fid=config['no_fid'])

  # Prepare noise and randomly sampled label arrays
  # Allow for different batch sizes in G
  G_batch_size = max(config['G_batch_size'], config['batch_size'])

  sample = lambda: utils.prepare_z_y(G_batch_size, G.dim_z,
                                     config['n_classes'], device=device,
                                     fp16=config['G_fp16'])
  # Prepare a fixed z & y to see individual sample evolution throghout training
  fixed_z, fixed_y = sample()

  train = train_fns.GAN_training_function(G, D, GD, sample, ema, state_dict,
                                          config)

  xm.master_print('Beginning training...')

  if xm.is_master_ordinal():
    pbar = tqdm(total=config['total_steps'])
    pbar.n = state_dict['itr']
    pbar.refresh()

  xm.rendezvous('training_starts')
  while (state_dict['itr'] < config['total_steps']):
    pl_loader = pl.ParallelLoader(loader, [device]).per_device_loader(device)

    for i, (x, y) in enumerate(pl_loader):
      if xm.is_master_ordinal():
          # Increment the iteration counter
          pbar.update(1)

      state_dict['itr'] += 1
      # Make sure G and D are in training mode, just in case they got set to eval
      # For D, which typically doesn't have BN, this shouldn't matter much.
      G.train()
      D.train()
      if config['ema']:
        G_ema.train()

      xm.rendezvous('data_collection')
      metrics = train(x, y)

      # train_log.log(itr=int(state_dict['itr']), **metrics)

      # Every sv_log_interval, log singular values
      #if ((config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval']))) and xm.is_master_ordinal():
        #train_log.log(itr=int(state_dict['itr']),
         #             **{**utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D')})

      # Save weights and copies as configured at specified interval
      if (not (state_dict['itr'] % config['save_every'])):
        if config['G_eval_mode']:
          xm.master_print('Switchin G to eval mode...')
          G.eval()
          if config['ema']:
            G_ema.eval()
        train_fns.save_and_sample(G, D, G_ema, sample, fixed_z, fixed_y, state_dict, config, experiment_name)

      # Test every specified interval
      if (not (state_dict['itr'] % config['test_every'])):

        which_G = G_ema if config['ema'] and config['use_ema'] else G
        if config['G_eval_mode']:
          xm.master_print('Switchin G to eval mode...')
          which_G.eval()

        def G_sample():
            z, y = sample()
            return which_G(z, which_G.shared(y))

        train_fns.test(G, D, G_ema, sample, state_dict, config, G_sample,
                       get_inception_metrics, experiment_name, test_log)

      if state_dict['itr'] >= config['total_steps']:
          break
def run(config):
    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = 64
    config['n_classes'] = 120
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'
    # Seed RNG
    utils.seed_rng(config['seed'])
    # Prepare root folders if necessary
    utils.prepare_root(config)
    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else 'generative_dog_images')
    print('Experiment name is %s' % experiment_name)

    G = BigGAN.Generator(**config).to(device)
    D = BigGAN.Discriminator(**config).to(device)

    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = BigGAN.Generator(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None

    GD = BigGAN.G_D(G, D)
    print(G)
    print(D)
    print('Number of params in G: {} D: {}'.format(
        *
        [sum([p.data.nelement() for p in net.parameters()])
         for net in [G, D]]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'config': config}

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        utils.load_weights(
            G, D, state_dict, config['weights_root'], experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            G_ema if config['ema'] else None)

    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    loaders = dataset.get_data_loaders(data_root=config['data_root'],
                                       label_root=config['label_root'],
                                       batch_size=D_batch_size,
                                       num_workers=config['num_workers'],
                                       shuffle=config['shuffle'],
                                       pin_memory=config['pin_memory'],
                                       drop_last=True)

    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])
    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()
    # Loaders are loaded, prepare the training function
    train = train_fns.create_train_fn(G, D, GD, z_, y_, ema, state_dict,
                                      config)

    print('Beginning training at epoch %d...' % state_dict['epoch'])
    start_time = time.perf_counter()
    total_iters = config['num_epochs'] * len(loaders[0])

    # Train for specified number of epochs, although we mostly track G iterations.
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        for i, (x, y) in enumerate(loaders[0]):
            # Increment the iteration counter
            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter much.
            G.train()
            D.train()
            if config['ema']:
                G_ema.train()
            x, y = x.to(device), y.to(device)
            metrics = train(x, y)

            if not (state_dict['itr'] % config['log_interval']):
                curr_time = time.perf_counter()
                curr_time_str = datetime.datetime.fromtimestamp(
                    curr_time).strftime('%H:%M:%S')
                elapsed = str(
                    datetime.timedelta(seconds=(curr_time - start_time)))
                log = ("[{}] [{}] [{} / {}] Ep {}, ".format(
                    curr_time_str, elapsed, state_dict['itr'], total_iters,
                    epoch) + ', '.join([
                        '%s : %+4.3f' % (key, metrics[key]) for key in metrics
                    ]))
                print(log)

            # Save weights and copies as configured at specified interval
            if not (state_dict['itr'] % config['save_every']):
                if config['G_eval_mode']:
                    print('Switching G to eval mode...')
                    G.eval()
                    # if config['ema']:
                    # G_ema.eval()
                train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z,
                                          fixed_y, state_dict, config,
                                          experiment_name)

        # Increment epoch counter at end of epoch
        state_dict['epoch'] += 1
def run(config):
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))

    # Next, build the model
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)
    E = model.ImgEncoder(**config).to(device)
    GDE = model.G_D_E(G, D, E)

    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    print('Number of params in G: {} D: {} E: {}'.format(*[
        sum([p.data.nelement() for p in net.parameters()])
        for net in [G, D, E]
    ]))

    print('Loading weights...')
    utils.load_weights(
        G,
        D,
        E,
        state_dict,
        config['weights_root'],
        experiment_name,
        config['load_weights'] if config['load_weights'] else None,
        None,
        strict=False,
        load_optim=False)

    # ==============================================================================
    # prepare the data
    loaders, train_dataset = utils.get_data_loaders(**config)

    G_batch_size = max(config['G_batch_size'], config['batch_size'])

    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device)

    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device)

    fixed_z.sample_()
    fixed_y.sample_()
    print("fixed_y original: {} {}".format(fixed_y.shape, fixed_y[:10]))

    fixed_x, fixed_y_of_x = utils.prepare_x_y(G_batch_size, train_dataset,
                                              experiment_name, config)

    evaluate_sample(config,
                    fixed_x,
                    fixed_y,
                    G,
                    E,
                    experiment_name,
                    attack=True)
Beispiel #20
0
def run(config):
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    if config['resume']:
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'
    utils.seed_rng(config['seed'])
    utils.prepare_root(config)
    torch.backends.cudnn.benchmark = True
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config))
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)
    G3 = model.Generator(**config).to(device)
    D3 = model.Discriminator(**config).to(device)
    if config['ema']:
        G_ema = model.Generator(**{**config, 'skip_init': True, 'no_optim': True}).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None
    if config['G_fp16']:
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        D = D.half()
    GD = model.G_D(G, D, config['conditional'])
    GD3 = model.G_D(G3, D3, config['conditional'])
    state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config}
    if config['resume']:
        utils.load_weights(G, D, state_dict, config['weights_root'], experiment_name, config['load_weights'] if config['load_weights'] else None, G_ema if config['ema'] else None)
    #utils.load_weights(G, D, state_dict, '../Task3_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'best0', G_ema if config['ema'] else None)
    #utils.load_weights(G, D, state_dict, '../Task1_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'best0', G_ema if config['ema'] else None)
    #utils.load_weights(G3, D3, state_dict, '../Task2_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'last0', G_ema if config['ema'] else None)
    #utils.load_weights(G3, D3, state_dict, '../Task2_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'best0', G_ema if config['ema'] else None)
    utils.load_weights(G3, D3, state_dict, '../Task2_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'last0', G_ema if config['ema'] else None)
    utils.load_weights(G, D, state_dict, '../Task3_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'best0', G_ema if config['ema'] else None)
    if config['parallel']:
        GD = nn.DataParallel(GD)
        if config['cross_replica']:
            patch_replication_callback(GD)
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'], experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    test_log = utils.MetricsLogger(test_metrics_fname, reinitialize=(not config['resume']))
    train_log = utils.MyLogger(train_metrics_fname, reinitialize=(not config['resume']), logstyle=config['logstyle'])
    utils.write_metadata(config['logs_root'], experiment_name, config, state_dict)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] * config['num_D_accumulations'])
    # Use: config['abnormal_class']
    #print(config['abnormal_class'])
    abnormal_class = config['abnormal_class']
    select_dataset = config['select_dataset']
    #print(config['select_dataset'])
    #print(select_dataset)
    loaders = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size, 'start_itr': state_dict['itr'], 'abnormal_class': abnormal_class, 'select_dataset': select_dataset})
    # Usage: --select_dataset cifar10 --abnormal_class 0 --shuffle --batch_size 64 --parallel --num_G_accumulations 1 --num_D_accumulations 1 --num_epochs 500 --num_D_steps 4 --G_lr 2e-4 --D_lr 2e-4 --dataset C10 --data_root ../Task2_CIFAR_MNIST_KLWGAN_Simulation_Experiment/data/ --G_ortho 0.0 --G_attn 0 --D_attn 0 --G_init N02 --D_init N02 --ema --use_ema --ema_start 1000 --start_eval 50 --test_every 5000 --save_every 2000 --num_best_copies 5 --num_save_copies 2 --loss_type kl_5 --seed 2 --which_best FID --model BigGAN --experiment_name C10Ukl5
    # Use: --select_dataset mnist --abnormal_class 1 --shuffle --batch_size 64 --parallel --num_G_accumulations 1 --num_D_accumulations 1 --num_epochs 500 --num_D_steps 4 --G_lr 2e-4 --D_lr 2e-4 --dataset C10 --data_root ../Task2_CIFAR_MNIST_KLWGAN_Simulation_Experiment/data/ --G_ortho 0.0 --G_attn 0 --D_attn 0 --G_init N02 --D_init N02 --ema --use_ema --ema_start 1000 --start_eval 50 --test_every 5000 --save_every 2000 --num_best_copies 5 --num_save_copies 2 --loss_type kl_5 --seed 2 --which_best FID --model BigGAN --experiment_name C10Ukl5
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16'])
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()
    if not config['conditional']:
        fixed_y.zero_()
        y_.zero_()
    if config['which_train_fn'] == 'GAN':
        train = train_fns.GAN_training_function(G3, D3, GD3, G3, D3, GD3, G, D, GD, z_, y_, ema, state_dict, config)
    else:
        train = train_fns.dummy_training_function()
    sample = functools.partial(utils.sample, G=(G_ema if config['ema'] and config['use_ema'] else G), z_=z_, y_=y_, config=config)
    if config['dataset'] == 'C10U' or config['dataset'] == 'C10':
        data_moments = 'fid_stats_cifar10_train.npz'
        #'../Task1_CIFAR_MNIST_KLWGAN_Simulation_Experiment/fid_stats_cifar10_train.npz'
        #data_moments = '../Task1_CIFAR_MNIST_KLWGAN_Simulation_Experiment/fid_stats_cifar10_train.npz'
    else:
        print("Cannot find the data set.")
        sys.exit()
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        if config['pbar'] == 'mine':
            pbar = utils.progress(loaders[0], displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta')
        else:
            pbar = tqdm(loaders[0])
        for i, (x, y) in enumerate(pbar):
            state_dict['itr'] += 1
            G.eval()
            D.train()
            if config['ema']:
                G_ema.train()
            if config['D_fp16']:
                x, y = x.to(device).half(), y.to(device)
            else:
                x, y = x.to(device), y.to(device)
            print('')
            # Random seed
            #print(config['seed'])
            if epoch==0 and i==0:
                print(config['seed'])
            metrics = train(x, y)
            # We double the learning rate if we double the batch size.
            train_log.log(itr=int(state_dict['itr']), **metrics)
            if (config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval'])):
                train_log.log(itr=int(state_dict['itr']), **{**utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D')})
            if config['pbar'] == 'mine':
                print(', '.join(['itr: %d' % state_dict['itr']] + ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]), end=' ')
            if not (state_dict['itr'] % config['save_every']):
                if config['G_eval_mode']:
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, state_dict, config, experiment_name)
            experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config))
            if (not (state_dict['itr'] % config['test_every'])) and (epoch >= config['start_eval']):
                if config['G_eval_mode']:
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                utils.sample_inception(
                    G_ema if config['ema'] and config['use_ema'] else G, config, str(epoch))
                folder_number = str(epoch)
                sample_moments = '%s/%s/%s/samples.npz' % (config['samples_root'], experiment_name, folder_number)
                FID = fid_score.calculate_fid_given_paths([data_moments, sample_moments], batch_size=50, cuda=True, dims=2048)
                train_fns.update_FID(G, D, G_ema, state_dict, config, FID, experiment_name, test_log)
        state_dict['epoch'] += 1
    #utils.save_weights(G, D, state_dict, config['weights_root'], experiment_name, 'be01Bes01Best%d' % state_dict['save_best_num'], G_ema if config['ema'] else None)
    utils.save_weights(G, D, state_dict, config['weights_root'], experiment_name, 'last%d' % 0, G_ema if config['ema'] else None)
Beispiel #21
0
def run(config):

    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    model = BigGAN

    # Next, build the model
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)

    # If using EMA, prepare it (Earth Moving Averaging for parameters)
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        ema = None

    GD = model.G_D(G, D)

    print('Number of params in G: {} D: {}'.format(
        *
        [sum([p.data.nelement() for p in net.parameters()])
         for net in [G, D]]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        utils.load_weights(
            G, D, state_dict, config['weights_root'], experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            G_ema if config['ema'] else None)

    # If parallel, parallelize the GD module
    if config['parallel']:
        GD = nn.DataParallel(GD)

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                              experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))

    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)

    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    loaders = utils.get_data_loaders(**{
        **config, 'batch_size': D_batch_size,
        'start_itr': state_dict['itr']
    })

    # Prepare inception metrics: FID and IS
    get_inception_metrics = inception_utils.prepare_inception_metrics(
        config['dataset'], config['parallel'], config['no_fid'])

    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])
    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()
    # Loaders are loaded, prepare the training function
    train = train_fns.GAN_training_function(G, D, GD, z_, y_, ema, state_dict,
                                            config)

    # Prepare Sample function for use with inception metrics
    sample = functools.partial(
        utils.sample,
        G=(G_ema if config['ema'] and config['use_ema'] else G),
        z_=z_,
        y_=y_,
        config=config)

    print('Beginning training at epoch %d...' % state_dict['epoch'])
    # Train for specified number of epochs, although we mostly track G iterations.
    for epoch in range(state_dict['epoch'], config['num_epochs']):

        pbar = utils.progress(
            loaders[0],
            displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta')
        for i, (x, y) in enumerate(pbar):
            # Increment the iteration counter
            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter much.
            G.train()
            D.train()
            if config['ema']:
                G_ema.train()
            if config['D_fp16']:
                x, y = x.to(device).half(), y.to(device)
            else:
                x, y = x.to(device), y.to(device)
            metrics = train(x, y)

            print(', '.join(
                ['itr: %d' % state_dict['itr']] +
                ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                  end=' ')

            # Save weights and copies as configured at specified interval
            if not (state_dict['itr'] % config['save_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z,
                                          fixed_y, state_dict, config,
                                          experiment_name)

            # Test every specified interval
            if not (state_dict['itr'] % config['test_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample,
                               get_inception_metrics, experiment_name,
                               test_log)
        # Increment epoch counter at end of epoch
        state_dict['epoch'] += 1
def run(config):

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)
    E = model.ImgEncoder(**config).to(device)
    # E = model.Encoder(**config).to(device)

    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(**{**config, 'skip_init': True,
                                   'no_optim': True}).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None

    # FP16?
    if config['G_fp16']:
        print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        # Consider automatically reducing SN_eps?
    GDE = model.G_D_E(G, D, E)

    print('Number of params in G: {} D: {} E: {}'.format(
        *[sum([p.data.nelement() for p in net.parameters()]) for net in [G, D, E]]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0,
                  'best_IS': 0, 'best_FID': 999999, 'config': config}

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        utils.load_weights(G, D, E, state_dict,
                           config['weights_root'], experiment_name,
                           config['load_weights'] if config['load_weights'] else None,
                           G_ema if config['ema'] else None)

    # If parallel, parallelize the GD module
    if config['parallel']:
        GDE = nn.DataParallel(GDE)
        if config['cross_replica']:
            patch_replication_callback(GDE)

    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps']
                    * config['num_D_accumulations'])
    loaders, train_dataset = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size,
                                        'start_itr': state_dict['itr']})

    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'],
                               device=device, fp16=config['G_fp16'])
    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z,
                                         config['n_classes'], device=device,
                                         fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()
    print("fixed_y original: {} {}".format(fixed_y.shape, fixed_y[:10]))
    ## TODO: change the sample method to sample x and y
    fixed_x, fixed_y_of_x = utils.prepare_x_y(G_batch_size, train_dataset, experiment_name, config, device=device)
    

    # Build image pool to prevent mode collapes
    if config['img_pool_size'] != 0:
        img_pool = ImagePool(config['img_pool_size'], train_dataset.num_class,\
                                    save_dir=os.path.join(config['imgbuffer_root'], experiment_name),
                                    resume_buffer=config['resume_buffer'])
    else:
        img_pool = None

    # Loaders are loaded, prepare the training function
    if config['which_train_fn'] == 'GAN':
        train = train_fns.GAN_training_function(G, D, E, GDE,
                                                ema, state_dict, config, img_pool)
    # Else, assume debugging and use the dummy train fn
    else:
        train = train_fns.dummy_training_function()
    # Prepare Sample function for use with inception metrics
    sample = functools.partial(utils.sample,
                               G=(G_ema if config['ema'] and config['use_ema']
                                   else G),
                               z_=z_, y_=y_, config=config)




    # print('Beginning training at epoch %f...' % (state_dict['itr'] * D_batch_size / len(train_dataset)))
    print("Beginning testing at Epoch {} (iteration {})".format(state_dict['epoch'], state_dict['itr']))

    if config['G_eval_mode']:
        print('Switchin G to eval mode...')
        G.eval()
        if config['ema']:
            G_ema.eval()
    # vc visualization
    # # print("VC visualization ===============")
    # activation_extract(G, D, E, G_ema, fixed_x, fixed_y_of_x, z_, y_,
    #                             state_dict, config, experiment_name, device, normal_eval=False, eval_vc=True, return_mask=False)
    # normal activation
    print("Normal activation ===============")
    activation_extract(G, D, E, G_ema, fixed_x, fixed_y_of_x, z_, y_,
                                state_dict, config, experiment_name, device, normal_eval=True, eval_vc=False, return_mask=False) # produce normal fully activated images
Beispiel #23
0
 def sample(): return utils.prepare_z_y(G_batch_size, G.dim_z,
                                        config['n_classes'], device=device,
                                        fp16=config['G_fp16'])
 # Prepare a fixed z & y to see individual sample evolution throghout
 # training
 fixed_z, fixed_y = sample()
Beispiel #24
0
def run(config):

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cpu'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    experiment_name = "test_{}".format(experiment_name)
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)
    E = model.ImgEncoder(**config).to(device)
    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None

    # FP16?
    if config['G_fp16']:
        print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        # Consider automatically reducing SN_eps?
    GDE = model.G_D_E(G, D, E)
    # print(G)
    # print(D)
    # print(E)
    print("Model Created!")
    print('Number of params in G: {} D: {} E: {}'.format(*[
        sum([p.data.nelement() for p in net.parameters()])
        for net in [G, D, E]
    ]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # If loading from a pre-trained model, load weights

    print('Loading weights...')
    utils.load_weights(
        G, D, E, state_dict, config['weights_root'],
        config['load_experiment_name'],
        config['load_weights'] if config['load_weights'] else None,
        G_ema if config['ema'] else None)
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }
    # If parallel, parallelize the GD module
    if config['parallel']:
        GDE = nn.DataParallel(GDE)
        if config['cross_replica']:
            patch_replication_callback(GDE)

    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    loaders, train_dataset = utils.get_data_loaders(**{
        **config, 'batch_size': D_batch_size,
        'start_itr': 0
    })

    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])
    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()
    print("fixed_y original: {} {}".format(fixed_y.shape, fixed_y[:10]))
    fixed_x, fixed_y_of_x = utils.prepare_x_y(G_batch_size, train_dataset,
                                              experiment_name, config)

    # Prepare Sample function for use with inception metrics
    sample = functools.partial(
        utils.sample,
        G=(G_ema if config['ema'] and config['use_ema'] else G),
        z_=z_,
        y_=y_,
        config=config)

    G.eval()
    E.eval()
    print("check1 -------------------------------")
    print("state_dict['itr']", state_dict['itr'])
    if config['pbar'] == 'mine':
        pbar = utils.progress(
            loaders[0],
            displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta')

    else:
        pbar = tqdm(loaders[0])

    print("state_dict['itr']", state_dict['itr'])
    for i, (x, y) in enumerate(pbar):
        state_dict['itr'] += 1
        if config['D_fp16']:
            x, y = x.to(device).half(), y.to(device)
        else:
            x, y = x.to(device), y.to(device)
        print("x.shape", x.shape)
        print("y.shape", y.shape)

        activation_extract(G,
                           D,
                           E,
                           G_ema,
                           x,
                           y,
                           z_,
                           y_,
                           state_dict,
                           config,
                           experiment_name,
                           save_weights=False)
        if state_dict['itr'] == 20:
            break
Beispiel #25
0
def run(config):

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'
    num_devices = torch.cuda.device_count()
    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = model.Generator(**config).to(device)
    D = model.ImageDiscriminator(**config).to(device)
    if config['no_Dv'] == False:
        Dv = model.VideoDiscriminator(**config).to(device)
    else:
        Dv = None

    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None

    # FP16?
    if config['G_fp16']:
        print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        if config['no_Dv'] == False:
            Dv = Dv.half()
        # Consider automatically reducing SN_eps?
    GD = model.G_D(
        G, D, Dv, config['k'],
        config['T_into_B'])  #xiaodan: add an argument k and T_into_B
    # print('GD.k in train.py line 91',GD.k)
    # print(G) # xiaodan: print disabled by xiaodan. Too many stuffs
    # print(D)
    if config['no_Dv'] == False:
        print('Number of params in G: {} D: {} Dv: {}'.format(*[
            sum([p.data.nelement() for p in net.parameters()])
            for net in [G, D, Dv]
        ]))
    else:
        print('Number of params in G: {} D: {}'.format(*[
            sum([p.data.nelement() for p in net.parameters()])
            for net in [G, D]
        ]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # If loading from a pre-trained BigGAN model, load weights
    if config['biggan_init']:
        print('Loading weights from pre-trained BigGAN...')
        utils.load_biggan_weights(G,
                                  D,
                                  state_dict,
                                  config['biggan_weights_root'],
                                  G_ema if config['ema'] else None,
                                  load_optim=False)

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        utils.load_weights(
            G, D, Dv, state_dict, config['weights_root'], experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            G_ema if config['ema'] else None)

    # If parallel, parallelize the GD module
    if config['parallel']:
        GD = nn.DataParallel(GD)
        if config['cross_replica']:
            patch_replication_callback(GD)

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                              experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])
    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)
    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    if config['dataset'] == 'C10':
        loaders = utils.get_video_cifar_data_loader(
            **{
                **config, 'batch_size': D_batch_size,
                'start_itr': state_dict['itr']
            })
    else:
        loaders = utils.get_video_data_loaders(**{
            **config, 'batch_size': D_batch_size,
            'start_itr': state_dict['itr']
        })
    # print(loaders)
    # print(loaders[0])
    print('D loss weight:', config['D_loss_weight'])
    # Prepare inception metrics: FID and IS
    if config['skip_testing'] == False:
        get_inception_metrics = inception_utils.prepare_inception_metrics(
            config['dataset'], config['parallel'], config['no_fid'])

    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(
        config['G_batch_size'], config['batch_size']
    )  # * num_devices #xiaodan: num_devices added by xiaodan
    # print('num_devices:',num_devices,'G_batch_size:',G_batch_size)
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])
    # print('z_,y_ shapes after prepare_z_y:',z_.shape,y_.shape)
    # print('z_,y_ size:',z_.shape,y_.shape)
    # print('G.dim_z:',G.dim_z)
    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()
    # Loaders are loaded, prepare the training function
    if config['which_train_fn'] == 'GAN':
        train = train_fns.GAN_training_function(G, D, Dv, GD, z_, y_, ema,
                                                state_dict, config)
    # Else, assume debugging and use the dummy train fn
    else:
        train = train_fns.dummy_training_function()
    # Prepare Sample function for use with inception metrics
    sample = functools.partial(
        utils.sample,
        G=(G_ema if config['ema'] and config['use_ema'] else G),
        z_=z_,
        y_=y_,
        config=config)

    print('Beginning training at epoch %d...' % state_dict['epoch'])
    unique_id = datetime.datetime.now().strftime('%Y%m-%d%H-%M%S-')
    tensorboard_path = os.path.join(config['logs_root'], 'tensorboard_logs',
                                    unique_id)
    os.makedirs(tensorboard_path)
    # Train for specified number of epochs, although we mostly track G iterations.
    writer = SummaryWriter(log_dir=tensorboard_path)
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        # Which progressbar to use? TQDM or my own?
        if config['pbar'] == 'mine':
            pbar = utils.progress(loaders[0],
                                  displaytype='s1k' if
                                  config['use_multiepoch_sampler'] else 'eta')
        else:
            pbar = tqdm(loaders[0])
        iteration = epoch * len(pbar)
        for i, (x, y) in enumerate(pbar):
            # Increment the iteration counter

            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter much.
            G.train()
            D.train()
            if config['no_Dv'] == False:
                Dv.train()
            if config['ema']:
                G_ema.train()
            if config['D_fp16']:
                x, y = x.to(device).half(), y.to(device)
            else:
                x, y = x.to(device), y.to(device)
            metrics = train(x, y, writer, iteration + i)
            train_log.log(itr=int(state_dict['itr']), **metrics)

            # Every sv_log_interval, log singular values
            if (config['sv_log_interval'] > 0) and (
                    not (state_dict['itr'] % config['sv_log_interval'])):
                if config['no_Dv'] == False:
                    train_log.log(itr=int(state_dict['itr']),
                                  **{
                                      **utils.get_SVs(G, 'G'),
                                      **utils.get_SVs(D, 'D'),
                                      **utils.get_SVs(Dv, 'Dv')
                                  })
                else:
                    train_log.log(itr=int(state_dict['itr']),
                                  **{
                                      **utils.get_SVs(G, 'G'),
                                      **utils.get_SVs(D, 'D')
                                  })

            # If using my progbar, print metrics.
            if config['pbar'] == 'mine':
                print(', '.join(
                    ['itr: %d' % state_dict['itr']] +
                    ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                      end=' ')

            # Save weights and copies as configured at specified interval
            if not (state_dict['itr'] % config['save_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                train_fns.save_and_sample(G, D, Dv, G_ema, z_, y_, fixed_z,
                                          fixed_y, state_dict, config,
                                          experiment_name)
            #xiaodan: Disabled test for now because we don't have inception data
            # Test every specified interval
            if not (state_dict['itr'] %
                    config['test_every']) and config['skip_testing'] == False:
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                IS_mean, IS_std, FID = train_fns.test(
                    G, D, Dv, G_ema, z_, y_, state_dict, config, sample,
                    get_inception_metrics, experiment_name, test_log)
                writer.add_scalar('Inception/IS', IS_mean, iteration + i)
                writer.add_scalar('Inception/IS_std', IS_std, iteration + i)
                writer.add_scalar('Inception/FID', FID, iteration + i)
        # Increment epoch counter at end of epoch
        state_dict['epoch'] += 1
Beispiel #26
0
def run(config):
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # Optionally, get the configuration from the state dict. This allows for
    # recovery of the config provided only a state dict and experiment name,
    # and can be convenient for writing less verbose sample shell scripts.
    if config['config_from_name']:
        utils.load_weights(None,
                           None,
                           state_dict,
                           config['weights_root'],
                           config['experiment_name'],
                           config['load_weights'],
                           None,
                           strict=False,
                           load_optim=False)
        # Ignore items which we might want to overwrite from the command line
        for item in state_dict['config']:
            if item not in [
                    'z_var', 'base_root', 'batch_size', 'G_batch_size',
                    'use_ema', 'G_eval_mode'
            ]:
                config[item] = state_dict['config'][item]

    # update config (see train.py for explanation)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    config = utils.update_config_roots(config)
    config['skip_init'] = True
    config['no_optim'] = True
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    G = model.Generator(**config).cuda()
    utils.count_parameters(G)

    # Load weights
    print('Loading weights...')
    # Here is where we deal with the ema--load ema weights or load normal weights
    utils.load_weights(G if not (config['use_ema']) else None,
                       None,
                       state_dict,
                       config['weights_root'],
                       experiment_name,
                       config['load_weights'],
                       G if config['ema'] and config['use_ema'] else None,
                       strict=False,
                       load_optim=False)
    # Update batch size setting used for G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'],
                               z_var=config['z_var'])

    if config['G_eval_mode']:
        print('Putting G in eval mode..')
        G.eval()
    else:
        print('G is in %s mode...' % ('training' if G.training else 'eval'))

    # Sample function
    sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config)
    if config['accumulate_stats']:
        print('Accumulating standing stats across %d accumulations...' %
              config['num_standing_accumulations'])
        utils.accumulate_standing_stats(G, z_, y_, config['n_classes'],
                                        config['num_standing_accumulations'])

    # Sample a number of images and save them to an NPZ, for use with TF-Inception
    if config['sample_npz']:
        # Lists to hold images and labels for images
        x, y = [], []
        print('Sampling %d images and saving them to npz...' %
              config['sample_num_npz'])
        for i in trange(
                int(np.ceil(config['sample_num_npz'] / float(G_batch_size)))):
            with torch.no_grad():
                images, labels = sample()
            x += [np.uint8(255 * (images.cpu().numpy() + 1) / 2.)]
            y += [labels.cpu().numpy()]
        x = np.concatenate(x, 0)[:config['sample_num_npz']]
        y = np.concatenate(y, 0)[:config['sample_num_npz']]
        print('Images shape: %s, Labels shape: %s' % (x.shape, y.shape))
        npz_filename = '%s/%s/samples.npz' % (config['samples_root'],
                                              experiment_name)
        print('Saving npz to %s...' % npz_filename)
        np.savez(npz_filename, **{'x': x, 'y': y})

    # Prepare sample sheets
    if config['sample_sheets']:
        print('Preparing conditional sample sheets...')
        utils.sample_sheet(
            G,
            classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']],
            num_classes=config['n_classes'],
            samples_per_class=10,
            parallel=config['parallel'],
            samples_root=config['samples_root'],
            experiment_name=experiment_name,
            folder_number=config['sample_sheet_folder_num'],
            z_=z_,
        )
    # Sample interp sheets
    if config['sample_interps']:
        print('Preparing interp sheets...')
        for fix_z, fix_y in zip([False, False, True], [False, True, False]):
            utils.interp_sheet(G,
                               num_per_sheet=16,
                               num_midpoints=8,
                               num_classes=config['n_classes'],
                               parallel=config['parallel'],
                               samples_root=config['samples_root'],
                               experiment_name=experiment_name,
                               folder_number=config['sample_sheet_folder_num'],
                               sheet_number=0,
                               fix_z=fix_z,
                               fix_y=fix_y,
                               device='cuda')
    # Sample random sheet
    if config['sample_random']:
        print('Preparing random sample sheet...')
        images, labels = sample()
        torchvision.utils.save_image(images.float(),
                                     '%s/%s/random_samples.jpg' %
                                     (config['samples_root'], experiment_name),
                                     nrow=int(G_batch_size**0.5),
                                     normalize=True)

    # Get Inception Score and FID
    get_inception_metrics = inception_utils.prepare_inception_metrics(
        config['dataset'], config['parallel'], config['no_fid'])

    # Prepare a simple function get metrics that we use for trunc curves

    def get_metrics():
        sample = functools.partial(utils.sample,
                                   G=G,
                                   z_=z_,
                                   y_=y_,
                                   config=config)
        IS_mean, IS_std, FID = get_inception_metrics(
            sample,
            config['num_inception_images'],
            num_splits=10,
            prints=False)
        # Prepare output string
        outstring = 'Using %s weights ' % ('ema'
                                           if config['use_ema'] else 'non-ema')
        outstring += 'in %s mode, ' % ('eval' if config['G_eval_mode'] else
                                       'training')
        outstring += 'with noise variance %3.3f, ' % z_.var
        outstring += 'over %d images, ' % config['num_inception_images']
        if config['accumulate_stats'] or not config['G_eval_mode']:
            outstring += 'with batch size %d, ' % G_batch_size
        if config['accumulate_stats']:
            outstring += 'using %d standing stat accumulations, ' % config[
                'num_standing_accumulations']
        outstring += 'Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f' % (
            state_dict['itr'], IS_mean, IS_std, FID)
        print(outstring)

    if config['sample_inception_metrics']:
        print('Calculating Inception metrics...')
        get_metrics()

    # Sample truncation curve stuff. This is basically the same as the inception metrics code
    if config['sample_trunc_curves']:
        start, step, end = [
            float(item) for item in config['sample_trunc_curves'].split('_')
        ]
        print(
            'Getting truncation values for variance in range (%3.3f:%3.3f:%3.3f)...'
            % (start, step, end))
        for var in np.arange(start, end + step, step):
            z_.var = var
            # Optionally comment this out if you want to run with standing stats
            # accumulated at one z variance setting
            if config['accumulate_stats']:
                utils.accumulate_standing_stats(
                    G, z_, y_, config['n_classes'],
                    config['num_standing_accumulations'])
            get_metrics()
Beispiel #27
0
def run(config):
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num_fair': 0,
        'save_best_num_fid': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'best_fair_d': 999999,
        'config': config
    }

    # Optionally, get the configuration from the state dict. This allows for
    # recovery of the config provided only a state dict and experiment name,
    # and can be convenient for writing less verbose sample shell scripts.
    if config['config_from_name']:
        utils.load_weights(None,
                           None,
                           state_dict,
                           config['weights_root'],
                           config['experiment_name'],
                           config['load_weights'],
                           None,
                           strict=False,
                           load_optim=False)
        # Ignore items which we might want to overwrite from the command line
        for item in state_dict['config']:
            if item not in [
                    'z_var', 'base_root', 'batch_size', 'G_batch_size',
                    'use_ema', 'G_eval_mode'
            ]:
                config[item] = state_dict['config'][item]

    # update config (see train.py for explanation)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = 1
    if config['conditional']:
        config['n_classes'] = 2
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    config = utils.update_config_roots(config)
    config['skip_init'] = True
    config['no_optim'] = True
    device = 'cuda'
    config['sample_num_npz'] = 10000
    print(config['ema_start'])

    # Seed RNG
    # utils.seed_rng(config['seed'])  # config['seed'])

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    G = model.Generator(**config).cuda()
    utils.count_parameters(G)

    # Load weights
    print('Loading weights...')
    assert config['mode'] in ['fair', 'fid']
    print('sampling from model with best FID scores...')
    config[
        'mode'] = 'fid'  # can change this to 'fair', but this assumes access to ground-truth attribute classifier (and labels)

    # find best weights for either FID or fair checkpointing
    weights_root = config['weights_root']
    ckpts = glob.glob(
        os.path.join(weights_root, experiment_name,
                     'state_dict_best_{}*'.format(config['mode'])))
    best_ckpt = 'best_{}{}'.format(config['mode'], len(ckpts) - 1)
    config['load_weights'] = best_ckpt

    # load weights to sample from generator
    utils.load_weights(G if not (config['use_ema']) else None,
                       None,
                       state_dict,
                       weights_root,
                       experiment_name,
                       config['load_weights'],
                       G if config['ema'] and config['use_ema'] else None,
                       strict=False,
                       load_optim=False)

    # Update batch size setting used for G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'],
                               z_var=config['z_var'])
    print('Putting G in eval mode..')
    G.eval()

    # Sample function
    sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config)
    if config['accumulate_stats']:
        print('Accumulating standing stats across %d accumulations...' %
              config['num_standing_accumulations'])
        utils.accumulate_standing_stats(G, z_, y_, config['n_classes'],
                                        config['num_standing_accumulations'])

    # # # Sample a number of images and save them to an NPZ, for use with TF-Inception
    sample_path = '%s/%s/' % (config['samples_root'], experiment_name)
    print('looking in sample path {}'.format(sample_path))
    if not os.path.exists(sample_path):
        print('creating sample path: {}'.format(sample_path))
        os.mkdir(sample_path)

    # Lists to hold images and labels for images
    print('saving samples from best FID checkpoint')
    # sampling 10 sets of 10K samples
    for k in range(10):
        npz_filename = '%s/%s/fid_samples_%s.npz' % (config['samples_root'],
                                                     experiment_name, k)
        if os.path.exists(npz_filename):
            print('samples already exist, skipping...')
            continue
        x, y = [], []
        print('Sampling %d images and saving them to npz...' %
              config['sample_num_npz'])
        for i in trange(
                int(np.ceil(config['sample_num_npz'] / float(G_batch_size)))):
            with torch.no_grad():
                images, labels = sample()
            x += [np.uint8(255 * (images.cpu().numpy() + 1) / 2.)]
            y += [labels.cpu().numpy()]
        x = np.concatenate(x, 0)[:config['sample_num_npz']]
        y = np.concatenate(y, 0)[:config['sample_num_npz']]
        print('checking labels: {}'.format(y.sum()))
        print('Images shape: %s, Labels shape: %s' % (x.shape, y.shape))
        npz_filename = '%s/%s/fid_samples_%s.npz' % (config['samples_root'],
                                                     experiment_name, k)
        print('Saving npz to %s...' % npz_filename)
        np.savez(npz_filename, **{'x': x, 'y': y})

    # classify proportions
    metrics = {'l2': 0, 'l1': 0, 'kl': 0}
    l2_db = np.zeros(10)
    l1_db = np.zeros(10)
    kl_db = np.zeros(10)

    # output file
    fname = '%s/%s/fair_disc_fid_samples.p' % (config['samples_root'],
                                               experiment_name)

    # load classifier
    if not config['multi']:
        print('Pre-loading pre-trained single-attribute classifier...')
        clf_state_dict = torch.load(CLF_PATH)['state_dict']
        clf_classes = 2
    else:
        # multi-attribute
        print('Pre-loading pre-trained multi-attribute classifier...')
        clf_state_dict = torch.load(MULTI_CLF_PATH)['state_dict']
        clf_classes = 4
    # load attribute classifier here
    clf = ResNet18(block=BasicBlock,
                   layers=[2, 2, 2, 2],
                   num_classes=clf_classes,
                   grayscale=False)
    clf.load_state_dict(clf_state_dict)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    clf = clf.to(device)
    clf.eval()  # turn off batch norm

    # classify examples and get probabilties
    n_classes = 2
    if config['multi']:
        n_classes = 4

    # number of classes
    probs_db = np.zeros((10, 10000, n_classes))
    for i in range(10):
        # grab appropriate samples
        npz_filename = '{}/{}/{}_samples_{}.npz'.format(
            config['samples_root'], experiment_name, config['mode'], i)
        preds, probs = classify_examples(clf, npz_filename)
        l2, l1, kl = utils.fairness_discrepancy(preds, clf_classes)

        # save metrics
        l2_db[i] = l2
        l1_db[i] = l1
        kl_db[i] = kl
        probs_db[i] = probs
        print('fair_disc for iter {} is: l2:{}, l1:{}, kl:{}'.format(
            i, l2, l1, kl))
    metrics['l2'] = l2_db
    metrics['l1'] = l1_db
    metrics['kl'] = kl_db
    print('fairness discrepancies saved in {}'.format(fname))
    print(l2_db)

    # save all metrics
    with open(fname, 'wb') as fp:
        pickle.dump(metrics, fp)
    np.save(
        os.path.join(config['samples_root'], experiment_name, 'clf_probs.npy'),
        probs_db)