def run_training(rank, world_size, model_args, data, load_from, new, num_train_steps, name): is_main = rank == 0 is_ddp = world_size > 1 if is_ddp: set_seed(seed) os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' dist.init_process_group('nccl', rank=rank, world_size=world_size) print(f"{rank + 1}/{world_size} process initialized.") model_args.update(is_ddp=is_ddp, rank=rank, world_size=world_size) model = Trainer(**model_args) if not new: model.load(load_from) else: model.clear() model.set_data_src(data) for _ in tqdm(range(num_train_steps - model.steps), initial=model.steps, total=num_train_steps, mininterval=10., desc=f'{name}<{data}>'): retry_call(model.train, tries=3, exceptions=NanException) if is_main and _ % 50 == 0: model.print_log() if is_ddp: dist.destroy_process_group()
def run_training(rank, world_size, model_args, data, load_from, disc_load_from, new, num_train_steps, name, seed): is_main = rank == 0 is_ddp = world_size > 1 if is_ddp: set_seed(seed) os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' dist.init_process_group('nccl', rank=rank, world_size=world_size) print(f"{rank + 1}/{world_size} process initialized.") model_args.update(is_ddp=is_ddp, rank=rank, world_size=world_size) model = Trainer(**model_args) if not new: model.load(load_from) if disc_load_from > 0: model2 = Trainer(**model_args) model2.load(disc_load_from) model.GAN.D, model.GAN.D_aug, model.GAN.D_cl, model.GAN.D_opt = model2.GAN.D, model2.GAN.D_aug, model2.GAN.D_cl, model2.GAN.D_opt del model2 else: model.clear() model.set_data_src(data) progress_bar = tqdm(initial=model.steps, total=num_train_steps, mininterval=10., desc=f'{name}<{data}>') while model.steps < num_train_steps: retry_call(model.train, tries=3, exceptions=NanException) progress_bar.n = model.steps progress_bar.refresh() if is_main and model.steps % 50 == 0: model.print_log() model.save(model.checkpoint_num) if is_ddp: dist.destroy_process_group()
def train_from_folder( data='/hydration/ffhq/pose/trainB', results_dir='/hydration/results', models_dir='./models', name='run_testsg2_001', new=False, load_from=-1, image_size=128, # 256 if comparing against Hydration; 128 if reproducing StyleGAN 2 network_capacity=16, # 64 if comparing against Hydration; 16 if reproducing StyleGAN 2 transparent=False, batch_size=3, # 1 if comparing against Hydration; 3 if reproducing StyleGAN 2 gradient_accumulate_every=5, # 1 if comparing against Hydration; 5 if reproducing StyleGAN 2 num_train_steps=150000, learning_rate=2e-4, # Always 0.0002 num_workers=None, save_every=1000, generate=False, generate_interpolation=False, save_frames=False, num_image_tiles=8, trunc_psi=0.75, # Always 0.75 fp16=False, cl_reg=False, # Always False fq_layers=[], # [] if comparing against Hydration; [] if reproducing StyleGAN 2 fq_dict_size=256, # 256 if comparing against Hydration; 256 if reproducing StyleGAN 2 attn_layers=[], # [] if comparing against Hydration; [] if reproducing StyleGAN 2 no_const=False, # False if comparing against Hydration; False if reproducing StyleGAN 2 aug_prob=0., # 0.0 if comparing against Hydration; 0.0 if reproducing StyleGAN 2 dataset_aug_prob=0., # 0.0 if comparing against Hydration; 0.0 if reproducing StyleGAN 2 use_manual_seed=-1, # -1 for no seed # 0 if comparing against Hydration; -1 if reproducing StyleGAN 2 debug_and_crash_mode=False): model = Trainer(name, results_dir, models_dir, batch_size=batch_size, gradient_accumulate_every=gradient_accumulate_every, image_size=image_size, network_capacity=network_capacity, transparent=transparent, lr=learning_rate, num_workers=num_workers, save_every=save_every, trunc_psi=trunc_psi, fp16=fp16, cl_reg=cl_reg, fq_layers=fq_layers, fq_dict_size=fq_dict_size, attn_layers=attn_layers, no_const=no_const, aug_prob=aug_prob, dataset_aug_prob=dataset_aug_prob, use_manual_seed=use_manual_seed, debug_and_crash_mode=debug_and_crash_mode) if not new: model.load(load_from) else: model.clear() if generate: now = datetime.now() timestamp = now.strftime("%m-%d-%Y_%H-%M-%S") samples_name = f'generated-{timestamp}' model.evaluate(samples_name, num_image_tiles) print( f'sample images generated at {results_dir}/{name}/{samples_name}') return if generate_interpolation: now = datetime.now() timestamp = now.strftime("%m-%d-%Y_%H-%M-%S") samples_name = f'generated-{timestamp}' model.generate_interpolation(samples_name, num_image_tiles, save_frames=save_frames) print( f'interpolation generated at {results_dir}/{name}/{samples_name}') return model.set_data_src(data) for _ in tqdm(range(num_train_steps - model.steps), mininterval=10., desc=f'{name}<{data}>'): retry_call(model.train, tries=3, exceptions=NanException) if _ % 50 == 0: model.print_log()
def train_from_folder(data='./data', results_dir='./results', models_dir='./models', name='default', new=False, load_from=-1, image_size=128, network_capacity=16, transparent=False, batch_size=5, gradient_accumulate_every=6, num_train_steps=150000, learning_rate=2e-4, lr_mlp=0.1, ttur_mult=1.5, rel_disc_loss=False, num_workers=None, save_every=1000, generate=False, generate_interpolation=False, save_frames=False, num_image_tiles=8, trunc_psi=0.75, fp16=False, cl_reg=False, fq_layers=[], fq_dict_size=256, attn_layers=[], no_const=False, aug_prob=0., dataset_aug_prob=0., gpu_ids=[0]): gpu_ids = cast_list(gpu_ids) os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpu_ids)) from stylegan2_pytorch import Trainer, NanException model = Trainer(name, results_dir, models_dir, batch_size=batch_size, gradient_accumulate_every=gradient_accumulate_every, image_size=image_size, network_capacity=network_capacity, transparent=transparent, lr=learning_rate, lr_mlp=lr_mlp, ttur_mult=ttur_mult, rel_disc_loss=rel_disc_loss, num_workers=num_workers, save_every=save_every, trunc_psi=trunc_psi, fp16=fp16, cl_reg=cl_reg, fq_layers=fq_layers, fq_dict_size=fq_dict_size, attn_layers=attn_layers, no_const=no_const, aug_prob=aug_prob, dataset_aug_prob=dataset_aug_prob) if not new: model.load(load_from) else: model.clear() if generate: samples_name = timestamped_filename() model.evaluate(samples_name, num_image_tiles) print( f'sample images generated at {results_dir}/{name}/{samples_name}') return if generate_interpolation: samples_name = timestamped_filename() model.generate_interpolation(samples_name, num_image_tiles, save_frames=save_frames) print( f'interpolation generated at {results_dir}/{name}/{samples_name}') return model.set_data_src(data) for _ in tqdm(range(num_train_steps - model.steps), mininterval=10., desc=f'{name}<{data}>'): retry_call(model.train, tries=3, exceptions=NanException) if _ % 50 == 0: model.print_log()