def run(config_parser): # Geting config experiment_name = model_name = config_parser['experiment_name'] model_path = '%s/%s' % (config_parser['weights_path'], model_name) logs_path = '%s/%s' % (config_parser['logs_path'], model_name) config_path = '%s/metalog.txt' % logs_path new_file = 'saved_stuff' save_path = '%s/%s' % (model_path, new_file) if not os.path.exists(save_path): os.mkdir(save_path) device = 'cuda' file = open(config_path, 'r') all_file = file.read() fs1 = all_file.find('{') fs2 = all_file.find('}') config = all_file[fs1:fs2 + 1] import ast config = config.replace(", 'G_activation': ReLU()", "") config = config.replace(", 'D_activation': ReLU()", "") config = ast.literal_eval(config) config['samples_root'] = 'samples_test' config['skip_init'] = True #config['no_optim'] = True # Loading Model config['weights_root'] = config_parser['weights_path'] model = __import__(config['model']) utils.seed_rng(config['seed']) # Prepare root folders if necessary utils.prepare_root(config) G = model.Generator(**config).to(device) D = model.Discriminator(**config).to(device) if config['is_encoder']: E = model.Encoder(**{**config, 'D': D}).to(device) Prior = layers.Prior(**config).to(device) GE = model.G_E(G, E, Prior) utils.load_weights( G, None, '', config['weights_root'], model_name, config_parser['name_suffix'], G if config['ema'] else None, E=None if not config['is_encoder'] else E, Prior=Prior if not config['prior_type'] == 'default' else None) # Sample functions sample = functools.partial(utils.sample, G=G, Prior=Prior, config=config) # Accumulate stats? samples_name = 'samples' if config_parser['accumulate_stats']: samples_name = 'samples_acc' utils.accumulate_standing_stats(G, Prior, config['n_classes'], config['num_standing_accumulations']) if config_parser['name_suffix'] is not None: samples_name += config_parser['name_suffix'] # Sample and save in npz config['sample_npz'] = True config['sample_num_npz'] = 50000 G_batch_size = Prior.bs # Sample a number of images and save them to an NPZ, for use with TF-Inception if config['sample_npz']: # Lists to hold images and labels for images x, y = [], [] print('Sampling %d images and saving them to npz...' % config['sample_num_npz']) for i in trange( int(np.ceil(config['sample_num_npz'] / float(G_batch_size)))): with torch.no_grad(): images, labels = sample() x += [np.uint8(255 * (images.cpu().numpy() + 1) / 2.)] y += [labels.cpu().numpy()] x = np.concatenate(x, 0)[:config['sample_num_npz']] y = np.concatenate(y, 0)[:config['sample_num_npz']] print('Images shape: %s, Labels shape: %s' % (x.shape, y.shape)) npz_filename = '%s/%s/%s.npz' % (config['weights_root'], experiment_name, samples_name) print('Num %d' % len(x)) print('Saving npz to %s...' % npz_filename) np.savez(npz_filename, **{'x': x, 'y': y}) # Reconstruction metrics D_batch_size = (config['batch_size'] * config['num_D_steps'] * config['num_D_accumulations']) config_aux = config.copy() config_aux['augment'] = False dataloader_noaug = utils.get_data_loaders( **{ **config_aux, 'batch_size': D_batch_size }) if config_parser['accumulate_stats']: utils.accumulate_standing_stats_E(E, Prior, dataloader_noaug, device, config) test_acc, test_acc_iter, error_rec = train_fns.test_accuracy( GE, dataloader_noaug, device, config['D_fp16'], config) json_metric_name = samples_name + '_json' if not os.path.isfile('%s/%s.json' % (model_path, json_metric_name)): metric_dict = {} metric_dict['test_acc'] = test_acc metric_dict['error_rec'] = error_rec json.dump(metric_dict, open('%s/%s.json' % (model_path, json_metric_name), 'w')) else: metric_dict = json.load( open('%s/%s.json' % (model_path, json_metric_name))) metric_dict['inception_mean'] = test_acc metric_dict['inception_std'] = error_rec json.dump(metric_dict, open('%s/%s.json' % (model_path, json_metric_name), 'w'))
def run(config): # Update the config dict as necessary # This is for convenience, to add settings derived from the user-specified # configuration into the config-dict (e.g. inferring the number of classes # and size of the images from the dataset, passing in a pytorch object # for the activation specified as a string) ## *** 新增 resolution 使用 I128_hdf5 数据集, 这里也许需要使用 C10数据集 config['resolution'] = utils.imsize_dict[config['dataset']] ## *** 新增 nclass_dict 加载 I128_hdf5 的类别, 这里也许需要使用 C10的类别 10类 config['n_classes'] = utils.nclass_dict[config['dataset']] ## 加载 GD的 激活函数, 都用Relu, 这里的Relu是小写,不知道是否要改大写R config['G_activation'] = utils.activation_dict[config['G_nl']] config['D_activation'] = utils.activation_dict[config['D_nl']] ## 从头训练吧,么有历史的参数,不用改,默认的就是 # By default, skip init if resuming training. if config['resume']: print('Skipping initialization for training resumption...') config['skip_init'] = True ## 日志加载,也不用改应该 config = utils.update_config_roots(config) device = 'cuda' # Seed RNG ## 设置初始随机数种子,都为0,*** 需要修改为paddle的设置 utils.seed_rng(config['seed']) # Prepare root folders if necessary ## 设置日志根目录,这个应该也不用改 utils.prepare_root(config) # Setup cudnn.benchmark for free speed ## @@@ 这里不需要更改,直接注释掉,Paddle不一定需要这个设置 ## 用于加速固定网络结构的参数 # torch.backends.cudnn.benchmark = True # Import the model--this line allows us to dynamically select different files. ## *** !!! 这个方法很酷哦,直接导入BigGan的model,要看一下BigGAN里面的网络结构配置 model = __import__(config['model']) ## 不用改,把一系列配置作为名字放到了实验名称中 experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) print('Experiment name is %s' % experiment_name) # Next, build the model ## *** 导入参数,需要修改两个方法 G = model.Generator(**config).to(device) D = model.Discriminator(**config).to(device) # If using EMA, prepare it ## *** 默认不开,可以先不改EMA部分 if config['ema']: print('Preparing EMA for G with decay of {}'.format( config['ema_decay'])) G_ema = model.Generator(**{ **config, 'skip_init': True, 'no_optim': True }).to(device) ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start']) else: G_ema, ema = None, None # FP16? ## C10比较小,G和D这部分也可以暂时不改,使用默认精度 if config['G_fp16']: print('Casting G to float16...') G = G.half() if config['ema']: G_ema = G_ema.half() if config['D_fp16']: print('Casting D to fp16...') D = D.half() # Consider automatically reducing SN_eps? ## 把设置完结构G和D打包放入结构模型G_D中 GD = model.G_D(G, D) ## *** 这两个print也许可以删掉,没必要。可能源于继承的nn.Module的一些打印属性 print(G) print(D) ## *** 这个parameters也是继承torch的属性 print('Number of params in G: {} D: {}'.format( * [sum([p.data.nelement() for p in net.parameters()]) for net in [G, D]])) # Prepare state dict, which holds things like epoch # and itr # ## 初始化统计参数记录表 不用变动 state_dict = { 'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config } # If loading from a pre-trained model, load weights ## 暂时不用预训练,所以这一块不用更改 if config['resume']: print('Loading weights...') utils.load_weights( G, D, state_dict, config['weights_root'], experiment_name, config['load_weights'] if config['load_weights'] else None, G_ema if config['ema'] else None) # If parallel, parallelize the GD module ## 暂时不用管,GD 默认不并行 if config['parallel']: GD = nn.DataParallel(GD) if config['cross_replica']: patch_replication_callback(GD) ## 日志中心,应该也可以不用管,如果需要就是把IS和FID的结果看看能不能抽出来 # Prepare loggers for stats; metrics holds test metrics, # lmetrics holds any desired training metrics. test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'], experiment_name) train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name) print('Inception Metrics will be saved to {}'.format(test_metrics_fname)) test_log = utils.MetricsLogger(test_metrics_fname, reinitialize=(not config['resume'])) print('Training Metrics will be saved to {}'.format(train_metrics_fname)) train_log = utils.MyLogger(train_metrics_fname, reinitialize=(not config['resume']), logstyle=config['logstyle']) ## 这个才是重要的,这个是用来做结果统计的。 # Write metadata utils.write_metadata(config['logs_root'], experiment_name, config, state_dict) ## *** D的数据加载,加载的过程中,get_data_loaders用到了torchvision的transforms方法 # Prepare data; the Discriminator's batch size is all that needs to be passed # to the dataloader, as G doesn't require dataloading. # Note that at every loader iteration we pass in enough data to complete # a full D iteration (regardless of number of D steps and accumulations) D_batch_size = (config['batch_size'] * config['num_D_steps'] * config['num_D_accumulations']) loaders = utils.get_data_loaders(**{ **config, 'batch_size': D_batch_size, 'start_itr': state_dict['itr'] }) ## 准备评价指标,FID和IS的计算流程,可以使用np版本计算,也不用改 # Prepare inception metrics: FID and IS get_inception_metrics = inception_utils.prepare_inception_metrics( config['dataset'], config['parallel'], config['no_fid']) ## 准备噪声和随机采样的标签组 # Prepare noise and randomly sampled label arrays # Allow for different batch sizes in G G_batch_size = max(config['G_batch_size'], config['batch_size']) ## *** 有一部分torch的numpy用法,需要更改一下,获得噪声和标签 z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) # Prepare a fixed z & y to see individual sample evolution throghout training ## *** 有一部分torch的numpy用法,需要更改一下,获得噪声和标签 ## TODO 获得两份噪声和标签,有社么用意吗? fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) ## *** 从Distribution中获得采样的方法,可以选择高斯采样和categorical采样 fixed_z.sample_() fixed_y.sample_() # Loaders are loaded, prepare the training function ## *** 实例化GAN_training_function训练流程 if config['which_train_fn'] == 'GAN': train = train_fns.GAN_training_function(G, D, GD, z_, y_, ema, state_dict, config) # Else, assume debugging and use the dummy train fn ## 如果没有指定训练模型,那么就用假训走一下流程Debug else: train = train_fns.dummy_training_function() # Prepare Sample function for use with inception metrics ## *** 把函数utils.sample中部分入参事先占掉,定义为新的函数sample sample = functools.partial( utils.sample, G=(G_ema if config['ema'] and config['use_ema'] else G), z_=z_, y_=y_, config=config) print('Beginning training at epoch %d...' % state_dict['epoch']) # Train for specified number of epochs, although we mostly track G iterations. for epoch in range(state_dict['epoch'], config['num_epochs']): # Which progressbar to use? TQDM or my own? if config['pbar'] == 'mine': ## 这一部分无需翻 ## !!! loaders[0] 代表了数据采样对象 pbar = utils.progress(loaders[0], displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta') else: pbar = tqdm(loaders[0]) for i, (x, y) in enumerate(pbar): # Increment the iteration counter state_dict['itr'] += 1 # Make sure G and D are in training mode, just in case they got set to eval # For D, which typically doesn't have BN, this shouldn't matter much. ## *** 继承nn.Module中的train, 对应的是 G.train() D.train() if config['ema']: G_ema.train() if config['D_fp16']: x, y = x.to(device).half(), y.to(device) else: x, y = x.to(device), y.to(device) ## *** 把数据和标签放入训练函数里,train本身有很多需要改写 metrics = train(x, y) ## 记录日志,把metrics信息都输入日志 train_log.log(itr=int(state_dict['itr']), **metrics) # Every sv_log_interval, log singular values ## 记录资格迹的变化日志 if (config['sv_log_interval'] > 0) and ( not (state_dict['itr'] % config['sv_log_interval'])): train_log.log(itr=int(state_dict['itr']), **{ **utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D') }) # If using my progbar, print metrics. if config['pbar'] == 'mine': print(', '.join( ['itr: %d' % state_dict['itr']] + ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]), end=' ') # Save weights and copies as configured at specified interval ## 默认每2000步记录一次结果 if not (state_dict['itr'] % config['save_every']): if config['G_eval_mode']: print('Switchin G to eval mode...') ## *** module中的方法 G.eval() ## 如果采用指数滑动平均 if config['ema']: G_ema.eval() train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, state_dict, config, experiment_name) # Test every specified interval ## 默认每5000步测试一次 if not (state_dict['itr'] % config['test_every']): if config['G_eval_mode']: print('Switchin G to eval mode...') G.eval() train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample, get_inception_metrics, experiment_name, test_log) # Increment epoch counter at end of epoch state_dict['epoch'] += 1
def run(config): config['resolution'] = utils.imsize_dict[config['dataset']] config['n_classes'] = utils.nclass_dict[config['dataset']] config['G_activation'] = utils.activation_dict[config['G_nl']] config['D_activation'] = utils.activation_dict[config['D_nl']] # By default, skip init if resuming training. if config['resume']: print('Skipping initialization for training resumption...') config['skip_init'] = True config = utils.update_config_roots(config) device = 'cuda' # Seed RNG utils.seed_rng(config['seed']) # Prepare root folders if necessary utils.prepare_root(config) # Setup cudnn.benchmark for free speed torch.backends.cudnn.benchmark = True experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) print('Experiment name is %s' % experiment_name) model = BigGAN # Next, build the model G = model.Generator(**config).to(device) D = model.Discriminator(**config).to(device) # If using EMA, prepare it (Earth Moving Averaging for parameters) if config['ema']: print('Preparing EMA for G with decay of {}'.format( config['ema_decay'])) G_ema = model.Generator(**{ **config, 'skip_init': True, 'no_optim': True }).to(device) ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start']) else: ema = None GD = model.G_D(G, D) print('Number of params in G: {} D: {}'.format( * [sum([p.data.nelement() for p in net.parameters()]) for net in [G, D]])) # Prepare state dict, which holds things like epoch # and itr # state_dict = { 'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config } # If loading from a pre-trained model, load weights if config['resume']: print('Loading weights...') utils.load_weights( G, D, state_dict, config['weights_root'], experiment_name, config['load_weights'] if config['load_weights'] else None, G_ema if config['ema'] else None) # If parallel, parallelize the GD module if config['parallel']: GD = nn.DataParallel(GD) # Prepare loggers for stats; metrics holds test metrics, # lmetrics holds any desired training metrics. test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'], experiment_name) train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name) print('Inception Metrics will be saved to {}'.format(test_metrics_fname)) test_log = utils.MetricsLogger(test_metrics_fname, reinitialize=(not config['resume'])) print('Training Metrics will be saved to {}'.format(train_metrics_fname)) # Write metadata utils.write_metadata(config['logs_root'], experiment_name, config, state_dict) D_batch_size = (config['batch_size'] * config['num_D_steps'] * config['num_D_accumulations']) loaders = utils.get_data_loaders(**{ **config, 'batch_size': D_batch_size, 'start_itr': state_dict['itr'] }) # Prepare inception metrics: FID and IS get_inception_metrics = inception_utils.prepare_inception_metrics( config['dataset'], config['parallel'], config['no_fid']) # Prepare noise and randomly sampled label arrays # Allow for different batch sizes in G G_batch_size = max(config['G_batch_size'], config['batch_size']) z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) # Prepare a fixed z & y to see individual sample evolution throghout training fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) fixed_z.sample_() fixed_y.sample_() # Loaders are loaded, prepare the training function train = train_fns.GAN_training_function(G, D, GD, z_, y_, ema, state_dict, config) # Prepare Sample function for use with inception metrics sample = functools.partial( utils.sample, G=(G_ema if config['ema'] and config['use_ema'] else G), z_=z_, y_=y_, config=config) print('Beginning training at epoch %d...' % state_dict['epoch']) # Train for specified number of epochs, although we mostly track G iterations. for epoch in range(state_dict['epoch'], config['num_epochs']): pbar = utils.progress( loaders[0], displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta') for i, (x, y) in enumerate(pbar): # Increment the iteration counter state_dict['itr'] += 1 # Make sure G and D are in training mode, just in case they got set to eval # For D, which typically doesn't have BN, this shouldn't matter much. G.train() D.train() if config['ema']: G_ema.train() if config['D_fp16']: x, y = x.to(device).half(), y.to(device) else: x, y = x.to(device), y.to(device) metrics = train(x, y) print(', '.join( ['itr: %d' % state_dict['itr']] + ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]), end=' ') # Save weights and copies as configured at specified interval if not (state_dict['itr'] % config['save_every']): if config['G_eval_mode']: print('Switchin G to eval mode...') G.eval() if config['ema']: G_ema.eval() train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, state_dict, config, experiment_name) # Test every specified interval if not (state_dict['itr'] % config['test_every']): if config['G_eval_mode']: print('Switchin G to eval mode...') G.eval() train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample, get_inception_metrics, experiment_name, test_log) # Increment epoch counter at end of epoch state_dict['epoch'] += 1
def run(config): def len_parallelloader(self): return len(self._loader._loader) pl.PerDeviceLoader.__len__ = len_parallelloader # Update the config dict as necessary # This is for convenience, to add settings derived from the user-specified # configuration into the config-dict (e.g. inferring the number of classes # and size of the images from the dataset, passing in a pytorch object # for the activation specified as a string) config['resolution'] = utils.imsize_dict[config['dataset']] config['n_classes'] = utils.nclass_dict[config['dataset']] config['G_activation'] = utils.activation_dict[config['G_nl']] config['D_activation'] = utils.activation_dict[config['D_nl']] # By default, skip init if resuming training. if config['resume']: xm.master_print('Skipping initialization for training resumption...') config['skip_init'] = True config = utils.update_config_roots(config) # Seed RNG utils.seed_rng(config['seed']) # Prepare root folders if necessary utils.prepare_root(config) # Setup cudnn.benchmark for free speed torch.backends.cudnn.benchmark = True # Import the model--this line allows us to dynamically select different # files. model = __import__(config['model']) experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) xm.master_print('Experiment name is %s' % experiment_name) device = xm.xla_device(devkind='TPU') # Next, build the model G = model.Generator(**config) D = model.Discriminator(**config) # If using EMA, prepare it if config['ema']: xm.master_print( 'Preparing EMA for G with decay of {}'.format( config['ema_decay'])) G_ema = model.Generator(**{**config, 'skip_init': True, 'no_optim': True}) else: xm.master_print('Not using ema...') G_ema, ema = None, None # FP16? if config['G_fp16']: xm.master_print('Casting G to float16...') G = G.half() if config['ema']: G_ema = G_ema.half() if config['D_fp16']: xm.master_print('Casting D to fp16...') D = D.half() # Prepare state dict, which holds things like itr # state_dict = {'itr': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config} # If loading from a pre-trained model, load weights if config['resume']: xm.master_print('Loading weights...') utils.load_weights( G, D, state_dict, config['weights_root'], experiment_name, config['load_weights'] if config['load_weights'] else None, G_ema if config['ema'] else None) # move everything to TPU G = G.to(device) D = D.to(device) G.optim = optim.Adam(params=G.parameters(), lr=G.lr, betas=(G.B1, G.B2), weight_decay=0, eps=G.adam_eps) D.optim = optim.Adam(params=D.parameters(), lr=D.lr, betas=(D.B1, D.B2), weight_decay=0, eps=D.adam_eps) # for key, val in G.optim.state.items(): # G.optim.state[key]['exp_avg'] = G.optim.state[key]['exp_avg'].to(device) # G.optim.state[key]['exp_avg_sq'] = G.optim.state[key]['exp_avg_sq'].to(device) # for key, val in D.optim.state.items(): # D.optim.state[key]['exp_avg'] = D.optim.state[key]['exp_avg'].to(device) # D.optim.state[key]['exp_avg_sq'] = D.optim.state[key]['exp_avg_sq'].to(device) if config['ema']: G_ema = G_ema.to(device) ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start']) # Consider automatically reducing SN_eps? GD = model.G_D(G, D) xm.master_print(G) xm.master_print(D) xm.master_print('Number of params in G: {} D: {}'.format( *[sum([p.data.nelement() for p in net.parameters()]) for net in [G, D]])) # Prepare loggers for stats; metrics holds test metrics, # lmetrics holds any desired training metrics. test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'], experiment_name) train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name) xm.master_print( 'Test Metrics will be saved to {}'.format(test_metrics_fname)) test_log = utils.MetricsLogger(test_metrics_fname, reinitialize=(not config['resume'])) xm.master_print( 'Training Metrics will be saved to {}'.format(train_metrics_fname)) train_log = utils.MyLogger(train_metrics_fname, reinitialize=(not config['resume']), logstyle=config['logstyle']) if xm.is_master_ordinal(): # Write metadata utils.write_metadata( config['logs_root'], experiment_name, config, state_dict) # Prepare data; the Discriminator's batch size is all that needs to be passed # to the dataloader, as G doesn't require dataloading. # Note that at every loader iteration we pass in enough data to complete # a full D iteration (regardless of number of D steps and accumulations) D_batch_size = (config['batch_size'] * config['num_D_steps'] * config['num_D_accumulations']) xm.master_print('Preparing data...') loader = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size, 'start_itr': state_dict['itr']}) # Prepare inception metrics: FID and IS xm.master_print('Preparing metrics...') get_inception_metrics = inception_utils.prepare_inception_metrics( config['dataset'], config['parallel'], no_inception=config['no_inception'], no_fid=config['no_fid']) # Prepare noise and randomly sampled label arrays # Allow for different batch sizes in G G_batch_size = max(config['G_batch_size'], config['batch_size']) def sample(): return utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) # Prepare a fixed z & y to see individual sample evolution throghout # training fixed_z, fixed_y = sample() train = train_fns.GAN_training_function(G, D, GD, sample, ema, state_dict, config) xm.master_print('Beginning training...') if xm.is_master_ordinal(): pbar = tqdm(total=config['total_steps']) pbar.n = state_dict['itr'] pbar.refresh() xm.rendezvous('training_starts') while (state_dict['itr'] < config['total_steps']): pl_loader = pl.ParallelLoader( loader, [device]).per_device_loader(device) for i, (x, y) in enumerate(pl_loader): if xm.is_master_ordinal(): # Increment the iteration counter pbar.update(1) state_dict['itr'] += 1 # Make sure G and D are in training mode, just in case they got set to eval # For D, which typically doesn't have BN, this shouldn't matter # much. G.train() D.train() if config['ema']: G_ema.train() xm.rendezvous('data_collection') metrics = train(x, y) # train_log.log(itr=int(state_dict['itr']), **metrics) # Every sv_log_interval, log singular values if ((config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval']))) : if xm.is_master_ordinal(): train_log.log(itr=int(state_dict['itr']), **{**utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D')}) xm.rendezvous('Log SVs.') # Save weights and copies as configured at specified interval if (not (state_dict['itr'] % config['save_every'])): if config['G_eval_mode']: xm.master_print('Switchin G to eval mode...') G.eval() if config['ema']: G_ema.eval() train_fns.save_and_sample( G, D, G_ema, sample, fixed_z, fixed_y, state_dict, config, experiment_name) # Test every specified interval if (not (state_dict['itr'] % config['test_every'])): which_G = G_ema if config['ema'] and config['use_ema'] else G if config['G_eval_mode']: xm.master_print('Switchin G to eval mode...') which_G.eval() def G_sample(): z, y = sample() return which_G(z, which_G.shared(y)) train_fns.test( G, D, G_ema, sample, state_dict, config, G_sample, get_inception_metrics, experiment_name, test_log) # Debug : Message print # if True: # xm.master_print(met.metrics_report()) if state_dict['itr'] >= config['total_steps']: break
def run(config): # By default, skip init if resuming training. if config['resume']: print('Skipping initialization for training resumption...') config['skip_init'] = True config = utils.update_config_roots(config) if config['cuda']: assert torch.cuda.is_available() device = torch.device('cuda') torch.backends.cudnn.benchmark = True else: device = torch.device('cpu') # Seed RNG utils.seed_rng(config['seed']) # Prepare root folders if necessary utils.prepare_root(config) # Import the model--this line allows us to dynamically select different files. experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) print('Experiment name is %s' % experiment_name) batch_size = config['batch_size'] num_class = config['num_class'] num_domain = config['num_domain'] dim_z = config['dim_z'] state_dict = { 'epoch': 0, 'iterations': 0, 'best_score': 0, 'final_score': 0, 'config': config } if config['trainer'] == 'DA_Infer_TAC_Adv': trainer = DA_Infer_TAC_Adv(config) if config['trainer'] == 'DA_Infer_AC_Adv': trainer = DA_Infer_AC_Adv(config) trainer.to(device) if config['resume']: state_dict = utils.resume_state( os.path.join(config['weights_root'], experiment_name)) trainer.resume(os.path.join(config['weights_root'], experiment_name)) trainer.to(device) iterations = state_dict['iterations'] # config tensorboard writer log_folder = os.path.join(config['logs_root'], experiment_name) writer = SummaryWriter(log_folder) # load datasets train_dataset_specs = { 'class_name': config['dataset'], 'seed': config['seed'], 'train': True, 'root': join(config['data_root'], config['dataset']), 'num_train': config['num_train'], 'num_domain': config['num_domain'], 'num_class': config['num_class'], 'dim': config['idim'], 'dim_d': config['dim_d'], 'dag_mat_file': config['dag_mat_file'], 'useMB': config['useMB'], 'tar_id': config['tar_id'], 'resolution': config['resolution'] } train_loader = utils.get_data_loader(train_dataset_specs, batch_size, config['num_workers']) test_dataset_specs = { 'class_name': config['dataset'], 'seed': config['seed'], 'train': False, 'root': join(config['data_root'], config['dataset']), 'num_train': config['num_train'], 'num_domain': config['num_domain'], 'num_class': config['num_class'], 'dim': config['idim'], 'dim_d': config['dim_d'], 'dag_mat_file': config['dag_mat_file'], 'useMB': config['useMB'], 'tar_id': config['tar_id'], 'resolution': config['resolution'] } test_loader = utils.get_data_loader(test_dataset_specs, batch_size, config['num_workers']) # get sample size in each domain do_ss = torch.zeros((num_domain, 1)) for do_i in range(num_domain): do_ss[do_i] = torch.Tensor([ (train_loader.dataset.labels[:, 1] == do_i).sum() ]).item() config['do_ss'] = do_ss # fixed noise for illustration fixed_noise = torch.randn(100, dim_z, device=device) label_gpu = torch.zeros(100, device=device, dtype=torch.int64) for i in range(num_class): label_gpu[i * num_class:(i + 1) * num_class] = i label_fixed = label_gpu.repeat(num_domain) y_onehot = torch.nn.functional.one_hot(label_fixed, num_class).float() label_fixed = y_onehot fixed_noise = fixed_noise.repeat(num_domain, 1) # training Diters = config['num_D_steps'] Giters = config['num_G_steps'] Diter = 0 Giter = Giters best_score = state_dict['best_score'] for ep in range(state_dict['epoch'], config['num_epochs']): state_dict['epoch'] = ep # test_acc_target_c = test_acc(trainer.dis, test_loader, device=device) # writer.add_scalar('test_acc_target_ac', test_acc_target_c, ep) # # writer.add_scalar('test_acc_target_tac', test_acc_target_ct, ep) # if test_acc_target_c > best_score: # best_score = test_acc_target_c # state_dict['best_score'] = best_score # if ep == config['num_epochs'] - 1: # state_dict['final_score'] = test_acc_target_c for it, (x, y) in enumerate(train_loader): if x.size(0) != batch_size: continue trainer.gen.train() trainer.dis.train() x = x.to(device) y = y.to(device).view(x.size(0), 2) if Diter < Diters: trainer.dis_update(x, y, config, state_dict, device) Diter += 1 if Diter == Diters: Giter = 0 if Giter < Giters and Diter > Diters: trainer.gen_update(x, y, config, state_dict, device) Giter += 1 if Giter == Giters: Diter = 0 if Diter == Diters: Diter += 1 # Dump training stats in log file if (iterations + 1) % config['save_every'] == 0: writer.add_scalar('gan_loss', trainer.gan_loss, iterations) writer.add_scalar('aux_loss_c_fake', trainer.aux_loss_c, iterations) writer.add_scalar('aux_loss_c_real', trainer.aux_loss_c1, iterations) writer.add_scalar('aux_loss_ct', trainer.aux_loss_ct, iterations) writer.add_scalar('aux_loss_d_fake', trainer.aux_loss_d, iterations) writer.add_scalar('aux_loss_d_real', trainer.aux_loss_d1, iterations) writer.add_scalar('aux_loss_dt', trainer.aux_loss_dt, iterations) writer.add_scalar('aux_loss_cls_real', trainer.aux_loss_cls1, iterations) writer.add_scalar('aux_loss_cls_fake', trainer.aux_loss_cls, iterations) if (iterations + 1) % config['display_every'] == 0: print( "Epoch: %04d, Iteration: %08d, gan loss: %.2f, ac_fake: %.2f, ac_real: %.2f, ac_twin: %.2f, ad_fake: %.2f, " "ad_real: %.2f, ad_twin: %.2f, cls_real: %.2f, cls_fake: %.2f" % (ep + 1, iterations + 1, trainer.gan_loss, trainer.aux_loss_c, trainer.aux_loss_c1, trainer.aux_loss_ct, trainer.aux_loss_d, trainer.aux_loss_d1, trainer.aux_loss_dt, trainer.aux_loss_cls1, trainer.aux_loss_cls)) # Save network weights if (iterations + 1) % config['save_every'] == 0: trainer.save( os.path.join(config['weights_root'], experiment_name), state_dict) iterations += 1 state_dict['iterations'] = iterations # print variational parameters if config['estimate'] == 'Bayesian': print(trainer.gen.mu.squeeze()) print(trainer.gen.sigma.squeeze()) # save sample images one = torch.ones(100, 1, device=device, dtype=torch.int64) label_d = one * 0 for i in range(1, num_domain): label_d = torch.cat((label_d, one * i)) label_d_onehot = torch.nn.functional.one_hot(label_d.squeeze(), num_domain).float() trainer.gen.eval() with torch.no_grad(): if config['estimate'] == 'ML': fake_img = trainer.gen(fixed_noise, label_fixed, label_d_onehot) elif config['estimate'] == 'Bayesian': noise_d = torch.randn(num_domain, config['dim_d']).to(device) fake_img, _ = trainer.gen(fixed_noise, label_fixed, label_d_onehot, noise_d) img_name = os.path.join(config['samples_root'], experiment_name + '_gen.jpg') torchvision.utils.save_image(fake_img.mul(0.5).add(0.5), img_name, nrow=10) trainer.gen.train()
def run(config): config['resolution'] = utils.imsize_dict[config['dataset']] config['n_classes'] = utils.nclass_dict[config['dataset']] config['G_activation'] = utils.activation_dict[config['G_nl']] config['D_activation'] = utils.activation_dict[config['D_nl']] # By default, skip init if resuming training. if config['resume']: print('Skipping initialization for training resumption...') config['skip_init'] = True config = utils.update_config_roots(config) device = 'cuda' # Seed RNG utils.seed_rng(config['seed']) # Prepare root folders if necessary utils.prepare_root(config) # Setup cudnn.benchmark for free speed torch.backends.cudnn.benchmark = True # Import the model--this line allows us to dynamically select different files. model = __import__(config['model']) experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) # Next, build the model G = model.Generator(**config).to(device) D = model.Discriminator(**config).to(device) E = model.ImgEncoder(**config).to(device) GDE = model.G_D_E(G, D, E) # Prepare state dict, which holds things like epoch # and itr # state_dict = { 'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config } print('Number of params in G: {} D: {} E: {}'.format(*[ sum([p.data.nelement() for p in net.parameters()]) for net in [G, D, E] ])) print('Loading weights...') utils.load_weights( G, D, E, state_dict, config['weights_root'], experiment_name, config['load_weights'] if config['load_weights'] else None, None, strict=False, load_optim=False) # ============================================================================== # prepare the data loaders, train_dataset = utils.get_data_loaders(**config) G_batch_size = max(config['G_batch_size'], config['batch_size']) z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device) # Prepare a fixed z & y to see individual sample evolution throghout training fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device) fixed_z.sample_() fixed_y.sample_() print("fixed_y original: {} {}".format(fixed_y.shape, fixed_y[:10])) fixed_x, fixed_y_of_x = utils.prepare_x_y(G_batch_size, train_dataset, experiment_name, config) evaluate_sample(config, fixed_x, fixed_y, G, E, experiment_name, attack=True)
def run(config): # Update the config dict as necessary # This is for convenience, to add settings derived from the user-specified # configuration into the config-dict (e.g. inferring the number of classes # and size of the images from the dataset, passing in a pytorch object # for the activation specified as a string) config['resolution'] = utils.imsize_dict[config['dataset']] config['n_classes'] = utils.nclass_dict[config['dataset']] config['G_activation'] = utils.activation_dict[config['G_nl']] config['D_activation'] = utils.activation_dict[config['D_nl']] # By default, skip init if resuming training. if config['resume']: print('Skipping initialization for training resumption...') config['skip_init'] = True config = utils.update_config_roots(config) device = 'cuda' # Seed RNG utils.seed_rng(config['seed']) # Prepare root folders if necessary utils.prepare_root(config) # Setup cudnn.benchmark for free speed torch.backends.cudnn.benchmark = True # Import the model--this line allows us to dynamically select different files. model = __import__(config['model']) experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) print('Experiment name is %s' % experiment_name) # Next, build the model G = model.Generator(**config).to(device) D = model.Discriminator(**config).to(device) E = model.ImgEncoder(**config).to(device) # E = model.Encoder(**config).to(device) # If using EMA, prepare it if config['ema']: print('Preparing EMA for G with decay of {}'.format( config['ema_decay'])) G_ema = model.Generator(**{ **config, 'skip_init': True, 'no_optim': True }).to(device) ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start']) else: G_ema, ema = None, None # FP16? if config['G_fp16']: print('Casting G to float16...') G = G.half() if config['ema']: G_ema = G_ema.half() if config['D_fp16']: print('Casting D to fp16...') D = D.half() # Consider automatically reducing SN_eps? GDE = model.G_D_E(G, D, E) print('Number of params in G: {} D: {} E: {}'.format(*[ sum([p.data.nelement() for p in net.parameters()]) for net in [G, D, E] ])) # Prepare state dict, which holds things like epoch # and itr # state_dict = { 'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config } # If loading from a pre-trained model, load weights if config['resume']: print('Loading weights...') utils.load_weights( G, D, E, state_dict, config['weights_root'], experiment_name, config['load_weights'] if config['load_weights'] else None, G_ema if config['ema'] else None) # If parallel, parallelize the GD module if config['parallel']: GDE = nn.DataParallel(GDE) if config['cross_replica']: patch_replication_callback(GDE) # Prepare data; the Discriminator's batch size is all that needs to be passed # to the dataloader, as G doesn't require dataloading. # Note that at every loader iteration we pass in enough data to complete # a full D iteration (regardless of number of D steps and accumulations) D_batch_size = (config['batch_size'] * config['num_D_steps'] * config['num_D_accumulations']) loaders, train_dataset = utils.get_data_loaders( **{ **config, 'batch_size': D_batch_size, 'start_itr': state_dict['itr'] }) # Prepare noise and randomly sampled label arrays # Allow for different batch sizes in G G_batch_size = max(config['G_batch_size'], config['batch_size']) z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) # Prepare a fixed z & y to see individual sample evolution throghout training fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) fixed_z.sample_() fixed_y.sample_() print("fixed_y original: {} {}".format(fixed_y.shape, fixed_y[:10])) ## TODO: change the sample method to sample x and y fixed_x, fixed_y_of_x = utils.prepare_x_y(G_batch_size, train_dataset, experiment_name, config) # Build image pool to prevent mode collapes if config['img_pool_size'] != 0: img_pool = ImagePool(config['img_pool_size'], train_dataset.num_class,\ save_dir=os.path.join(config['imgbuffer_root'], experiment_name), resume_buffer=config['resume_buffer']) else: img_pool = None # Loaders are loaded, prepare the training function if config['which_train_fn'] == 'GAN': train = train_fns.GAN_training_function(G, D, E, GDE, ema, state_dict, config, img_pool) # Else, assume debugging and use the dummy train fn else: train = train_fns.dummy_training_function() # Prepare Sample function for use with inception metrics sample = functools.partial( utils.sample, G=(G_ema if config['ema'] and config['use_ema'] else G), z_=z_, y_=y_, config=config) # print('Beginning training at epoch %f...' % (state_dict['itr'] * D_batch_size / len(train_dataset))) print("Beginning testing at Epoch {} (iteration {})".format( state_dict['epoch'], state_dict['itr'])) if config['G_eval_mode']: print('Switchin G to eval mode...') G.eval() if config['ema']: G_ema.eval() fixed_x, fixed_Gz, intermediates = activation_extract( G, D, E, G_ema, fixed_x, fixed_y_of_x, z_, y_, state_dict, config, experiment_name, save_weights=config['save_weights']) plot_channel_activation(intermediates, config['img_index'], fixed_Gz[config['img_index']], experiment_name, config, state_dict)
def load_G(config): """ Loads a pre-trained BigGAN generator network (exponential moving average variant). """ # config['resolution'] = utils.imsize_dict[config['dataset']] config['resolution'] = config['resolution'] config['n_classes'] = utils.nclass_dict[config['dataset']] config['G_activation'] = utils.activation_dict[config['G_nl']] config = utils.update_config_roots(config) device = 'cuda' # Seed RNG utils.seed_rng(config['seed']) # Prepare root folders if necessary utils.prepare_root(config) # Setup cudnn.benchmark for free speed torch.backends.cudnn.benchmark = True # Import the model--this line allows us to dynamically select different files. model = __import__(config['model']) experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) G = model.Generator(**{ **config, 'skip_init': True, 'no_optim': True }).to(device) print('laigemoxing') # FP16? (Note: half/mixed-precision is untested with the direction discovery code) if config['G_fp16']: print('Casting G to float16...') G = G.half() # print(G) print('Number of params in G: {}'.format( *[sum([p.data.nelement() for p in net.parameters()]) for net in [G]])) # Prepare state dict, which holds things like epoch # and itr # state_dict = { 'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config } # Load the pre-trained G_ema model as "G" if config['resume']: print('Loading weights...') utils.load_weights( None, None, state_dict, None, None, config['load_weights'] if config['load_weights'] else None, G, load_optim=False, strict=False, direct_path=config['G_path']) G.to(device) # Override G's optimizer to only optimize the direction matrix A: for param in G.parameters(): param.requires_grad = False G.optim = None G.eval() return G, state_dict, device, experiment_name
def run(config): config['resolution'] = utils.imsize_dict[config['dataset']] config['n_classes'] = utils.nclass_dict[config['dataset']] config['G_activation'] = utils.activation_dict[config['G_nl']] config['D_activation'] = utils.activation_dict[config['D_nl']] if config['resume']: config['skip_init'] = True config = utils.update_config_roots(config) device = 'cuda' utils.seed_rng(config['seed']) utils.prepare_root(config) torch.backends.cudnn.benchmark = True model = __import__(config['model']) experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) G = model.Generator(**config).to(device) D = model.Discriminator(**config).to(device) if config['ema']: G_ema = model.Generator(**{ **config, 'skip_init': True, 'no_optim': True }).to(device) ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start']) else: G_ema, ema = None, None if config['G_fp16']: G = G.half() if config['ema']: G_ema = G_ema.half() if config['D_fp16']: D = D.half() GD = model.G_D(G, D, config['conditional']) state_dict = { 'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config } if config['resume']: utils.load_weights( G, D, state_dict, config['weights_root'], experiment_name, config['load_weights'] if config['load_weights'] else None, G_ema if config['ema'] else None) utils.load_weights( G, D, state_dict, '../Task1_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'best0', G_ema if config['ema'] else None) if config['parallel']: GD = nn.DataParallel(GD) if config['cross_replica']: patch_replication_callback(GD) test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'], experiment_name) train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name) test_log = utils.MetricsLogger(test_metrics_fname, reinitialize=(not config['resume'])) train_log = utils.MyLogger(train_metrics_fname, reinitialize=(not config['resume']), logstyle=config['logstyle']) utils.write_metadata(config['logs_root'], experiment_name, config, state_dict) D_batch_size = (config['batch_size'] * config['num_D_steps'] * config['num_D_accumulations']) # Use: config['abnormal_class'] #print(config['abnormal_class']) #abnormal_class = config['abnormal_class'] #select_dataset = config['select_dataset'] #print(config['select_dataset']) #print(abnormal_class) #print(select_dataset) abnormal_class = config['abnormal_class'] select_dataset = config['select_dataset'] loaders = utils.get_data_loaders( **{ **config, 'batch_size': D_batch_size, 'start_itr': state_dict['itr'], 'abnormal_class': abnormal_class, 'select_dataset': select_dataset }) G_batch_size = max(config['G_batch_size'], config['batch_size']) z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) fixed_z.sample_() fixed_y.sample_() if not config['conditional']: fixed_y.zero_() y_.zero_() if config['which_train_fn'] == 'GAN': train = train_fns.GAN_training_function(G, D, GD, z_, y_, ema, state_dict, config) else: train = train_fns.dummy_training_function() sample = functools.partial( utils.sample, G=(G_ema if config['ema'] and config['use_ema'] else G), z_=z_, y_=y_, config=config) if config['dataset'] == 'C10U' or config['dataset'] == 'C10': data_moments = 'fid_stats_cifar10_train.npz' #'../Task1_CIFAR_MNIST_KLWGAN_Simulation_Experiment/fid_stats_cifar10_train.npz' #data_moments = '../Task1_CIFAR_MNIST_KLWGAN_Simulation_Experiment/fid_stats_cifar10_train.npz' else: print("Cannot find the image data set.") sys.exit() for epoch in range(state_dict['epoch'], config['num_epochs']): if config['pbar'] == 'mine': pbar = utils.progress(loaders[0], displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta') else: pbar = tqdm(loaders[0]) for i, (x, y) in enumerate(pbar): state_dict['itr'] += 1 G.train() D.train() if config['ema']: G_ema.train() if config['D_fp16']: x, y = x.to(device).half(), y.to(device) else: x, y = x.to(device), y.to(device) print('') # Random seed #print(config['seed']) if epoch == 0 and i == 0: print(config['seed']) # We double the learning rate if we double the batch size. metrics = train(x, y) train_log.log(itr=int(state_dict['itr']), **metrics) if (config['sv_log_interval'] > 0) and ( not (state_dict['itr'] % config['sv_log_interval'])): train_log.log(itr=int(state_dict['itr']), **{ **utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D') }) if config['pbar'] == 'mine': print(', '.join( ['itr: %d' % state_dict['itr']] + ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]), end=' ') if not (state_dict['itr'] % config['save_every']): if config['G_eval_mode']: G.eval() if config['ema']: G_ema.eval() train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, state_dict, config, experiment_name) experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) if (not (state_dict['itr'] % config['test_every'])) and ( epoch >= config['start_eval']): if config['G_eval_mode']: G.eval() if config['ema']: G_ema.eval() utils.sample_inception( G_ema if config['ema'] and config['use_ema'] else G, config, str(epoch)) folder_number = str(epoch) sample_moments = '%s/%s/%s/samples.npz' % ( config['samples_root'], experiment_name, folder_number) FID = fid_score.calculate_fid_given_paths( [data_moments, sample_moments], batch_size=50, cuda=True, dims=2048) train_fns.update_FID(G, D, G_ema, state_dict, config, FID, experiment_name, test_log) state_dict['epoch'] += 1 utils.save_weights(G, D, state_dict, config['weights_root'], experiment_name, 'last%d' % 0, G_ema if config['ema'] else None)
def run(config): # Update the config dict as necessary # This is for convenience, to add settings derived from the user-specified # configuration into the config-dict (e.g. inferring the number of classes # and size of the images from the dataset, passing in a pytorch object # for the activation specified as a string) config['resolution'] = utils.imsize_dict[config['dataset']] config['n_classes'] = utils.nclass_dict[config['dataset']] config['G_activation'] = utils.activation_dict[config['G_nl']] config['D_activation'] = utils.activation_dict[config['D_nl']] # By default, skip init if resuming training. if config['resume']: print('Skipping initialization for training resumption...') config['skip_init'] = True config = utils.update_config_roots(config) device = 'cuda' # Seed RNG utils.seed_rng(config['seed']) # Prepare root folders if necessary utils.prepare_root(config) # Setup cudnn.benchmark for free speed torch.backends.cudnn.benchmark = True # Import the model--this line allows us to dynamically select different files. model = __import__(config['model']) experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) print('Experiment name is %s' % experiment_name) # Next, build the model D = model.Discriminator(**config).to(device) # FP16? if config['D_fp16']: print('Casting D to fp16...') D = D.half() # Consider automatically reducing SN_eps? print(D) # Prepare state dict, which holds things like epoch # and itr # state_dict = {'itr': 0, 'epoch': 0, 'config': config} # If parallel, parallelize the GD module if config['parallel']: D = nn.DataParallel(D) if config['cross_replica']: patch_replication_callback(D) # Prepare loggers for stats; metrics holds test metrics, # lmetrics holds any desired training metrics. train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name) print('Training Metrics will be saved to {}'.format(train_metrics_fname)) train_log = utils.MyLogger(train_metrics_fname, reinitialize=(not config['resume']), logstyle=config['logstyle']) # set tensorboard logger tb_logdir = '%s/%s/tblogs' % (config['logs_root'], experiment_name) if os.path.exists(tb_logdir): for filename in os.listdir(tb_logdir): if filename.startswith('events'): os.remove(os.path.join(tb_logdir, filename)) # remove previous event logs tb_writer = SummaryWriter(log_dir=tb_logdir) # Write metadata utils.write_metadata(config['logs_root'], experiment_name, config, state_dict) # Prepare data; the Discriminator's batch size is all that needs to be passed # to the dataloader, as G doesn't require dataloading. # Note that at every loader iteration we pass in enough data to complete # a full D iteration (regardless of number of D steps and accumulations) D_batch_size = (config['batch_size'] * config['num_D_steps'] * config['num_D_accumulations']) loaders = utils.get_data_loaders(**{ **config, 'batch_size': D_batch_size, 'start_itr': state_dict['itr'] }) # Prepare inception metrics: FID and IS get_inception_metrics = inception_utils.prepare_inception_metrics( config['dataset'], config['parallel'], config['no_fid']) # Loaders are loaded, prepare the training function if config['which_train_fn'] == 'MINE': train = train_fns.MINE_training_function(D, state_dict, config) # Else, assume debugging and use the dummy train fn else: train = train_fns.dummy_training_function() print('Beginning training at epoch %d...' % state_dict['epoch']) # Train for specified number of epochs, although we mostly track G iterations. for epoch in range(state_dict['epoch'], config['num_epochs']): # Which progressbar to use? TQDM or my own (mine, ok)? if config['pbar'] == 'mine': pbar = utils.progress(loaders[0], displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta') else: pbar = tqdm(loaders[0]) for i, (x, y) in enumerate(pbar): # Increment the iteration counter state_dict['itr'] += 1 # Make sure G and D are in training mode, just in case they got set to eval # For D, which typically doesn't have BN, this shouldn't matter much. D.train() if config['D_fp16']: x, y = x.to(device).half(), y.to(device) else: x, y = x.to(device), y.to(device) metrics = train(x, y) print(metrics) train_log.log(itr=int(state_dict['itr']), **metrics) for metric_name in metrics: tb_writer.add_scalar('Train/%s' % metric_name, metrics[metric_name], state_dict['itr']) # If using my progbar, print metrics. if config['pbar'] == 'mine': print(', '.join( ['itr: %d' % state_dict['itr']] + ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]), end=' ') # Increment epoch counter at end of epoch state_dict['epoch'] += 1
def run(config): # Prepare state dict, which holds things like epoch # and itr # state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config} # Optionally, get the configuration from the state dict. This allows for # recovery of the config provided only a state dict and experiment name, # and can be convenient for writing less verbose sample shell scripts. if config['config_from_name']: utils.load_weights(None, None, state_dict, config['weights_root'], config['experiment_name'], config['load_weights'], None, strict=False, load_optim=False) # Ignore items which we might want to overwrite from the command line for item in state_dict['config']: if item not in ['z_var', 'base_root', 'batch_size', 'G_batch_size', 'use_ema', 'G_eval_mode']: config[item] = state_dict['config'][item] # Update the config dict as necessary # This is for convenience, to add settings derived from the user-specified # configuration into the config-dict (e.g. inferring the number of classes # and size of the images from the dataset, passing in a pytorch object # for the activation specified as a string) config['resolution'] = utils.imsize_dict[config['dataset']] config['n_classes'] = utils.nclass_dict[config['dataset']] config['G_activation'] = utils.activation_dict[config['G_nl']] config['D_activation'] = utils.activation_dict[config['D_nl']] # By default, skip init if resuming training. if config['resume']: print('Skipping initialization for training resumption...') config['skip_init'] = True config = utils.update_config_roots(config) device = 'cuda' # Seed RNG utils.seed_rng(config['seed']) # Prepare root folders if necessary utils.prepare_root(config) # Setup cudnn.benchmark for free speed torch.backends.cudnn.benchmark = True # Import the model--this line allows us to dynamically select different files. model = __import__(config['model']) experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) print('Experiment name is %s' % experiment_name) # Next, build the model G = model.Generator(**config).to(device) D = model.Discriminator(**config).to(device) utils.count_parameters(G) # Load weights print('Loading weights...') # Here is where we deal with the ema--load ema weights or load normal weights utils.load_weights(G if not (config['use_ema']) else None, None, state_dict, config['weights_root'], experiment_name, config['load_weights'], G if config['ema'] and config['use_ema'] else None, strict=False, load_optim=False) # Update batch size setting used for G G_batch_size = max(config['G_batch_size'], config['batch_size']) z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16'], z_var=config['z_var']) if config['G_eval_mode']: print('Putting G in eval mode..') G.eval() else: print('G is in %s mode...' % ('training' if G.training else 'eval')) #Sample function sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config) if config['accumulate_stats']: print('Accumulating standing stats across %d accumulations...' % config['num_standing_accumulations']) utils.accumulate_standing_stats(G, z_, y_, config['n_classes'], config['num_standing_accumulations']) # Sample a number of images and save them to an NPZ, for use with TF-Inception if config['sample_npz']: # Lists to hold images and labels for images x, y = [], [] print('Sampling %d images and saving them to npz...' % config['sample_num_npz']) for i in trange(int(np.ceil(config['sample_num_npz'] / float(G_batch_size)))): with torch.no_grad(): images, labels = sample() x += [np.uint8(255 * (images.cpu().numpy() + 1) / 2.)] y += [labels.cpu().numpy()] x = np.concatenate(x, 0)[:config['sample_num_npz']] y = np.concatenate(y, 0)[:config['sample_num_npz']] print('Images shape: %s, Labels shape: %s' % (x.shape, y.shape)) npz_filename = '%s/%s/samples.npz' % (config['samples_root'], experiment_name) print('Saving npz to %s...' % npz_filename) np.savez(npz_filename, **{'x' : x, 'y' : y}) # Prepare sample sheets if config['sample_sheets']: print('Preparing conditional sample sheets...') utils.sample_sheet(G, classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']], num_classes=config['n_classes'], samples_per_class=10, parallel=config['parallel'], samples_root=config['samples_root'], experiment_name=experiment_name, folder_number=config['sample_sheet_folder_num'], z_=z_,) # Sample interp sheets if config['sample_interps']: print('Preparing interp sheets...') for fix_z, fix_y in zip([False, False, True], [False, True, False]): utils.interp_sheet(G, num_per_sheet=16, num_midpoints=8, num_classes=config['n_classes'], parallel=config['parallel'], samples_root=config['samples_root'], experiment_name=experiment_name, folder_number=config['sample_sheet_folder_num'], sheet_number=0, fix_z=fix_z, fix_y=fix_y, device='cuda') # Sample random sheet if config['sample_random']: print('Preparing random sample sheet...') images, labels = sample() print("labels size", labels) torchvision.utils.save_image(images.float(), '%s/%s/random_samples.jpg' % (config['samples_root'], experiment_name), nrow=int(G_batch_size**0.5), normalize=True) # If using EMA, prepare it if config['ema']: print('Preparing EMA for G with decay of {}'.format(config['ema_decay'])) G_ema = model.Generator(**{**config, 'skip_init':True, 'no_optim': True}).to(device) ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start']) else: G_ema, ema = None, None # FP16? if config['G_fp16']: print('Casting G to float16...') G = G.half() if config['ema']: G_ema = G_ema.half() if config['D_fp16']: print('Casting D to fp16...') D = D.half() # Consider automatically reducing SN_eps? GD = model.G_D(G, D) #print(G) #print(D) print('Number of params in G: {} D: {}'.format( *[sum([p.data.nelement() for p in net.parameters()]) for net in [G,D]])) # Prepare state dict, which holds things like epoch # and itr # state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config} # If loading from a pre-trained model, load weights if config['resume']: print('Loading weights...') utils.load_weights(G, D, state_dict, config['weights_root'], experiment_name, config['load_weights'] if config['load_weights'] else None, G_ema if config['ema'] else None) # If parallel, parallelize the GD module if config['parallel']: GD = nn.DataParallel(GD) if config['cross_replica']: patch_replication_callback(GD) D_fake = D(images[1,:,:,:],labels[0]) print("D_fake ",D_fake)
def run(config): # By default, skip init if resuming training. if config['resume']: print('Skipping initialization for training resumption...') config['skip_init'] = True config = utils.update_config_roots(config) if config['cuda']: assert torch.cuda.is_available() device = torch.device('cuda') torch.backends.cudnn.benchmark = True else: device = torch.device('cpu') # Seed RNG utils.seed_rng(config['seed']) # Prepare root folders if necessary utils.prepare_root(config) # Import the model--this line allows us to dynamically select different files. experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) print('Experiment name is %s' % experiment_name) batch_size = config['batch_size'] state_dict = {'epoch': 0, 'iterations': 0, 'best_score': 0, 'final_score': 0, 'config': config} if config['trainer'] == 'DA_Poolnn': trainer = DA_Poolnn(config) trainer.to(device) if config['resume']: state_dict = utils.resume_state(os.path.join(config['weights_root'], experiment_name)) trainer.resume(os.path.join(config['weights_root'], experiment_name)) trainer.to(device) iterations = state_dict['iterations'] # config tensorboard writer log_folder = os.path.join(config['logs_root'], experiment_name) writer = SummaryWriter(log_folder) # load datasets train_dataset_specs = {'class_name': config['dataset'], 'seed': config['seed'], 'train': True, 'root': join(config['data_root'], config['dataset']), 'num_train': config['num_train'], 'num_domain': config['num_domain'], 'num_class': config['num_class'], 'dim': config['idim'], 'dim_d': config['dim_d'], 'dag_mat_file': config['dag_mat_file'], 'useMB': config['useMB'], 'tar_id': config['tar_id'], 'resolution': config['resolution']} train_loader = utils.get_data_loader(train_dataset_specs, batch_size, config['num_workers']) test_dataset_specs = {'class_name': config['dataset'], 'seed': config['seed'], 'train': False, 'root': join(config['data_root'], config['dataset']), 'num_train': config['num_train'], 'num_domain': config['num_domain'], 'num_class': config['num_class'], 'dim': config['idim'], 'dim_d': config['dim_d'], 'dag_mat_file': config['dag_mat_file'], 'useMB': config['useMB'], 'tar_id': config['tar_id'], 'resolution': config['resolution']} test_loader = utils.get_data_loader(test_dataset_specs, batch_size, config['num_workers']) # training best_score = state_dict['best_score'] for ep in range(state_dict['epoch'], config['num_epochs']): state_dict['epoch'] = ep test_acc_target_c = test_acc(trainer.dis, test_loader, device=device) writer.add_scalar('test_acc_target_ac', test_acc_target_c, ep) if test_acc_target_c > best_score: best_score = test_acc_target_c state_dict['best_score'] = best_score if ep == config['num_epochs'] - 1: state_dict['final_score'] = test_acc_target_c for it, (x, y) in enumerate(train_loader): if x.size(0) != batch_size: continue trainer.dis.train() x = x.to(device) y = y.to(device).view(x.size(0), 2) # trainer.gen_update(x, y, config, device) trainer.dis_update(x, y, config, device) # Dump training stats in log file if (iterations + 1) % config['save_every'] == 0: writer.add_scalar('aux_loss_c', trainer.aux_loss_c, iterations) if (iterations + 1) % config['display_every'] == 0: print("Epoch: %04d, Iteration: %08d, ce loss: %.2f" % (ep+1, iterations + 1, trainer.aux_loss_c)) # Save network weights if (iterations + 1) % config['save_every'] == 0: trainer.save(os.path.join(config['weights_root'], experiment_name), state_dict) iterations += 1 state_dict['iterations'] = iterations
def run(config): # By default, skip init if resuming training. if config['resume']: print('Skipping initialization for training resumption...') config['skip_init'] = True config = utils.update_config_roots(config) if config['cuda']: assert torch.cuda.is_available() device = torch.device('cuda') torch.backends.cudnn.benchmark = True else: device = torch.device('cpu') # Seed RNG utils.seed_rng(config['seed']) # Prepare root folders if necessary utils.prepare_root(config) # Import the model--this line allows us to dynamically select different files. experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) print('Experiment name is %s' % experiment_name) batch_size = config['batch_size'] num_class = config['num_class'] num_domain = config['num_domain'] dim_z = config['dim_z'] state_dict = { 'epoch': 0, 'iterations': 0, 'best_score': 0, 'final_score': 0, 'config': config } if config['trainer'] == 'DA_Infer_TAC': trainer = DA_Infer_TAC(config) trainer.to(device) if config['resume']: state_dict = utils.resume_state( os.path.join(config['weights_root'], experiment_name)) trainer.resume(os.path.join(config['weights_root'], experiment_name)) trainer.to(device) iterations = state_dict['iterations'] # config tensorboard writer log_folder = os.path.join(config['logs_root'], experiment_name) writer = SummaryWriter(log_folder) # load datasets train_dataset_specs = { 'class_name': config['dataset'], 'seed': config['seed'], 'train': True, 'root': join(config['data_root'], config['dataset']), 'num_train': config['num_train'], 'num_domain': config['num_domain'], 'num_class': config['num_class'], 'dim': config['idim'], 'dim_d': config['dim_d'], 'dag_mat_file': config['dag_mat_file'], 'useMB': config['useMB'], 'tar_id': config['tar_id'] } train_loader = utils.get_data_loader(train_dataset_specs, batch_size, config['num_workers']) test_dataset_specs = { 'class_name': config['dataset'], 'seed': config['seed'], 'train': False, 'root': join(config['data_root'], config['dataset']), 'num_train': config['num_train'], 'num_domain': config['num_domain'], 'num_class': config['num_class'], 'dim': config['idim'], 'dim_d': config['dim_d'], 'dag_mat_file': config['dag_mat_file'], 'useMB': config['useMB'], 'tar_id': config['tar_id'] } test_loader = utils.get_data_loader(test_dataset_specs, batch_size, config['num_workers']) # compute pairwise distance for kernel width pair_dist = pairwise_distances(train_loader.dataset.data) config['base_x'] = np.median(pair_dist) if config['is_reg']: pair_dist = pairwise_distances(train_loader.dataset.y) config['base_y'] = np.median(pair_dist) # get sample size in each domain do_ss = torch.zeros((num_domain, 1)) for do_i in range(num_domain): do_ss[do_i] = torch.Tensor([ (train_loader.dataset.labels[:, 1] == do_i).sum() ]).item() config['do_ss'] = do_ss # training Diters = config['num_D_steps'] Giters = config['num_G_steps'] Diter = 0 Giter = Giters best_score = state_dict['best_score'] for ep in range(state_dict['epoch'], config['num_epochs']): state_dict['epoch'] = ep test_acc_target_c = test_acc(trainer.dis, test_loader, device=device) writer.add_scalar('test_acc_target_ac', test_acc_target_c, ep) # writer.add_scalar('test_acc_target_tac', test_acc_target_ct, ep) if test_acc_target_c > best_score: best_score = test_acc_target_c state_dict['best_score'] = best_score if ep == config['num_epochs'] - 1: state_dict['final_score'] = test_acc_target_c for it, (x, y) in enumerate(train_loader): if x.size(0) != batch_size: continue trainer.gen.train() trainer.dis.train() x = x.to(device) y = y.to(device).view(x.size(0), 2) if Diter < Diters: trainer.dis_update(x, y, config, device) Diter += 1 if Diter == Diters: Giter = 0 if Giter < Giters and Diter > Diters: trainer.gen_update(x, y, config, device) Giter += 1 if Giter == Giters: Diter = 0 if Diter == Diters: Diter += 1 # Dump training stats in log file if (iterations + 1) % config['save_every'] == 0: writer.add_scalar('gan_loss', trainer.gan_loss, iterations) writer.add_scalar('aux_loss_c_fake', trainer.aux_loss_c, iterations) writer.add_scalar('aux_loss_c_real', trainer.aux_loss_c1, iterations) writer.add_scalar('aux_loss_ct', trainer.aux_loss_ct, iterations) writer.add_scalar('aux_loss_d_fake', trainer.aux_loss_d, iterations) writer.add_scalar('aux_loss_d_real', trainer.aux_loss_d1, iterations) writer.add_scalar('aux_loss_dt', trainer.aux_loss_dt, iterations) writer.add_scalar('aux_loss_cls_real', trainer.aux_loss_cls1, iterations) writer.add_scalar('aux_loss_cls_fake', trainer.aux_loss_cls, iterations) if (iterations + 1) % config['display_every'] == 0: print( "Epoch: %04d, Iteration: %08d, gan loss: %.2f, ac_fake: %.2f, ac_real: %.2f, ac_twin: %.2f, ad_fake: %.2f, " "ad_real: %.2f, ad_twin: %.2f, cls_real: %.2f, cls_fake: %.2f" % (ep + 1, iterations + 1, trainer.gan_loss, trainer.aux_loss_c, trainer.aux_loss_c1, trainer.aux_loss_ct, trainer.aux_loss_d, trainer.aux_loss_d1, trainer.aux_loss_dt, trainer.aux_loss_cls1, trainer.aux_loss_cls)) # Save network weights if (iterations + 1) % config['save_every'] == 0: trainer.save( os.path.join(config['weights_root'], experiment_name), state_dict) iterations += 1 state_dict['iterations'] = iterations # print variational parameters if config['estimate'] == 'Bayesian': print(trainer.gen.mu.squeeze()) print(trainer.gen.sigma.squeeze())
def run(config): # Update the config dict as necessary # This is for convenience, to add settings derived from the user-specified # configuration into the config-dict (e.g. inferring the number of classes # and size of the images from the dataset, passing in a pytorch object # for the activation specified as a string) config['resolution'] = utils.imsize_dict[config['dataset']] config['n_classes'] = utils.nclass_dict[config['dataset']] config['G_activation'] = utils.activation_dict[config['G_nl']] config['D_activation'] = utils.activation_dict[config['D_nl']] # By default, skip init if resuming training. if config['resume']: print('Skipping initialization for training resumption...') config['skip_init'] = True config = utils.update_config_roots(config) device = 'cuda' # Seed RNG utils.seed_rng(config['seed']) # Prepare root folders if necessary utils.prepare_root(config) # Setup cudnn.benchmark for free speed torch.backends.cudnn.benchmark = True # Import the model--this line allows us to dynamically select different files. model = __import__(config['model']) experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) print('Experiment name is %s' % experiment_name) # Next, build the model G = model.Generator(**config).to(device) D = model.Discriminator(**config).to(device) # If using EMA, prepare it if config['ema']: print('Preparing EMA for G with decay of {}'.format( config['ema_decay'])) G_ema = model.Generator(**{ **config, 'skip_init': True, 'no_optim': True }).to(device) ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start']) else: ema = None # FP16? if config['G_fp16']: print('Casting G to float16...') G = G.half() if config['ema']: G_ema = G_ema.half() if config['D_fp16']: print('Casting D to fp16...') D = D.half() # Consider automatically reducing SN_eps? GD = model.G_D(G, D) print(G) print(D) print('Number of params in G: {} D: {}'.format( * [sum([p.data.nelement() for p in net.parameters()]) for net in [G, D]])) # Prepare state dict, which holds things like epoch # and itr # state_dict = { 'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config } # If loading from a pre-trained model, load weights if config['resume']: print('Loading weights...') utils.load_weights( G, D, state_dict, config['weights_root'], experiment_name, config['load_weights'] if config['load_weights'] else None, G_ema if config['ema'] else None) # If parallel, parallelize the GD module if config['parallel']: GD = nn.DataParallel(GD) if config['cross_replica']: patch_replication_callback(GD) # Prepare loggers for stats; metrics holds test metrics, # lmetrics holds any desired training metrics. test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'], experiment_name) train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name) print('Inception Metrics will be saved to {}'.format(test_metrics_fname)) test_log = utils.MetricsLogger(test_metrics_fname, reinitialize=(not config['resume'])) print('Training Metrics will be saved to {}'.format(train_metrics_fname)) train_log = utils.MyLogger(train_metrics_fname, reinitialize=(not config['resume']), logstyle=config['logstyle']) # Write metadata utils.write_metadata(config['logs_root'], experiment_name, config, state_dict) # Prepare data; the Discriminator's batch size is all that needs to be passed # to the dataloader, as G doesn't require dataloading. # Note that at every loader iteration we pass in enough data to complete # a full D iteration (regardless of number of D steps and accumulations) D_batch_size = (config['batch_size'] * config['num_D_steps'] * config['num_D_accumulations']) loaders = utils.get_data_loaders(**{ **config, 'batch_size': D_batch_size, 'start_itr': state_dict['itr'] }) # Prepare inception metrics: FID and IS get_inception_metrics = inception_utils.prepare_inception_metrics( config['dataset'], config['parallel'], config['no_fid']) # Prepare noise and randomly sampled label arrays # Allow for different batch sizes in G G_batch_size = max(config['G_batch_size'], config['batch_size']) z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) # Prepare a fixed z & y to see individual sample evolution throghout training fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) fixed_z.sample_() fixed_y.sample_() # Loaders are loaded, prepare the training function if config['which_train_fn'] == 'GAN': train = train_fns.GAN_training_function(G, D, GD, z_, y_, ema, state_dict, config) # Else, assume debugging and use the dummy train fn else: train = train_fns.dummy_training_function() # Prepare Sample function for use with inception metrics sample = functools.partial( utils.sample, G=(G_ema if config['ema'] and config['use_ema'] else G), z_=z_, y_=y_, config=config) print('Beginning training at epoch %d...' % state_dict['epoch']) # Train for specified number of epochs, although we mostly track G iterations. for epoch in range(state_dict['epoch'], config['num_epochs']): # Which progressbar to use? TQDM or my own? if config['pbar'] == 'mine': pbar = utils.progress(loaders[0], displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta') else: pbar = tqdm(loaders[0]) # import pdb;pdb.set_trace() for i, (x, y) in enumerate(pbar): # Increment the iteration counter state_dict['itr'] += 1 # Make sure G and D are in training mode, just in case they got set to eval # For D, which typically doesn't have BN, this shouldn't matter much. G.train() D.train() if config['ema']: G_ema.train() if config['D_fp16']: x, y = x.to(device).half(), y.to(device) else: x, y = x.to(device), y.to(device) metrics = train(x, y) train_log.log(itr=int(state_dict['itr']), **metrics) # Every sv_log_interval, log singular values if (config['sv_log_interval'] > 0) and ( not (state_dict['itr'] % config['sv_log_interval'])): train_log.log(itr=int(state_dict['itr']), **{ **utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D') }) # If using my progbar, print metrics. if config['pbar'] == 'mine': print(', '.join( ['itr: %d' % state_dict['itr']] + ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]), end=' ') # Save weights and copies as configured at specified interval if not (state_dict['itr'] % config['save_every']): if config['G_eval_mode']: print('Switchin G to eval mode...') G.eval() if config['ema']: G_ema.eval() train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, state_dict, config, experiment_name) # Test every specified interval if not (state_dict['itr'] % config['test_every']): if config['G_eval_mode']: print('Switchin G to eval mode...') G.eval() train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample, get_inception_metrics, experiment_name, test_log) # Increment epoch counter at end of epoch state_dict['epoch'] += 1
config['G_activation'] = utils.activation_dict[config['G_nl']] config['D_activation'] = utils.activation_dict[config['D_nl']] # By default, skip init if resuming training. if config['resume']: print('Skipping initialization for training resumption...') config['skip_init'] = True config = utils.update_config_roots(config) device = 'cuda' # Seed RNG utils.seed_rng(config['seed']) # Prepare root folders if necessary utils.prepare_root(config) # Import the model--this line allows us to dynamically select different files. model = __import__(config['model']) for torch_fn in glob("%s/*pth" % input_weight_folder): if "optim" in torch_fn: continue # skip optimizer file if "best3" not in torch_fn: continue import torch as pytorch torch_state_dict = pytorch.load(torch_fn) # Next, build the model print(torch_fn)
def run(config): # Update the config dict as necessary # This is for convenience, to add settings derived from the user-specified # configuration into the config-dict (e.g. inferring the number of classes # and size of the images from the dataset, passing in a pytorch object # for the activation specified as a string) config['resolution'] = 64 config['n_classes'] = 120 config['G_activation'] = utils.activation_dict[config['G_nl']] config['D_activation'] = utils.activation_dict[config['D_nl']] # By default, skip init if resuming training. if config['resume']: print('Skipping initialization for training resumption...') config['skip_init'] = True config = utils.update_config_roots(config) device = 'cuda' # Seed RNG utils.seed_rng(config['seed']) # Prepare root folders if necessary utils.prepare_root(config) # Setup cudnn.benchmark for free speed torch.backends.cudnn.benchmark = True experiment_name = (config['experiment_name'] if config['experiment_name'] else 'generative_dog_images') print('Experiment name is %s' % experiment_name) G = BigGAN.Generator(**config).to(device) D = BigGAN.Discriminator(**config).to(device) # If using EMA, prepare it if config['ema']: print('Preparing EMA for G with decay of {}'.format( config['ema_decay'])) G_ema = BigGAN.Generator(**{ **config, 'skip_init': True, 'no_optim': True }).to(device) ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start']) else: G_ema, ema = None, None GD = BigGAN.G_D(G, D) print(G) print(D) print('Number of params in G: {} D: {}'.format( * [sum([p.data.nelement() for p in net.parameters()]) for net in [G, D]])) # Prepare state dict, which holds things like epoch # and itr # state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'config': config} # If loading from a pre-trained model, load weights if config['resume']: print('Loading weights...') utils.load_weights( G, D, state_dict, config['weights_root'], experiment_name, config['load_weights'] if config['load_weights'] else None, G_ema if config['ema'] else None) # Prepare data; the Discriminator's batch size is all that needs to be passed # to the dataloader, as G doesn't require dataloading. # Note that at every loader iteration we pass in enough data to complete # a full D iteration (regardless of number of D steps and accumulations) D_batch_size = (config['batch_size'] * config['num_D_steps'] * config['num_D_accumulations']) loaders = dataset.get_data_loaders(data_root=config['data_root'], label_root=config['label_root'], batch_size=D_batch_size, num_workers=config['num_workers'], shuffle=config['shuffle'], pin_memory=config['pin_memory'], drop_last=True) # Prepare noise and randomly sampled label arrays # Allow for different batch sizes in G G_batch_size = max(config['G_batch_size'], config['batch_size']) z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) # Prepare a fixed z & y to see individual sample evolution throghout training fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) fixed_z.sample_() fixed_y.sample_() # Loaders are loaded, prepare the training function train = train_fns.create_train_fn(G, D, GD, z_, y_, ema, state_dict, config) print('Beginning training at epoch %d...' % state_dict['epoch']) start_time = time.perf_counter() total_iters = config['num_epochs'] * len(loaders[0]) # Train for specified number of epochs, although we mostly track G iterations. for epoch in range(state_dict['epoch'], config['num_epochs']): for i, (x, y) in enumerate(loaders[0]): # Increment the iteration counter state_dict['itr'] += 1 # Make sure G and D are in training mode, just in case they got set to eval # For D, which typically doesn't have BN, this shouldn't matter much. G.train() D.train() if config['ema']: G_ema.train() x, y = x.to(device), y.to(device) metrics = train(x, y) if not (state_dict['itr'] % config['log_interval']): curr_time = time.perf_counter() curr_time_str = datetime.datetime.fromtimestamp( curr_time).strftime('%H:%M:%S') elapsed = str( datetime.timedelta(seconds=(curr_time - start_time))) log = ("[{}] [{}] [{} / {}] Ep {}, ".format( curr_time_str, elapsed, state_dict['itr'], total_iters, epoch) + ', '.join([ '%s : %+4.3f' % (key, metrics[key]) for key in metrics ])) print(log) # Save weights and copies as configured at specified interval if not (state_dict['itr'] % config['save_every']): if config['G_eval_mode']: print('Switching G to eval mode...') G.eval() # if config['ema']: # G_ema.eval() train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, state_dict, config, experiment_name) # Increment epoch counter at end of epoch state_dict['epoch'] += 1
def run(config): # Update the config dict as necessary # This is for convenience, to add settings derived from the user-specified # configuration into the config-dict (e.g. inferring the number of classes # and size of the images from the dataset, passing in a pytorch object # for the activation specified as a string) config['resolution'] = utils.imsize_dict[config['dataset']] config['n_classes'] = utils.nclass_dict[config['dataset']] config['G_activation'] = utils.activation_dict[config['G_nl']] config['D_activation'] = utils.activation_dict[config['D_nl']] # By default, skip init if resuming training. if config['resume']: print('Skipping initialization for training resumption...') config['skip_init'] = True config = utils.update_config_roots(config) device = 'cuda' # Seed RNG utils.seed_rng(config['seed']) # Prepare root folders if necessary utils.prepare_root(config) # Setup cudnn.benchmark for free speed torch.backends.cudnn.benchmark = True # Import the model--this line allows us to dynamically select different files. model = __import__(config['model']) experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) experiment_name = "test_{}".format(experiment_name) print('Experiment name is %s' % experiment_name) # Next, build the model G = model.Generator(**config).to(device) D = model.Discriminator(**config).to(device) E = model.ImgEncoder(**config).to(device) # If using EMA, prepare it if config['ema']: print('Preparing EMA for G with decay of {}'.format( config['ema_decay'])) G_ema = model.Generator(**{ **config, 'skip_init': True, 'no_optim': True }).to(device) ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start']) else: G_ema, ema = None, None # FP16? if config['G_fp16']: print('Casting G to float16...') G = G.half() if config['ema']: G_ema = G_ema.half() if config['D_fp16']: print('Casting D to fp16...') D = D.half() # Consider automatically reducing SN_eps? GDE = model.G_D_E(G, D, E) # print(G) # print(D) # print(E) print("Model Created!") print('Number of params in G: {} D: {} E: {}'.format(*[ sum([p.data.nelement() for p in net.parameters()]) for net in [G, D, E] ])) # Prepare state dict, which holds things like epoch # and itr # state_dict = { 'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config } # If loading from a pre-trained model, load weights print('Loading weights...') utils.load_weights( G, D, E, state_dict, config['weights_root'], config['load_experiment_name'], config['load_weights'] if config['load_weights'] else None, G_ema if config['ema'] else None) # If parallel, parallelize the GD module if config['parallel']: GDE = nn.DataParallel(GDE) if config['cross_replica']: patch_replication_callback(GDE) G_batch_size = max(config['G_batch_size'], config['batch_size']) D_batch_size = (config['batch_size'] * config['num_D_steps'] * config['num_D_accumulations']) loaders, train_dataset = utils.get_data_loaders( **{ **config, 'batch_size': D_batch_size, 'start_itr': state_dict['itr'] }) z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) # Prepare a fixed z & y to see individual sample evolution throghout training fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16']) fixed_z.sample_() fixed_y.sample_() print("fixed_y original: {} {}".format(fixed_y.shape, fixed_y[:10])) fixed_x, fixed_y_of_x = utils.prepare_x_y(G_batch_size, train_dataset, experiment_name, config) # Prepare Sample function for use with inception metrics sample = functools.partial( utils.sample, G=(G_ema if config['ema'] and config['use_ema'] else G), z_=z_, y_=y_, config=config) G.eval() E.eval() print("check1 -------------------------------") activation_extract(G, D, E, G_ema, fixed_x, fixed_y_of_x, z_, y_, state_dict, config, experiment_name, save_weights=False)