def train_from_folder( data = './data', results_dir = './results', models_dir = './models', name = 'default', new = False, load_from = -1, image_size = 256, optimizer = 'adam', fmap_max = 512, transparent = False, batch_size = 10, gradient_accumulate_every = 4, num_train_steps = 150000, learning_rate = 2e-4, save_every = 1000, evaluate_every = 1000, generate = False, generate_interpolation = False, attn_res_layers = [32], sle_spatial = False, disc_output_size = 1, interpolation_num_steps = 100, save_frames = False, num_image_tiles = 8, trunc_psi = 0.75, aug_prob = 0., aug_types = ['cutout', 'translation'], dataset_aug_prob = 0., multi_gpus = False, calculate_fid_every = None, seed = 42 ): model_args = dict( name = name, results_dir = results_dir, models_dir = models_dir, batch_size = batch_size, gradient_accumulate_every = gradient_accumulate_every, attn_res_layers = cast_list(attn_res_layers), sle_spatial = sle_spatial, disc_output_size = disc_output_size, image_size = image_size, optimizer = optimizer, fmap_max = fmap_max, transparent = transparent, lr = learning_rate, save_every = save_every, evaluate_every = evaluate_every, trunc_psi = trunc_psi, aug_prob = aug_prob, aug_types = cast_list(aug_types), dataset_aug_prob = dataset_aug_prob, calculate_fid_every = calculate_fid_every ) if generate: model = Trainer(**model_args) model.load(load_from) samples_name = timestamped_filename() model.evaluate(samples_name, num_image_tiles) print(f'sample images generated at {results_dir}/{name}/{samples_name}') return if generate_interpolation: model = Trainer(**model_args) model.load(load_from) samples_name = timestamped_filename() model.generate_interpolation(samples_name, num_image_tiles, num_steps = interpolation_num_steps, save_frames = save_frames) print(f'interpolation generated at {results_dir}/{name}/{samples_name}') return world_size = torch.cuda.device_count() if world_size == 1 or not multi_gpus: run_training(0, 1, model_args, data, load_from, new, num_train_steps, name, seed) return mp.spawn(run_training, args=(world_size, model_args, data, load_from, new, num_train_steps, name, seed), nprocs=world_size, join=True)
def run_training(rank, world_size, model_args, data, load_from, new, num_train_steps, name, seed): is_main = rank == 0 is_ddp = world_size > 1 if is_ddp: set_seed(seed) os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' dist.init_process_group('nccl', rank=rank, world_size=world_size) print(f"{rank + 1}/{world_size} process initialized.") model_args.update( is_ddp = is_ddp, rank = rank, world_size = world_size ) model = Trainer(**model_args) if not new: model.load(load_from) else: model.clear() model.set_data_src(data) for _ in tqdm(range(num_train_steps - model.steps), initial = model.steps, total = num_train_steps, mininterval=10., desc=f'{name}<{data}>'): retry_call(model.train, tries=3, exceptions=NanException) if is_main and _ % 50 == 0: model.print_log() model.save(model.checkpoint_num) if is_ddp: dist.destroy_process_group()
def load_model(num=-1, path=None): model = Trainer(save_every=1) model.load(num, path) model.eval() return model
def get_trainer( data='./data', results_dir='./results', models_dir='./models', name='default', new=False, load_from=-1, image_size=256, optimizer='adam', fmap_max=512, transparent=False, batch_size=10, gradient_accumulate_every=4, num_train_steps=150000, learning_rate=2e-4, save_every=1000, evaluate_every=1000, generate=False, generate_interpolation=False, attn_res_layers=[32], sle_spatial=False, disc_output_size=1, antialias=False, interpolation_num_steps=100, save_frames=False, num_image_tiles=8, trunc_psi=0.75, aug_prob=None, aug_types=['cutout', 'translation'], dataset_aug_prob=0., multi_gpus=False, calculate_fid_every=None, seed=42, amp=False ): def cast_list(el): return el if isinstance(el, list) else [el] model_args = dict( name=name, results_dir=results_dir, models_dir=models_dir, batch_size=batch_size, gradient_accumulate_every=gradient_accumulate_every, attn_res_layers=cast_list(attn_res_layers), sle_spatial=sle_spatial, disc_output_size=disc_output_size, antialias=antialias, image_size=image_size, optimizer=optimizer, fmap_max=fmap_max, transparent=transparent, lr=learning_rate, save_every=save_every, evaluate_every=evaluate_every, trunc_psi=trunc_psi, aug_prob=aug_prob, aug_types=cast_list(aug_types), dataset_aug_prob=dataset_aug_prob, calculate_fid_every=calculate_fid_every, amp=amp ) ret = Trainer(**model_args) ret.load(load_from) return ret
def train_from_folder( data='./data', results_dir='./results', models_dir='./models', name='default', new=False, load_from=-1, image_size=256, optimizer='adam', fmap_max=512, transparent=False, greyscale=False, batch_size=10, gradient_accumulate_every=4, num_train_steps=150000, learning_rate=2e-4, save_every=1000, evaluate_every=1000, generate=False, generate_types=['default', 'ema'], generate_interpolation=False, aug_test=False, aug_prob=None, aug_types=['cutout', 'translation'], dataset_aug_prob=0., attn_res_layers=[32], freq_chan_attn=False, disc_output_size=1, antialias=False, interpolation_num_steps=100, save_frames=False, num_image_tiles=None, trunc_psi=0.75, multi_gpus=False, calculate_fid_every=None, calculate_fid_num_images=12800, clear_fid_cache=False, seed=42, amp=False, show_progress=False, ): num_image_tiles = default(num_image_tiles, 4 if image_size > 512 else 8) model_args = dict(name=name, results_dir=results_dir, models_dir=models_dir, batch_size=batch_size, gradient_accumulate_every=gradient_accumulate_every, attn_res_layers=cast_list(attn_res_layers), freq_chan_attn=freq_chan_attn, disc_output_size=disc_output_size, antialias=antialias, image_size=image_size, num_image_tiles=num_image_tiles, optimizer=optimizer, fmap_max=fmap_max, transparent=transparent, greyscale=greyscale, lr=learning_rate, save_every=save_every, evaluate_every=evaluate_every, trunc_psi=trunc_psi, aug_prob=aug_prob, aug_types=cast_list(aug_types), dataset_aug_prob=dataset_aug_prob, calculate_fid_every=calculate_fid_every, calculate_fid_num_images=calculate_fid_num_images, clear_fid_cache=clear_fid_cache, amp=amp) if generate: model = Trainer(**model_args) model.load(load_from) samples_name = timestamped_filename() checkpoint = model.checkpoint_num dir_result = model.generate(samples_name, num_image_tiles, checkpoint, generate_types) print(f'sample images generated at {dir_result}') return if generate_interpolation: model = Trainer(**model_args) model.load(load_from) samples_name = timestamped_filename() model.generate_interpolation(samples_name, num_image_tiles, num_steps=interpolation_num_steps, save_frames=save_frames) print( f'interpolation generated at {results_dir}/{name}/{samples_name}') return if show_progress: model = Trainer(**model_args) model.show_progress(num_images=num_image_tiles, types=generate_types) return if aug_test: DiffAugmentTest(data=data, image_size=image_size, batch_size=batch_size, types=aug_types, nrow=num_image_tiles) return world_size = torch.cuda.device_count() if world_size == 1 or not multi_gpus: run_training(0, 1, model_args, data, load_from, new, num_train_steps, name, seed) return mp.spawn(run_training, args=(world_size, model_args, data, load_from, new, num_train_steps, name, seed), nprocs=world_size, join=True)