Ejemplo n.º 1
0
def train_from_folder(data='./data',
                      results_dir='./results',
                      models_dir='./models',
                      name='default',
                      new=False,
                      load_from=-1,
                      image_size=128,
                      network_capacity=16,
                      transparent=False,
                      batch_size=5,
                      gradient_accumulate_every=6,
                      num_train_steps=150000,
                      learning_rate=2e-4,
                      lr_mlp=0.1,
                      ttur_mult=1.5,
                      rel_disc_loss=False,
                      num_workers=None,
                      save_every=1000,
                      generate=False,
                      generate_interpolation=False,
                      interpolation_num_steps=100,
                      save_frames=False,
                      num_image_tiles=8,
                      trunc_psi=0.75,
                      fp16=False,
                      cl_reg=False,
                      fq_layers=[],
                      fq_dict_size=256,
                      attn_layers=[],
                      no_const=False,
                      aug_prob=0.,
                      aug_types=['translation', 'cutout'],
                      generator_top_k=False,
                      generator_top_k_gamma=0.99,
                      generator_top_k_frac=0.5,
                      dataset_aug_prob=0.,
                      multi_gpus=False,
                      calculate_fid_every=None):
    model_args = dict(name=name,
                      results_dir=results_dir,
                      models_dir=models_dir,
                      batch_size=batch_size,
                      gradient_accumulate_every=gradient_accumulate_every,
                      image_size=image_size,
                      network_capacity=network_capacity,
                      transparent=transparent,
                      lr=learning_rate,
                      lr_mlp=lr_mlp,
                      ttur_mult=ttur_mult,
                      rel_disc_loss=rel_disc_loss,
                      num_workers=num_workers,
                      save_every=save_every,
                      trunc_psi=trunc_psi,
                      fp16=fp16,
                      cl_reg=cl_reg,
                      fq_layers=fq_layers,
                      fq_dict_size=fq_dict_size,
                      attn_layers=attn_layers,
                      no_const=no_const,
                      aug_prob=aug_prob,
                      aug_types=cast_list(aug_types),
                      generator_top_k=generator_top_k,
                      generator_top_k_gamma=generator_top_k_gamma,
                      generator_top_k_frac=generator_top_k_frac,
                      dataset_aug_prob=dataset_aug_prob,
                      calculate_fid_every=calculate_fid_every)

    if generate:
        model = Trainer(**model_args)
        model.load(load_from)
        samples_name = timestamped_filename()
        model.evaluate(samples_name, num_image_tiles)
        print(
            f'sample images generated at {results_dir}/{name}/{samples_name}')
        return

    if generate_interpolation:
        model = Trainer(**model_args)
        model.load(load_from)
        samples_name = timestamped_filename()
        model.generate_interpolation(samples_name,
                                     num_image_tiles,
                                     num_steps=interpolation_num_steps,
                                     save_frames=save_frames)
        print(
            f'interpolation generated at {results_dir}/{name}/{samples_name}')
        return

    world_size = torch.cuda.device_count()

    if world_size == 1 or not multi_gpus:
        run_training(0, 1, model_args, data, load_from, new, num_train_steps,
                     name)
        return

    mp.spawn(run_training,
             args=(world_size, model_args, data, load_from, new,
                   num_train_steps, name),
             nprocs=world_size,
             join=True)
