def run_training(rank, world_size, model_args, data, load_from, new, num_train_steps, name): is_main = rank == 0 is_ddp = world_size > 1 if is_ddp: set_seed(seed) os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' dist.init_process_group('nccl', rank=rank, world_size=world_size) print(f"{rank + 1}/{world_size} process initialized.") model_args.update(is_ddp=is_ddp, rank=rank, world_size=world_size) model = Trainer(**model_args) if not new: model.load(load_from) else: model.clear() model.set_data_src(data) for _ in tqdm(range(num_train_steps - model.steps), initial=model.steps, total=num_train_steps, mininterval=10., desc=f'{name}<{data}>'): retry_call(model.train, tries=3, exceptions=NanException) if is_main and _ % 50 == 0: model.print_log() if is_ddp: dist.destroy_process_group()
def load_latest_model(): model_args = dict(name='default', results_dir='./results', models_dir='./models', batch_size=8, gradient_accumulate_every=6, image_size=128, network_capacity=4, fmap_max=512, transparent=False, lr=2e-4, lr_mlp=0.1, ttur_mult=1.5, rel_disc_loss=False, num_workers=16, save_every=1000, evaluate_every=1000, trunc_psi=0.75, fp16=False, cl_reg=False, fq_layers=[], fq_dict_size=256, attn_layers=[], no_const=False, aug_prob=0., aug_types=['translation', 'cutout'], top_k_training=False, generator_top_k_gamma=0.99, generator_top_k_frac=0.5, dataset_aug_prob=0., calculate_fid_every=None, mixed_prob=0.9, log=False) model = Trainer(**model_args) model.load(-1) model.GAN.train(False) return model
def train_from_folder(data='./data', results_dir='./results', models_dir='./models', name='default', new=False, load_from=-1, image_size=128, network_capacity=16, transparent=False, batch_size=5, gradient_accumulate_every=6, num_train_steps=150000, learning_rate=2e-4, lr_mlp=0.1, ttur_mult=1.5, rel_disc_loss=False, num_workers=None, save_every=1000, generate=False, generate_interpolation=False, interpolation_num_steps=100, save_frames=False, num_image_tiles=8, trunc_psi=0.75, fp16=False, cl_reg=False, fq_layers=[], fq_dict_size=256, attn_layers=[], no_const=False, aug_prob=0., aug_types=['translation', 'cutout'], generator_top_k=False, generator_top_k_gamma=0.99, generator_top_k_frac=0.5, dataset_aug_prob=0., multi_gpus=False, calculate_fid_every=None): model_args = dict(name=name, results_dir=results_dir, models_dir=models_dir, batch_size=batch_size, gradient_accumulate_every=gradient_accumulate_every, image_size=image_size, network_capacity=network_capacity, transparent=transparent, lr=learning_rate, lr_mlp=lr_mlp, ttur_mult=ttur_mult, rel_disc_loss=rel_disc_loss, num_workers=num_workers, save_every=save_every, trunc_psi=trunc_psi, fp16=fp16, cl_reg=cl_reg, fq_layers=fq_layers, fq_dict_size=fq_dict_size, attn_layers=attn_layers, no_const=no_const, aug_prob=aug_prob, aug_types=cast_list(aug_types), generator_top_k=generator_top_k, generator_top_k_gamma=generator_top_k_gamma, generator_top_k_frac=generator_top_k_frac, dataset_aug_prob=dataset_aug_prob, calculate_fid_every=calculate_fid_every) if generate: model = Trainer(**model_args) model.load(load_from) samples_name = timestamped_filename() model.evaluate(samples_name, num_image_tiles) print( f'sample images generated at {results_dir}/{name}/{samples_name}') return if generate_interpolation: model = Trainer(**model_args) model.load(load_from) samples_name = timestamped_filename() model.generate_interpolation(samples_name, num_image_tiles, num_steps=interpolation_num_steps, save_frames=save_frames) print( f'interpolation generated at {results_dir}/{name}/{samples_name}') return world_size = torch.cuda.device_count() if world_size == 1 or not multi_gpus: run_training(0, 1, model_args, data, load_from, new, num_train_steps, name) return mp.spawn(run_training, args=(world_size, model_args, data, load_from, new, num_train_steps, name), nprocs=world_size, join=True)
def train_from_folder(data='./data', results_dir='./results', models_dir='./models', name='default', new=False, load_from=-1, image_size=128, network_capacity=16, fmap_max=512, transparent=False, batch_size=5, gradient_accumulate_every=6, num_train_steps=150000, learning_rate=2e-4, lr_mlp=0.1, ttur_mult=1.5, rel_disc_loss=False, num_workers=None, save_every=1000, evaluate_every=1000, generate=False, num_generate=1, generate_interpolation=False, interpolation_num_steps=100, save_frames=False, num_image_tiles=8, trunc_psi=0.75, mixed_prob=0.9, fp16=False, no_pl_reg=False, cl_reg=False, fq_layers=[], fq_dict_size=256, attn_layers=[], no_const=False, aug_prob=0., aug_types=['translation', 'cutout'], top_k_training=False, generator_top_k_gamma=0.99, generator_top_k_frac=0.5, dual_contrast_loss=False, dataset_aug_prob=0., multi_gpus=False, calculate_fid_every=None, calculate_fid_num_images=12800, clear_fid_cache=False, seed=42, log=False, cat_len=27, softmax_loss=False): model_args = dict(name=name, results_dir=results_dir, models_dir=models_dir, batch_size=batch_size, gradient_accumulate_every=gradient_accumulate_every, image_size=image_size, network_capacity=network_capacity, fmap_max=fmap_max, transparent=transparent, lr=learning_rate, lr_mlp=lr_mlp, ttur_mult=ttur_mult, rel_disc_loss=rel_disc_loss, num_workers=num_workers, save_every=save_every, evaluate_every=evaluate_every, num_image_tiles=num_image_tiles, trunc_psi=trunc_psi, fp16=fp16, no_pl_reg=no_pl_reg, cl_reg=cl_reg, fq_layers=fq_layers, fq_dict_size=fq_dict_size, attn_layers=attn_layers, no_const=no_const, aug_prob=aug_prob, aug_types=cast_list(aug_types), top_k_training=top_k_training, generator_top_k_gamma=generator_top_k_gamma, generator_top_k_frac=generator_top_k_frac, dual_contrast_loss=dual_contrast_loss, dataset_aug_prob=dataset_aug_prob, calculate_fid_every=calculate_fid_every, calculate_fid_num_images=calculate_fid_num_images, clear_fid_cache=clear_fid_cache, mixed_prob=mixed_prob, log=log, cat_len=cat_len, softmax_loss=softmax_loss) if generate: model = Trainer(**model_args) model.load(load_from) samples_name = timestamped_filename() for num in tqdm(range(num_generate)): model.evaluate(f'{samples_name}-{num}', num_image_tiles) print( f'sample images generated at {results_dir}/{name}/{samples_name}') return if generate_interpolation: model = Trainer(**model_args) model.load(load_from) samples_name = timestamped_filename() model.generate_interpolation(samples_name, num_image_tiles, num_steps=interpolation_num_steps, save_frames=save_frames) print( f'interpolation generated at {results_dir}/{name}/{samples_name}') return world_size = torch.cuda.device_count() print("Number of available CUDA devices: ", world_size) if world_size == 1 or not multi_gpus: run_training(0, 1, model_args, data, load_from, new, num_train_steps, name, seed) return mp.spawn(run_training, args=(world_size, model_args, data, load_from, new, num_train_steps, name, seed), nprocs=world_size, join=True)
def train_from_folder( data='/hydration/ffhq/pose/trainB', results_dir='/hydration/results', models_dir='./models', name='run_testsg2_001', new=False, load_from=-1, image_size=128, # 256 if comparing against Hydration; 128 if reproducing StyleGAN 2 network_capacity=16, # 64 if comparing against Hydration; 16 if reproducing StyleGAN 2 transparent=False, batch_size=3, # 1 if comparing against Hydration; 3 if reproducing StyleGAN 2 gradient_accumulate_every=5, # 1 if comparing against Hydration; 5 if reproducing StyleGAN 2 num_train_steps=150000, learning_rate=2e-4, # Always 0.0002 num_workers=None, save_every=1000, generate=False, generate_interpolation=False, save_frames=False, num_image_tiles=8, trunc_psi=0.75, # Always 0.75 fp16=False, cl_reg=False, # Always False fq_layers=[], # [] if comparing against Hydration; [] if reproducing StyleGAN 2 fq_dict_size=256, # 256 if comparing against Hydration; 256 if reproducing StyleGAN 2 attn_layers=[], # [] if comparing against Hydration; [] if reproducing StyleGAN 2 no_const=False, # False if comparing against Hydration; False if reproducing StyleGAN 2 aug_prob=0., # 0.0 if comparing against Hydration; 0.0 if reproducing StyleGAN 2 dataset_aug_prob=0., # 0.0 if comparing against Hydration; 0.0 if reproducing StyleGAN 2 use_manual_seed=-1, # -1 for no seed # 0 if comparing against Hydration; -1 if reproducing StyleGAN 2 debug_and_crash_mode=False): model = Trainer(name, results_dir, models_dir, batch_size=batch_size, gradient_accumulate_every=gradient_accumulate_every, image_size=image_size, network_capacity=network_capacity, transparent=transparent, lr=learning_rate, num_workers=num_workers, save_every=save_every, trunc_psi=trunc_psi, fp16=fp16, cl_reg=cl_reg, fq_layers=fq_layers, fq_dict_size=fq_dict_size, attn_layers=attn_layers, no_const=no_const, aug_prob=aug_prob, dataset_aug_prob=dataset_aug_prob, use_manual_seed=use_manual_seed, debug_and_crash_mode=debug_and_crash_mode) if not new: model.load(load_from) else: model.clear() if generate: now = datetime.now() timestamp = now.strftime("%m-%d-%Y_%H-%M-%S") samples_name = f'generated-{timestamp}' model.evaluate(samples_name, num_image_tiles) print( f'sample images generated at {results_dir}/{name}/{samples_name}') return if generate_interpolation: now = datetime.now() timestamp = now.strftime("%m-%d-%Y_%H-%M-%S") samples_name = f'generated-{timestamp}' model.generate_interpolation(samples_name, num_image_tiles, save_frames=save_frames) print( f'interpolation generated at {results_dir}/{name}/{samples_name}') return model.set_data_src(data) for _ in tqdm(range(num_train_steps - model.steps), mininterval=10., desc=f'{name}<{data}>'): retry_call(model.train, tries=3, exceptions=NanException) if _ % 50 == 0: model.print_log()
def train_from_folder(data='./data', results_dir='./results', models_dir='./models', name='default', new=False, load_from=-1, image_size=128, network_capacity=16, transparent=False, batch_size=5, gradient_accumulate_every=6, num_train_steps=150000, learning_rate=2e-4, lr_mlp=0.1, ttur_mult=1.5, rel_disc_loss=False, num_workers=None, save_every=1000, generate=False, generate_interpolation=False, save_frames=False, num_image_tiles=8, trunc_psi=0.75, fp16=False, cl_reg=False, fq_layers=[], fq_dict_size=256, attn_layers=[], no_const=False, aug_prob=0., dataset_aug_prob=0., gpu_ids=[0]): gpu_ids = cast_list(gpu_ids) os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpu_ids)) from stylegan2_pytorch import Trainer, NanException model = Trainer(name, results_dir, models_dir, batch_size=batch_size, gradient_accumulate_every=gradient_accumulate_every, image_size=image_size, network_capacity=network_capacity, transparent=transparent, lr=learning_rate, lr_mlp=lr_mlp, ttur_mult=ttur_mult, rel_disc_loss=rel_disc_loss, num_workers=num_workers, save_every=save_every, trunc_psi=trunc_psi, fp16=fp16, cl_reg=cl_reg, fq_layers=fq_layers, fq_dict_size=fq_dict_size, attn_layers=attn_layers, no_const=no_const, aug_prob=aug_prob, dataset_aug_prob=dataset_aug_prob) if not new: model.load(load_from) else: model.clear() if generate: samples_name = timestamped_filename() model.evaluate(samples_name, num_image_tiles) print( f'sample images generated at {results_dir}/{name}/{samples_name}') return if generate_interpolation: samples_name = timestamped_filename() model.generate_interpolation(samples_name, num_image_tiles, save_frames=save_frames) print( f'interpolation generated at {results_dir}/{name}/{samples_name}') return model.set_data_src(data) for _ in tqdm(range(num_train_steps - model.steps), mininterval=10., desc=f'{name}<{data}>'): retry_call(model.train, tries=3, exceptions=NanException) if _ % 50 == 0: model.print_log()
def run_training(rank, world_size, model_args, data, load_from, disc_load_from, new, num_train_steps, name, seed): is_main = rank == 0 is_ddp = world_size > 1 if is_ddp: set_seed(seed) os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' dist.init_process_group('nccl', rank=rank, world_size=world_size) print(f"{rank + 1}/{world_size} process initialized.") model_args.update(is_ddp=is_ddp, rank=rank, world_size=world_size) model = Trainer(**model_args) if not new: model.load(load_from) if disc_load_from > 0: model2 = Trainer(**model_args) model2.load(disc_load_from) model.GAN.D, model.GAN.D_aug, model.GAN.D_cl, model.GAN.D_opt = model2.GAN.D, model2.GAN.D_aug, model2.GAN.D_cl, model2.GAN.D_opt del model2 else: model.clear() model.set_data_src(data) progress_bar = tqdm(initial=model.steps, total=num_train_steps, mininterval=10., desc=f'{name}<{data}>') while model.steps < num_train_steps: retry_call(model.train, tries=3, exceptions=NanException) progress_bar.n = model.steps progress_bar.refresh() if is_main and model.steps % 50 == 0: model.print_log() model.save(model.checkpoint_num) if is_ddp: dist.destroy_process_group()
def synthesize_gwas_data( n_causal=100, exp_var=0.5, img_size=512, models_dir='../models', checkpoint=-1, same_noise=True, psi=0.6, name='stylegan2_healthy', mult_scale=1., subset=None, seed=123, wdir='.', ): '''synthesize images from latent codes with StyleGAN2 Load latent codes from corresponding LATENT_BOLT_TEMPL and generate synthetic images via StyleGAN2 model # Parameters n_causal (int): number of causal SNPs exp_var (float in [0, 1]): percentage of explained variance by causal SNPs img_size (int): size of images to be created (determined by training scheme of StyleGAN2) checkpoint (int): which epoch to use; if -1, load latest epoch in model directory same_noise (bool): use the same noise vector for all images psi (float in [0, 1]): truncation parameter for images, trade-off between image quality and diversity name (str): name of StyleGAN2 model in models directory mult_scale (float): multiplier for standard-normal style vector (input), to increase/decrease diversity of images subset (None or int): only create subset of images, for debugging seed (int): random seed ''' pth = join(wdir, LATENT_DIR, get_latent_bolt(exp_var, n_causal, seed)) latent = pd.read_csv(pth, sep=' ', index_col=1).drop('FID', 1) if subset is not None: latent = latent.sample(subset, random_state=123) T = Trainer(name, models_dir=models_dir, results_dir=models_dir, image_size=img_size, network_capacity=16) T.load(checkpoint) if psi is None: psi = T.trunc_psi if same_noise: N = image_noise(1, img_size) out_img_dir = join(wdir, get_img_dir(name, exp_var, n_causal, mult_scale, seed)) os.makedirs(out_img_dir, exist_ok=True) for i, lat in tqdm(latent.iterrows(), total=len(latent)): if not same_noise: N = image_noise(1, img_size) L = [(mult_scale * torch.from_numpy(lat.values).view(1, -1).float().cuda(), 8)] gen = T.generate_truncated(T.GAN.S, T.GAN.G, L, N, trunc_psi=psi) gen = (gen.cpu().double()[0].permute(1, 2, 0).numpy() * 255).astype( np.uint8) gen = PIL.Image.fromarray(gen) path = join(out_img_dir, f'{i}.jpg') gen.save(path) cwd = os.getcwd() paths = [join(cwd, out_img_dir, f'{i}.jpg') for i in latent.index] df = pd.DataFrame(np.array([latent.index, paths]).T, columns=['IID', 'path']) csv_path = out_img_dir + '.csv' df.to_csv(csv_path, index=False)