Exemplo n.º 1
0
def main():
    # parse command line and run
    torch.autograd.set_detect_anomaly(True)
    parser = utils.prepare_parser()
    config = vars(parser.parse_args())
    print(config)
    run(config)
Exemplo n.º 2
0
def get_config(args):
    parser = utils.prepare_parser()
    parser = utils.add_sample_parser(parser)
    config = vars(parser.parse_args())

    resolution = utils.imsize_dict[args.dataset]
    attn_dict = {128: '64', 256: '128', 512: '64'}
    dim_z_dict = {128: 120, 256: 140, 512: 128}

    # See: https://github.com/ajbrock/BigGAN-PyTorch/blob/master/scripts/sample_BigGAN_bs256x8.sh.
    config["resolution"] = resolution
    config["n_classes"] = utils.nclass_dict[args.dataset]
    config["G_activation"] = utils.activation_dict["inplace_relu"]
    config["D_activation"] = utils.activation_dict["inplace_relu"]
    config["G_attn"] = attn_dict[resolution]
    config["D_attn"] = "64"
    config["G_ch"] = 96
    config["D_ch"] = 96
    config["hier"] = True
    config["dim_z"] = dim_z_dict[resolution]
    config["shared_dim"] = 128
    config["G_shared"] = True
    config = utils.update_config_roots(config)
    config["skip_init"] = True
    config["no_optim"] = True
    config["device"] = "cuda"
    return config
Exemplo n.º 3
0
def main():
    parser = utils.prepare_parser()
    parser = utils.add_dgp_parser(parser)

    config = vars(parser.parse_args())
    utils.dgp_update_config(config)

    #Print parameters
    cat = [
        'model', 'dgp_mode', 'list_file', 'exp_path', 'root_dir', 'resolution',
        'random_G', 'update_G', 'custom_mask'
    ]
    for key, val in config.items():
        if key in cat:
            print(key, ":", str(val))

    if config['custom_mask']:
        config['mask_path'] = '../data/input/' + config['mask_path']
        print('mask_path :', config['mask_path'])

    rank = 0
    if mp.get_start_method(allow_none=True) != 'spawn':
        mp.set_start_method('spawn', force=True)
    if config['dist']:
        rank, world_size = dist_init(config['port'])

    # Seed RNG
    utils.seed_rng(rank + config['seed'])

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # train
    trainer = Trainer(config)
    trainer.run()
Exemplo n.º 4
0
def main():
    # parse command line and run
    parser = utils.prepare_parser()
    parser = utils.add_sample_parser(parser)
    config = vars(parser.parse_args())
    print(config)
    run(config)
Exemplo n.º 5
0
def _mp_fn(index):
    torch.set_default_tensor_type('torch.FloatTensor')
    # parse command line and run
    parser = utils.prepare_parser()
    config = vars(parser.parse_args())
    print(config)
    run(config)
Exemplo n.º 6
0
def main():
  # parse command line and run
  parser = utils.prepare_parser()
  config = vars(parser.parse_args())
  print(config)
  args = argparse.Namespace(outdir='results/tmp')
  myargs = argparse.Namespace(stdout=sys.stdout)
  run(config, args, myargs)
Exemplo n.º 7
0
def main():
    # parse command line and run
    parser = utils.prepare_parser()

    update_parser_defaults_from_yaml(parser)

    config = vars(parser.parse_args())
    print(config)
    run(config)
Exemplo n.º 8
0
def get_config():
    parser = utils.prepare_parser()
    config = vars(parser.parse_args())
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config["experiment_name"] = input("Please input experiment_name: ")
    config["model_parameter_name"] = find_best_model_parameter(
        config["experiment_name"])
    return config