Ejemplo n.º 2
0
def train_from_folder(
    data='/hydration/ffhq/pose/trainB',
    results_dir='/hydration/results',
    models_dir='./models',
    name='run_testsg2_001',
    new=False,
    load_from=-1,
    image_size=128,  # 256 if comparing against Hydration; 128 if reproducing StyleGAN 2
    network_capacity=16,  # 64 if comparing against Hydration; 16 if reproducing StyleGAN 2
    transparent=False,
    batch_size=3,  # 1 if comparing against Hydration; 3 if reproducing StyleGAN 2
    gradient_accumulate_every=5,  # 1 if comparing against Hydration; 5 if reproducing StyleGAN 2
    num_train_steps=150000,
    learning_rate=2e-4,  # Always 0.0002
    num_workers=None,
    save_every=1000,
    generate=False,
    generate_interpolation=False,
    save_frames=False,
    num_image_tiles=8,
    trunc_psi=0.75,  # Always 0.75
    fp16=False,
    cl_reg=False,  # Always False
    fq_layers=[],  # [] if comparing against Hydration; [] if reproducing StyleGAN 2
    fq_dict_size=256,  # 256 if comparing against Hydration; 256 if reproducing StyleGAN 2
    attn_layers=[],  # [] if comparing against Hydration; [] if reproducing StyleGAN 2
    no_const=False,  # False if comparing against Hydration; False if reproducing StyleGAN 2
    aug_prob=0.,  # 0.0 if comparing against Hydration; 0.0 if reproducing StyleGAN 2
    dataset_aug_prob=0.,  # 0.0 if comparing against Hydration; 0.0 if reproducing StyleGAN 2
    use_manual_seed=-1,  # -1 for no seed  # 0 if comparing against Hydration; -1 if reproducing StyleGAN 2
    debug_and_crash_mode=False):
    model = Trainer(name,
                    results_dir,
                    models_dir,
                    batch_size=batch_size,
                    gradient_accumulate_every=gradient_accumulate_every,
                    image_size=image_size,
                    network_capacity=network_capacity,
                    transparent=transparent,
                    lr=learning_rate,
                    num_workers=num_workers,
                    save_every=save_every,
                    trunc_psi=trunc_psi,
                    fp16=fp16,
                    cl_reg=cl_reg,
                    fq_layers=fq_layers,
                    fq_dict_size=fq_dict_size,
                    attn_layers=attn_layers,
                    no_const=no_const,
                    aug_prob=aug_prob,
                    dataset_aug_prob=dataset_aug_prob,
                    use_manual_seed=use_manual_seed,
                    debug_and_crash_mode=debug_and_crash_mode)

    if not new:
        model.load(load_from)
    else:
        model.clear()

    if generate:
        now = datetime.now()
        timestamp = now.strftime("%m-%d-%Y_%H-%M-%S")
        samples_name = f'generated-{timestamp}'
        model.evaluate(samples_name, num_image_tiles)
        print(
            f'sample images generated at {results_dir}/{name}/{samples_name}')
        return

    if generate_interpolation:
        now = datetime.now()
        timestamp = now.strftime("%m-%d-%Y_%H-%M-%S")
        samples_name = f'generated-{timestamp}'
        model.generate_interpolation(samples_name,
                                     num_image_tiles,
                                     save_frames=save_frames)
        print(
            f'interpolation generated at {results_dir}/{name}/{samples_name}')
        return

    model.set_data_src(data)

    for _ in tqdm(range(num_train_steps - model.steps),
                  mininterval=10.,
                  desc=f'{name}<{data}>'):
        retry_call(model.train, tries=3, exceptions=NanException)
        if _ % 50 == 0:
            model.print_log()
Ejemplo n.º 3
0
def train_from_folder(data='./data',
                      results_dir='./results',
                      models_dir='./models',
                      name='default',
                      new=False,
                      load_from=-1,
                      image_size=128,
                      network_capacity=16,
                      fmap_max=512,
                      transparent=False,
                      batch_size=5,
                      gradient_accumulate_every=6,
                      num_train_steps=150000,
                      learning_rate=2e-4,
                      lr_mlp=0.1,
                      ttur_mult=1.5,
                      rel_disc_loss=False,
                      num_workers=None,
                      save_every=1000,
                      evaluate_every=1000,
                      generate=False,
                      num_generate=1,
                      generate_interpolation=False,
                      interpolation_num_steps=100,
                      save_frames=False,
                      num_image_tiles=8,
                      trunc_psi=0.75,
                      mixed_prob=0.9,
                      fp16=False,
                      no_pl_reg=False,
                      cl_reg=False,
                      fq_layers=[],
                      fq_dict_size=256,
                      attn_layers=[],
                      no_const=False,
                      aug_prob=0.,
                      aug_types=['translation', 'cutout'],
                      top_k_training=False,
                      generator_top_k_gamma=0.99,
                      generator_top_k_frac=0.5,
                      dual_contrast_loss=False,
                      dataset_aug_prob=0.,
                      multi_gpus=False,
                      calculate_fid_every=None,
                      calculate_fid_num_images=12800,
                      clear_fid_cache=False,
                      seed=42,
                      log=False,
                      cat_len=27,
                      softmax_loss=False):
    model_args = dict(name=name,
                      results_dir=results_dir,
                      models_dir=models_dir,
                      batch_size=batch_size,
                      gradient_accumulate_every=gradient_accumulate_every,
                      image_size=image_size,
                      network_capacity=network_capacity,
                      fmap_max=fmap_max,
                      transparent=transparent,
                      lr=learning_rate,
                      lr_mlp=lr_mlp,
                      ttur_mult=ttur_mult,
                      rel_disc_loss=rel_disc_loss,
                      num_workers=num_workers,
                      save_every=save_every,
                      evaluate_every=evaluate_every,
                      num_image_tiles=num_image_tiles,
                      trunc_psi=trunc_psi,
                      fp16=fp16,
                      no_pl_reg=no_pl_reg,
                      cl_reg=cl_reg,
                      fq_layers=fq_layers,
                      fq_dict_size=fq_dict_size,
                      attn_layers=attn_layers,
                      no_const=no_const,
                      aug_prob=aug_prob,
                      aug_types=cast_list(aug_types),
                      top_k_training=top_k_training,
                      generator_top_k_gamma=generator_top_k_gamma,
                      generator_top_k_frac=generator_top_k_frac,
                      dual_contrast_loss=dual_contrast_loss,
                      dataset_aug_prob=dataset_aug_prob,
                      calculate_fid_every=calculate_fid_every,
                      calculate_fid_num_images=calculate_fid_num_images,
                      clear_fid_cache=clear_fid_cache,
                      mixed_prob=mixed_prob,
                      log=log,
                      cat_len=cat_len,
                      softmax_loss=softmax_loss)

    if generate:
        model = Trainer(**model_args)
        model.load(load_from)
        samples_name = timestamped_filename()
        for num in tqdm(range(num_generate)):
            model.evaluate(f'{samples_name}-{num}', num_image_tiles)
        print(
            f'sample images generated at {results_dir}/{name}/{samples_name}')
        return

    if generate_interpolation:
        model = Trainer(**model_args)
        model.load(load_from)
        samples_name = timestamped_filename()
        model.generate_interpolation(samples_name,
                                     num_image_tiles,
                                     num_steps=interpolation_num_steps,
                                     save_frames=save_frames)
        print(
            f'interpolation generated at {results_dir}/{name}/{samples_name}')
        return

    world_size = torch.cuda.device_count()
    print("Number of available CUDA devices: ", world_size)

    if world_size == 1 or not multi_gpus:
        run_training(0, 1, model_args, data, load_from, new, num_train_steps,
                     name, seed)
        return

    mp.spawn(run_training,
             args=(world_size, model_args, data, load_from, new,
                   num_train_steps, name, seed),
             nprocs=world_size,
             join=True)
