def main(): # parse command line and run torch.autograd.set_detect_anomaly(True) parser = utils.prepare_parser() config = vars(parser.parse_args()) print(config) run(config)
def get_config(args): parser = utils.prepare_parser() parser = utils.add_sample_parser(parser) config = vars(parser.parse_args()) resolution = utils.imsize_dict[args.dataset] attn_dict = {128: '64', 256: '128', 512: '64'} dim_z_dict = {128: 120, 256: 140, 512: 128} # See: https://github.com/ajbrock/BigGAN-PyTorch/blob/master/scripts/sample_BigGAN_bs256x8.sh. config["resolution"] = resolution config["n_classes"] = utils.nclass_dict[args.dataset] config["G_activation"] = utils.activation_dict["inplace_relu"] config["D_activation"] = utils.activation_dict["inplace_relu"] config["G_attn"] = attn_dict[resolution] config["D_attn"] = "64" config["G_ch"] = 96 config["D_ch"] = 96 config["hier"] = True config["dim_z"] = dim_z_dict[resolution] config["shared_dim"] = 128 config["G_shared"] = True config = utils.update_config_roots(config) config["skip_init"] = True config["no_optim"] = True config["device"] = "cuda" return config
def main(): parser = utils.prepare_parser() parser = utils.add_dgp_parser(parser) config = vars(parser.parse_args()) utils.dgp_update_config(config) #Print parameters cat = [ 'model', 'dgp_mode', 'list_file', 'exp_path', 'root_dir', 'resolution', 'random_G', 'update_G', 'custom_mask' ] for key, val in config.items(): if key in cat: print(key, ":", str(val)) if config['custom_mask']: config['mask_path'] = '../data/input/' + config['mask_path'] print('mask_path :', config['mask_path']) rank = 0 if mp.get_start_method(allow_none=True) != 'spawn': mp.set_start_method('spawn', force=True) if config['dist']: rank, world_size = dist_init(config['port']) # Seed RNG utils.seed_rng(rank + config['seed']) # Setup cudnn.benchmark for free speed torch.backends.cudnn.benchmark = True # train trainer = Trainer(config) trainer.run()
def main(): # parse command line and run parser = utils.prepare_parser() parser = utils.add_sample_parser(parser) config = vars(parser.parse_args()) print(config) run(config)
def _mp_fn(index): torch.set_default_tensor_type('torch.FloatTensor') # parse command line and run parser = utils.prepare_parser() config = vars(parser.parse_args()) print(config) run(config)
def main(): # parse command line and run parser = utils.prepare_parser() config = vars(parser.parse_args()) print(config) args = argparse.Namespace(outdir='results/tmp') myargs = argparse.Namespace(stdout=sys.stdout) run(config, args, myargs)
def main(): # parse command line and run parser = utils.prepare_parser() update_parser_defaults_from_yaml(parser) config = vars(parser.parse_args()) print(config) run(config)
def get_config(): parser = utils.prepare_parser() config = vars(parser.parse_args()) config['resolution'] = utils.imsize_dict[config['dataset']] config['n_classes'] = utils.nclass_dict[config['dataset']] config["experiment_name"] = input("Please input experiment_name: ") config["model_parameter_name"] = find_best_model_parameter( config["experiment_name"]) return config
def main(): parser = utils.prepare_parser() config = vars(parser.parse_args()) time = datetime.now() time = time.strftime("%m-%d-%Y-%H:%M:%S") # state_dict = {'itr': 0, 'epoch': 0, 'best_epoch': 0, 'best_test_loss': 9999999, # 'config': config} if config['exp_name'] == '' and not config['resume']: config['exp_name'] = '{}_model_{}_bs_{}_lr_{}'.format(time, config['model'], config['batch_size'], config['lr']) print('======>new exp name: {}\n\n'.format(config['exp_name'])) elif (not len(config['exp_name']) == 0): pass else: print(config['exp_name']) print(config['resume']) raise Exception('Set up experiment wrong!') if config['model'] =='ours': for _ in range(3): plot_dict={} plot_dict['roc_auc'] = [] plot_dict['auprc'] = [] for feat_mult in [300,320,384,512,768]: config['feature_multiplier'] = feat_mult with open("{}{}_results_log.txt".format(config['ckpts_path'], config['exp_name']), "a+") as file: file.write('\n\n\n\n\n\n\n\n======>>>>>>{}\n\n'.format(config)) file.write('No Maxpool1d\n') print(config) print('\n\n') plot_dict = run(config, plot_dict) with open('{}{}_{}_with_Maxpool_data_for_plot.txt'.format(config['ckpts_path'], config['exp_name'], _), 'a+') as file: file.write(json.dumps(plot_dict)) file.write('\n\n\n') #json.dump(plot_dict, open('{}{}_{}_json_data_for_plot.txt'.format(config['ckpts_path'], config['exp_name'], _), 'a+')) elif config['model'] == 'danq': for _ in range(3): plot_dict = {} plot_dict['roc_auc'] = [] plot_dict['auprc'] = [] for feat_mult in [8,16,48,96,128,192,256,300,320]: #[8, 16, 48, 96, 128, 192, 256, 300, 320, 360, 420]: config['feature_multiplier'] = feat_mult with open("{}{}_results_log.txt".format(config['ckpts_path'], config['exp_name']), "a+") as file: file.write('\n\n\n\n======>>>>>>{}\n\n'.format(config)) file.write('With danq\n') print('\n\n') print(config) plot_dict = run(config, plot_dict) with open('{}{}_{}_data_for_plot.txt'.format(config['ckpts_path'], config['exp_name'], _), 'a+') as file: file.write(json.dumps(plot_dict)) file.write('\n\n\n') else: raise Exception('choose the correct model..')
def main(): # parse command line and run parser = utils.prepare_parser() config = vars(parser.parse_args()) print(config) path_torch_home = os.path.join(config['root_path'], 'torch_cache') os.makedirs(path_torch_home, exist_ok=True) os.environ['TORCH_HOME'] = path_torch_home run(config)
def main(): # parse command line and run parser = utils.prepare_parser() parser = utils.add_sample_parser(parser) config = vars(parser.parse_args()) # print("reached_main") # for item in config: # print(item, config[item]) run(config)
def main(): # parse command line and run parser = utils.prepare_parser() update_parser_defaults_from_yaml(parser) args = parser.parse_args() args.base_root = os.path.join(args.tl_outdir, 'biggan') config = EasyDict(vars(args)) config_str = get_dict_str(config) logger = logging.getLogger('tl') logger.info(config_str) run(config)
def main(): # parse command line and run parser = utils.prepare_parser() parser = utils.add_sample_parser(parser) config = vars(parser.parse_args()) print(config) if config['sample_multiple']: suffixes = config['load_weights'].split(',') for suffix in suffixes: config['load_weights'] = suffix run(config) else: run(config)
def main(): logger = logging.getLogger('tl') # parse command line parser = utils.prepare_parser() update_parser_defaults_from_yaml(parser) args = parser.parse_args() args.base_root = os.path.join(args.tl_outdir, 'biggan') opt = EasyDict(vars(args)) logger.info(f"\nglobal_cfg: \n" + get_dict_str(global_cfg)) global_cfg.dump_to_file_with_command( f"{opt.tl_outdir}/config_command.yaml", command=opt.tl_command) run(opt)
def main(): parser = utils.prepare_parser() parser = utils.add_dgp_parser(parser) config = vars(parser.parse_args()) utils.dgp_update_config(config) print(config) rank = 0 if mp.get_start_method(allow_none=True) != 'spawn': mp.set_start_method('spawn', force=True) if config['dist']: rank, world_size = dist_init(config['port']) # Seed RNG utils.seed_rng(rank + config['seed']) # Setup cudnn.benchmark for free speed torch.backends.cudnn.benchmark = True # train trainer = Trainer(config) trainer.run()
weight_decay=0.0, d_penalty=config['d_penalty'], g_penalty=0, noise_shape=(64, config['z_dim']), gp_weight=config['gp_weight']) # d_penalty: L2 penalty on discriminator; # gp_weight: gradient penalty weight # trainer.load_checkpoint(chkpt_path='') trainer.train_bcgd(is_flag=config['eval_is'], fid_flag=config['eval_fid'], epoch_num=config['epoch_num'], mode='ACGD', collect_info=config['collect_info'], dataname=config['dataset'], logname=config['logdir'], loss_type=config['loss_type']) # Loss type: JSD, WGAN # trainer.train_bcgd(epoch_num=120, mode='ACGD', collect_info=True, dataname='CIFAR10-WGAN', logname='CIFAR10-WGAN', loss_type='WGAN') # trainer.train_gd(epoch_num=600, mode=modes[3], dataname='CIFAR10-JSD', logname='CIFAR10-JSD', loss_type='JSD') if __name__ == '__main__': torch.backends.cudnn.benchmark = True parser = prepare_parser() config = vars(parser.parse_args()) print(config) train_wgan(config) # train_mnist() # train_cifar()
def main(): # parse command line and run parser = utils.prepare_parser() config = vars(parser.parse_args()) #print(config) run(config)
def main(): parser = utils.prepare_parser() config = vars(parser.parse_args()) print(config) run(config)
from common import * import utils import dataset # os.makedirs('sample_dc_gan', exist_ok=True) # Set random seed for reproducibility manualSeed = 999 #manualSeed = random.randint(1, 10000) # use if you want new results print("Random Seed: ", manualSeed) random.seed(manualSeed) torch.manual_seed(manualSeed) parser = utils.prepare_parser() config = vars(parser.parse_args()) loaders = dataset.get_data_loaders( data_root=config['data_root'], label_root=config['label_root'], batch_size=config['batch_size'], num_workers=config['num_workers'], shuffle=config['shuffle'], pin_memory=config['pin_memory'], drop_last=True, load_in_mem=config['load_in_mem'], mask_out=True, ) image_size = IMG_SIZE nc = 3
def main(): parser = utils.prepare_parser() config = vars(parser.parse_args()) # See the configuration. #print(config) run(config)
'alexnet', # 224x224 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', # 224x224 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', # 224x224 'squeezenet_v0', 'squeezenet_v1', #224x224 'inception_v3', # 299x299 ] args = utils.prepare_parser() args.gpu = misc.auto_select_gpu(utility_bound=0, num_gpu=args.ngpu, selected_gpus=args.gpu) args.ngpu = len(args.gpu) misc.ensure_dir(args.logdir) # ensure or create logdir args.model_root = misc.expand_user(args.model_root) args.data_root = misc.expand_user(args.data_root) args.input_size = 299 if 'inception' in args.type else args.input_size assert args.quant_method in ['linear', 'minmax', 'log', 'tanh', 'scale'] print("=================PARSER==================") for k, v in args.__dict__.items(): print('{}: {}'.format(k, v)) print("========================================")
def main(): # parse command line and run # 外部运行主程序入口时候可以输入参数,控制计算模式 parser = utils.prepare_parser() ## vars() 函数返回对象object的属性和属性值的字典对象。 ## 将parser字典化 ## 数据集dataset 默认为 I128_hdf5 ## 是否采用数据增强augment 默认为 0 ## num_workers 加速数据读取的加载,默认为8, 应小于HDF5 TODO 待查为啥要小于某个数据格式 ## no_pin_memory 作者也不确定,暂时不管,默认为FALSE TODO ## shuffle 数据还是需要shuffle,万一数据有聚集BATCH的情况,模型就疯了,默认为True ## load_in_mem 作者也不确定,估计是用来加速数据加载速度的 TODO ## use_multiepoch_sampler 使用多回合采样方法,默认为True,TODO multi-epoch sampler 这个东西一定要研究一下,采样相当重要,这个不知道是啥 ## model 默认使用 BigGAN ## G_param 生成模型使用图像归一方法,默认使用谱归一化,还可以选择SVD或者None,TODO 利用代码加深一下SN和SVD的区别 ## D_param 判别模型使用图像归一方法,默认使用谱归一化,还可以选择SVD或者None,TODO 利用代码加深一下SN和SVD的区别 ## resolution 表示结构代码中的通道数,默认128可解析为如下结构: # G_ch 表示结构代码中的通道数,即ch的值 # arch[128] = {'in_channels' : [ch * item for item in [16, 16, 8, 4, 2]], # 'out_channels' : [ch * item for item in [16, 8, 4, 2, 1]], # 'upsample' : [True] * 5, # 'resolution' : [8, 16, 32, 64, 128], # 'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')]) # for i in range(3,8)}} ## D_ch 辨别模型的信道 默认64 同上 ## G_depth 每个阶段G的resblocks的数量 TODO 盲猜和ResNet结构有关系 ## D_depth 每个阶段D的resblocks的数量 ## TODO D_thin和D_wide是干啥的?默认都是FALSE ## G_shared TODO Use shared embeddings in G 这个需要细看一下是怎么操作的 ## shared_dim TODO 'G''s shared embedding dimensionality; if 0, will be equal to dim_z. ## dim_z 噪声的维度,默认为128 ## z_var 噪声的标准差,premiere为1 ## hier 这个是使用多层噪声 # 以arch[64]为例子, 插入的层数 num_slots 为 [16, 16, 8, 4] 的len + 1 # 每一块插入的尺寸 z_chunk_size 为 dim_z // num_slots,即把噪声的维度 与 插入层数 取模 # z_chunk_size是直接参与前向传播 forward 和 which_bn 应该是决定哪一层使用BatchNormalization # 重塑噪声维度,让其在每一层插入的维度是相同的 dim_z = z_chunk_size * num_slots #~ 如果不使用多层噪声,那么就把 num_slots 的数量置为1 # 同时把 z_chunk_size 置于 0 ## cross_replica TODO 这个是啥?把G模型的batchnorm再复制一遍吗 ## 使用默认的归一化方法 mybn,看一下为啥 ## G_nl&D_nl的激活函数 ## G_attn和D_attn是否使用attention机制 ## norm_style 使用归一化方法,CNN还是使用BN比较好,四者之间的区别可以百度 ## seed 随机数种子,默认为0,控制初始化参数和数据读取的 ## G_init 生成模型初始化方法 TODO !!! 要知晓一下ortho是什么类型的初始化方法 使初始化的参数矩阵是正交规范化的 # https://zhuanlan.zhihu.com/p/98873800 # 正交规范化的矩阵能够缓解梯度消失和爆炸,具体看BigGAN中的援引文献, # 原为L1范数,BigGAN中使用L2范数进行改进 ## D_init 判别模型初始化方法 ## skip_init 跳过初始化 TODO 为什么跳过初始化,就是ideal的testing? ## G_lr 生成模型的学习率 ## D_lr 辨别模型的学习率 ## G_B1 & D_B1 beta1 % G_B2 & D_B2 beta2 TODO 属于GAN中的什么类型的参数 ## batch_size 批大小 64 ## G_batch_size TODO 为何还要写G的batch size?? ## num_G_accumulations TODO 把G的梯度相加又是为了什么?? ## num_D_steps TODO 每一步G需要跑几次D,D需要多训练训练,如果G跑太前面就太逼近原图了? ## num_D_accumulations TODO 把D的梯度相加又是为了什么?? ## split_D 运行D两次,而不是连接输入 TODO ? ## num_epochs @@ 训练的轮次数量 ## parallel 默认是FALSE 训练时是否使用多GPUs一起 ## G_fp16 & D_fp16 & G_mixed_precision & D_mixed_precision 对于精度的使用,一个是加快计算,一个减小模型大小.有的时候会带一点准确率的损失,所以mix表示预测的时候用更高的精度 ## accumulate_stats 默认FALSE TODO 积累统计数据 ## num_standing_accumulations TODO 这个和楼上那个什么关系?!!! #@@ 接下来是运行模式 ## G_eval_mode TODO 每次采样或者测试的时候,都评估一下G? ## save_every 每2000迭代储存一下模型 ## num_save_copies TODO 储几个备份?默认2个 ## num_best_copies 存几个最好的版本,默认2个 ## which_best 依据IS或者FID作为评价指标去储存模型 ## no_fid 是否只计算IS,不计算FID。默认都计算 ## test_every 没多少论测试一下,默认5000 ## num_inception_images TODO 用于计算初始度量的样本数量50000 ## hashname 是否使用HasName而不是配置中的类别,默认不使用 ## base_root 默认储存所有权重,采样,数据,日志 ## data_root TODO !!!数据默认存在哪里 ## weights_root,logs_root,samples_root 本地存在哪儿 ## pbar 是否使用进度条,默认使用mine ## name_suffix 命名添加后缀 ## experiment_name 自定义存储实验名称 ## config_from_name 是否使用hash实验名 #@@ EMA 对权重进行指数滑动平均处理的工具群 ## ema 是否保存G的ema参数 ## ema_decay EMA的衰减率 ## use_ema 是否在G中使用评估 ## ema_start 什么时候去用EMA方法更新权重 #@@ SV stuff 奇异值的迭代性质 ## adam_eps adam的epsilon value,TODO 看一下它在干嘛 ## BN_eps & SN_eps 批归一化和谱归一化的参数 ## num_G_SVs G的奇异值追踪数量,默认为1 ## num_D_SVs D的奇异值追踪数量,默认为1 ## num_G_SV_itrs G的迭代次数,默认为1 ## num_D_SV_itrs D的迭代次数,默认为1 #@@ Ortho reg stuff 这个又是什么 ## 控制Ortho的迭代性质 ## which_train_fn 默认是GAN ## load_weights 加载哪类参数,是copy的或者历史最好的etc. ## resume TODO Resume training是从历史记录开始训练,可以用于预训练 ## Log stuff 日志系统 ## logstyle 日志类型 ## log_G_spectra 记录G的头3个奇异值 ## log_D_spectra 记录D的头3个奇异值 ## sv_log_interval 每个多少论记录一次 #@@ Arguments for sample.py; not presently used in train.py ## add_sample_parser 增加一些采样的参数 config = vars(parser.parse_args()) print(config) ## 将运行字典传入GAN模型主程序 run(config)