Exemplo n.º 9
0
def main():

    parser = utils.prepare_parser()
    config = vars(parser.parse_args())
    time = datetime.now()
    time = time.strftime("%m-%d-%Y-%H:%M:%S")
    # state_dict = {'itr': 0, 'epoch': 0, 'best_epoch': 0, 'best_test_loss': 9999999,
    #               'config': config}

    if config['exp_name'] == '' and not config['resume']:
        config['exp_name'] = '{}_model_{}_bs_{}_lr_{}'.format(time, config['model'], config['batch_size'], config['lr'])
        print('======>new exp name: {}\n\n'.format(config['exp_name']))
    elif (not len(config['exp_name']) == 0):
        pass
    else:
        print(config['exp_name'])
        print(config['resume'])
        raise Exception('Set up experiment wrong!')

    if config['model'] =='ours':
        for _ in range(3):
            plot_dict={}
            plot_dict['roc_auc'] = []
            plot_dict['auprc'] = []
            for feat_mult in [300,320,384,512,768]:
                config['feature_multiplier'] = feat_mult
                with open("{}{}_results_log.txt".format(config['ckpts_path'], config['exp_name']), "a+") as file:
                    file.write('\n\n\n\n\n\n\n\n======>>>>>>{}\n\n'.format(config))
                    file.write('No Maxpool1d\n')

                print(config)
                print('\n\n')
                plot_dict = run(config, plot_dict)
            with open('{}{}_{}_with_Maxpool_data_for_plot.txt'.format(config['ckpts_path'], config['exp_name'], _), 'a+') as file:
                file.write(json.dumps(plot_dict))
                file.write('\n\n\n')

            #json.dump(plot_dict, open('{}{}_{}_json_data_for_plot.txt'.format(config['ckpts_path'], config['exp_name'], _), 'a+'))
    elif config['model'] == 'danq':
        for _ in range(3):
            plot_dict = {}
            plot_dict['roc_auc'] = []
            plot_dict['auprc'] = []
            for feat_mult in [8,16,48,96,128,192,256,300,320]: #[8, 16, 48, 96, 128, 192, 256, 300, 320, 360, 420]:
                config['feature_multiplier'] = feat_mult
                with open("{}{}_results_log.txt".format(config['ckpts_path'], config['exp_name']), "a+") as file:
                    file.write('\n\n\n\n======>>>>>>{}\n\n'.format(config))
                    file.write('With danq\n')
                print('\n\n')
                print(config)
                plot_dict = run(config, plot_dict)
            with open('{}{}_{}_data_for_plot.txt'.format(config['ckpts_path'], config['exp_name'], _),
                      'a+') as file:
                file.write(json.dumps(plot_dict))
                file.write('\n\n\n')
    else:
        raise Exception('choose the correct model..')
Exemplo n.º 10
0
def main():
    # parse command line and run
    parser = utils.prepare_parser()
    config = vars(parser.parse_args())
    print(config)
    path_torch_home = os.path.join(config['root_path'], 'torch_cache')
    os.makedirs(path_torch_home, exist_ok=True)
    os.environ['TORCH_HOME'] = path_torch_home
    run(config)
Exemplo n.º 11
0
def main():
    # parse command line and run
    parser = utils.prepare_parser()
    parser = utils.add_sample_parser(parser)
    config = vars(parser.parse_args())
    # print("reached_main")
    # for item in config:
    #	print(item, config[item])
    run(config)
Exemplo n.º 12
0
def main():
    # parse command line and run
    parser = utils.prepare_parser()

    update_parser_defaults_from_yaml(parser)

    args = parser.parse_args()
    args.base_root = os.path.join(args.tl_outdir, 'biggan')
    config = EasyDict(vars(args))
    config_str = get_dict_str(config)
    logger = logging.getLogger('tl')
    logger.info(config_str)
    run(config)
Exemplo n.º 13
0
def main():
    # parse command line and run
    parser = utils.prepare_parser()
    parser = utils.add_sample_parser(parser)
    config = vars(parser.parse_args())
    print(config)
    if config['sample_multiple']:
        suffixes = config['load_weights'].split(',')
        for suffix in suffixes:
            config['load_weights'] = suffix
            run(config)
    else:
        run(config)