Ejemplo n.º 4
0
def train_from_folder(data='./data',
                      results_dir='./results',
                      models_dir='./models',
                      name='default',
                      new=False,
                      load_from=-1,
                      image_size=128,
                      network_capacity=16,
                      transparent=False,
                      batch_size=5,
                      gradient_accumulate_every=6,
                      num_train_steps=150000,
                      learning_rate=2e-4,
                      lr_mlp=0.1,
                      ttur_mult=1.5,
                      rel_disc_loss=False,
                      num_workers=None,
                      save_every=1000,
                      generate=False,
                      generate_interpolation=False,
                      save_frames=False,
                      num_image_tiles=8,
                      trunc_psi=0.75,
                      fp16=False,
                      cl_reg=False,
                      fq_layers=[],
                      fq_dict_size=256,
                      attn_layers=[],
                      no_const=False,
                      aug_prob=0.,
                      dataset_aug_prob=0.,
                      gpu_ids=[0]):
    gpu_ids = cast_list(gpu_ids)
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpu_ids))

    from stylegan2_pytorch import Trainer, NanException

    model = Trainer(name,
                    results_dir,
                    models_dir,
                    batch_size=batch_size,
                    gradient_accumulate_every=gradient_accumulate_every,
                    image_size=image_size,
                    network_capacity=network_capacity,
                    transparent=transparent,
                    lr=learning_rate,
                    lr_mlp=lr_mlp,
                    ttur_mult=ttur_mult,
                    rel_disc_loss=rel_disc_loss,
                    num_workers=num_workers,
                    save_every=save_every,
                    trunc_psi=trunc_psi,
                    fp16=fp16,
                    cl_reg=cl_reg,
                    fq_layers=fq_layers,
                    fq_dict_size=fq_dict_size,
                    attn_layers=attn_layers,
                    no_const=no_const,
                    aug_prob=aug_prob,
                    dataset_aug_prob=dataset_aug_prob)

    if not new:
        model.load(load_from)
    else:
        model.clear()

    if generate:
        samples_name = timestamped_filename()
        model.evaluate(samples_name, num_image_tiles)
        print(
            f'sample images generated at {results_dir}/{name}/{samples_name}')
        return

    if generate_interpolation:
        samples_name = timestamped_filename()
        model.generate_interpolation(samples_name,
                                     num_image_tiles,
                                     save_frames=save_frames)
        print(
            f'interpolation generated at {results_dir}/{name}/{samples_name}')
        return

    model.set_data_src(data)

    for _ in tqdm(range(num_train_steps - model.steps),
                  mininterval=10.,
                  desc=f'{name}<{data}>'):
        retry_call(model.train, tries=3, exceptions=NanException)
        if _ % 50 == 0:
            model.print_log()