Exemplo n.º 14
0
def main():
    logger = logging.getLogger('tl')
    # parse command line
    parser = utils.prepare_parser()
    update_parser_defaults_from_yaml(parser)

    args = parser.parse_args()
    args.base_root = os.path.join(args.tl_outdir, 'biggan')
    opt = EasyDict(vars(args))

    logger.info(f"\nglobal_cfg: \n" + get_dict_str(global_cfg))
    global_cfg.dump_to_file_with_command(
        f"{opt.tl_outdir}/config_command.yaml", command=opt.tl_command)

    run(opt)
Exemplo n.º 15
0
def main():
    parser = utils.prepare_parser()
    parser = utils.add_dgp_parser(parser)
    config = vars(parser.parse_args())
    utils.dgp_update_config(config)
    print(config)

    rank = 0
    if mp.get_start_method(allow_none=True) != 'spawn':
        mp.set_start_method('spawn', force=True)
    if config['dist']:
        rank, world_size = dist_init(config['port'])

    # Seed RNG
    utils.seed_rng(rank + config['seed'])

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # train
    trainer = Trainer(config)
    trainer.run()
                         weight_decay=0.0,
                         d_penalty=config['d_penalty'],
                         g_penalty=0,
                         noise_shape=(64, config['z_dim']),
                         gp_weight=config['gp_weight'])
    # d_penalty: L2 penalty on discriminator;
    # gp_weight: gradient penalty weight
    # trainer.load_checkpoint(chkpt_path='')
    trainer.train_bcgd(is_flag=config['eval_is'],
                       fid_flag=config['eval_fid'],
                       epoch_num=config['epoch_num'],
                       mode='ACGD',
                       collect_info=config['collect_info'],
                       dataname=config['dataset'],
                       logname=config['logdir'],
                       loss_type=config['loss_type'])
    # Loss type: JSD, WGAN
    # trainer.train_bcgd(epoch_num=120, mode='ACGD', collect_info=True, dataname='CIFAR10-WGAN', logname='CIFAR10-WGAN', loss_type='WGAN')

    # trainer.train_gd(epoch_num=600, mode=modes[3], dataname='CIFAR10-JSD', logname='CIFAR10-JSD', loss_type='JSD')


if __name__ == '__main__':
    torch.backends.cudnn.benchmark = True
    parser = prepare_parser()
    config = vars(parser.parse_args())
    print(config)
    train_wgan(config)
    # train_mnist()
    # train_cifar()
Exemplo n.º 17
0
def main():
    # parse command line and run
    parser = utils.prepare_parser()
    config = vars(parser.parse_args())
    #print(config)
    run(config)
Exemplo n.º 18
0
def main():
    parser = utils.prepare_parser()
    config = vars(parser.parse_args())
    print(config)
    run(config)
Exemplo n.º 19
0
from common import *
import utils
import dataset
#
os.makedirs('sample_dc_gan', exist_ok=True)
# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

parser = utils.prepare_parser()
config = vars(parser.parse_args())

loaders = dataset.get_data_loaders(
    data_root=config['data_root'],
    label_root=config['label_root'],
    batch_size=config['batch_size'],
    num_workers=config['num_workers'],
    shuffle=config['shuffle'],
    pin_memory=config['pin_memory'],
    drop_last=True,
    load_in_mem=config['load_in_mem'],
    mask_out=True,
)

image_size = IMG_SIZE

nc = 3
Exemplo n.º 20
0
def main():
    parser = utils.prepare_parser()
    config = vars(parser.parse_args())
    # See the configuration.
    #print(config)
    run(config)
Exemplo n.º 21
0
    'alexnet',  # 224x224
    'vgg16',
    'vgg16_bn',
    'vgg19',
    'vgg19_bn',  # 224x224
    'resnet18',
    'resnet34',
    'resnet50',
    'resnet101',
    'resnet152',  # 224x224
    'squeezenet_v0',
    'squeezenet_v1',  #224x224
    'inception_v3',  # 299x299
]

args = utils.prepare_parser()

args.gpu = misc.auto_select_gpu(utility_bound=0,
                                num_gpu=args.ngpu,
                                selected_gpus=args.gpu)
args.ngpu = len(args.gpu)
misc.ensure_dir(args.logdir)  # ensure or create logdir
args.model_root = misc.expand_user(args.model_root)
args.data_root = misc.expand_user(args.data_root)
args.input_size = 299 if 'inception' in args.type else args.input_size
assert args.quant_method in ['linear', 'minmax', 'log', 'tanh', 'scale']
print("=================PARSER==================")
for k, v in args.__dict__.items():
    print('{}: {}'.format(k, v))
print("========================================")
Exemplo n.º 22
0
def main():
    # parse command line and run
    # 外部运行主程序入口时候可以输入参数,控制计算模式
    parser = utils.prepare_parser()
    ## vars() 函数返回对象object的属性和属性值的字典对象。
    ## 将parser字典化
    ## 数据集dataset 默认为 I128_hdf5
    ## 是否采用数据增强augment 默认为 0
    ## num_workers 加速数据读取的加载,默认为8, 应小于HDF5 TODO 待查为啥要小于某个数据格式
    ## no_pin_memory 作者也不确定,暂时不管,默认为FALSE TODO
    ## shuffle 数据还是需要shuffle,万一数据有聚集BATCH的情况,模型就疯了,默认为True
    ## load_in_mem 作者也不确定,估计是用来加速数据加载速度的 TODO
    ## use_multiepoch_sampler 使用多回合采样方法,默认为True,TODO multi-epoch sampler 这个东西一定要研究一下,采样相当重要,这个不知道是啥
    ## model 默认使用 BigGAN
    ## G_param 生成模型使用图像归一方法,默认使用谱归一化,还可以选择SVD或者None,TODO 利用代码加深一下SN和SVD的区别
    ## D_param 判别模型使用图像归一方法,默认使用谱归一化,还可以选择SVD或者None,TODO 利用代码加深一下SN和SVD的区别

    ## resolution 表示结构代码中的通道数,默认128可解析为如下结构:
    # G_ch 表示结构代码中的通道数,即ch的值
    # arch[128] = {'in_channels' :  [ch * item for item in [16, 16, 8, 4, 2]],
    #             'out_channels' : [ch * item for item in [16, 8, 4, 2, 1]],
    #             'upsample' : [True] * 5,
    #             'resolution' : [8, 16, 32, 64, 128],
    #             'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])
    #                             for i in range(3,8)}}
    ## D_ch 辨别模型的信道 默认64 同上

    ## G_depth 每个阶段G的resblocks的数量 TODO 盲猜和ResNet结构有关系
    ## D_depth 每个阶段D的resblocks的数量
    ## TODO D_thin和D_wide是干啥的?默认都是FALSE
    ## G_shared TODO Use shared embeddings in G 这个需要细看一下是怎么操作的
    ## shared_dim TODO 'G''s shared embedding dimensionality; if 0, will be equal to dim_z.
    ## dim_z 噪声的维度,默认为128
    ## z_var 噪声的标准差,premiere为1
    ## hier 这个是使用多层噪声
    # 以arch[64]为例子, 插入的层数 num_slots 为 [16, 16, 8, 4] 的len + 1
    # 每一块插入的尺寸 z_chunk_size 为 dim_z // num_slots,即把噪声的维度 与 插入层数 取模
    # z_chunk_size是直接参与前向传播 forward 和 which_bn 应该是决定哪一层使用BatchNormalization
    # 重塑噪声维度,让其在每一层插入的维度是相同的 dim_z = z_chunk_size * num_slots
    #~ 如果不使用多层噪声,那么就把 num_slots 的数量置为1
    # 同时把 z_chunk_size 置于 0

    ## cross_replica TODO 这个是啥?把G模型的batchnorm再复制一遍吗
    ## 使用默认的归一化方法 mybn,看一下为啥
    ## G_nl&D_nl的激活函数
    ## G_attn和D_attn是否使用attention机制
    ## norm_style 使用归一化方法,CNN还是使用BN比较好,四者之间的区别可以百度
    ## seed 随机数种子,默认为0,控制初始化参数和数据读取的
    ## G_init 生成模型初始化方法 TODO !!! 要知晓一下ortho是什么类型的初始化方法 使初始化的参数矩阵是正交规范化的
    # https://zhuanlan.zhihu.com/p/98873800
    # 正交规范化的矩阵能够缓解梯度消失和爆炸,具体看BigGAN中的援引文献,
    # 原为L1范数,BigGAN中使用L2范数进行改进
    ## D_init 判别模型初始化方法
    ## skip_init 跳过初始化 TODO 为什么跳过初始化,就是ideal的testing?
    ## G_lr 生成模型的学习率
    ## D_lr 辨别模型的学习率
    ## G_B1 & D_B1 beta1 % G_B2 & D_B2 beta2 TODO 属于GAN中的什么类型的参数
    ## batch_size 批大小 64
    ## G_batch_size TODO 为何还要写G的batch size??
    ## num_G_accumulations TODO 把G的梯度相加又是为了什么??
    ## num_D_steps TODO 每一步G需要跑几次D,D需要多训练训练,如果G跑太前面就太逼近原图了?
    ## num_D_accumulations TODO 把D的梯度相加又是为了什么??
    ## split_D 运行D两次,而不是连接输入 TODO ?
    ## num_epochs @@ 训练的轮次数量
    ## parallel 默认是FALSE 训练时是否使用多GPUs一起
    ## G_fp16 & D_fp16 & G_mixed_precision & D_mixed_precision 对于精度的使用,一个是加快计算,一个减小模型大小.有的时候会带一点准确率的损失,所以mix表示预测的时候用更高的精度
    ## accumulate_stats 默认FALSE TODO 积累统计数据
    ## num_standing_accumulations TODO 这个和楼上那个什么关系?!!!

    #@@ 接下来是运行模式
    ## G_eval_mode TODO 每次采样或者测试的时候,都评估一下G?
    ## save_every 每2000迭代储存一下模型
    ## num_save_copies TODO 储几个备份?默认2个
    ## num_best_copies 存几个最好的版本,默认2个
    ## which_best 依据IS或者FID作为评价指标去储存模型
    ## no_fid 是否只计算IS,不计算FID。默认都计算
    ## test_every 没多少论测试一下,默认5000
    ## num_inception_images TODO 用于计算初始度量的样本数量50000
    ## hashname 是否使用HasName而不是配置中的类别,默认不使用
    ## base_root 默认储存所有权重,采样,数据,日志
    ## data_root TODO !!!数据默认存在哪里
    ## weights_root,logs_root,samples_root 本地存在哪儿
    ## pbar 是否使用进度条,默认使用mine
    ## name_suffix 命名添加后缀
    ## experiment_name 自定义存储实验名称
    ## config_from_name 是否使用hash实验名

    #@@ EMA 对权重进行指数滑动平均处理的工具群
    ## ema 是否保存G的ema参数
    ## ema_decay EMA的衰减率
    ## use_ema 是否在G中使用评估
    ## ema_start 什么时候去用EMA方法更新权重

    #@@ SV stuff 奇异值的迭代性质
    ## adam_eps adam的epsilon value,TODO 看一下它在干嘛
    ## BN_eps & SN_eps 批归一化和谱归一化的参数
    ## num_G_SVs G的奇异值追踪数量,默认为1
    ## num_D_SVs D的奇异值追踪数量,默认为1
    ## num_G_SV_itrs G的迭代次数,默认为1
    ## num_D_SV_itrs D的迭代次数,默认为1

    #@@ Ortho reg stuff 这个又是什么
    ## 控制Ortho的迭代性质

    ## which_train_fn 默认是GAN
    ## load_weights 加载哪类参数,是copy的或者历史最好的etc.
    ## resume TODO Resume training是从历史记录开始训练,可以用于预训练

    ## Log stuff 日志系统
    ## logstyle 日志类型
    ## log_G_spectra 记录G的头3个奇异值
    ## log_D_spectra 记录D的头3个奇异值
    ## sv_log_interval 每个多少论记录一次

    #@@ Arguments for sample.py; not presently used in train.py
    ## add_sample_parser 增加一些采样的参数
    config = vars(parser.parse_args())
    print(config)
    ## 将运行字典传入GAN模型主程序
    run(config)