Ejemplo n.º 1
0
def run_training(rank, world_size, model_args, data, load_from, new,
                 num_train_steps, name):
    is_main = rank == 0
    is_ddp = world_size > 1

    if is_ddp:
        set_seed(seed)
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12355'
        dist.init_process_group('nccl', rank=rank, world_size=world_size)

        print(f"{rank + 1}/{world_size} process initialized.")

    model_args.update(is_ddp=is_ddp, rank=rank, world_size=world_size)

    model = Trainer(**model_args)

    if not new:
        model.load(load_from)
    else:
        model.clear()

    model.set_data_src(data)

    for _ in tqdm(range(num_train_steps - model.steps),
                  initial=model.steps,
                  total=num_train_steps,
                  mininterval=10.,
                  desc=f'{name}<{data}>'):
        retry_call(model.train, tries=3, exceptions=NanException)
        if is_main and _ % 50 == 0:
            model.print_log()

    if is_ddp:
        dist.destroy_process_group()
Ejemplo n.º 2
0
def run_training(rank, world_size, model_args, data, load_from, disc_load_from,
                 new, num_train_steps, name, seed):
    is_main = rank == 0
    is_ddp = world_size > 1

    if is_ddp:
        set_seed(seed)
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12355'
        dist.init_process_group('nccl', rank=rank, world_size=world_size)

        print(f"{rank + 1}/{world_size} process initialized.")

    model_args.update(is_ddp=is_ddp, rank=rank, world_size=world_size)

    model = Trainer(**model_args)

    if not new:
        model.load(load_from)

        if disc_load_from > 0:
            model2 = Trainer(**model_args)
            model2.load(disc_load_from)
            model.GAN.D, model.GAN.D_aug, model.GAN.D_cl, model.GAN.D_opt = model2.GAN.D, model2.GAN.D_aug, model2.GAN.D_cl, model2.GAN.D_opt
            del model2
    else:
        model.clear()

    model.set_data_src(data)

    progress_bar = tqdm(initial=model.steps,
                        total=num_train_steps,
                        mininterval=10.,
                        desc=f'{name}<{data}>')
    while model.steps < num_train_steps:
        retry_call(model.train, tries=3, exceptions=NanException)
        progress_bar.n = model.steps
        progress_bar.refresh()
        if is_main and model.steps % 50 == 0:
            model.print_log()

    model.save(model.checkpoint_num)

    if is_ddp:
        dist.destroy_process_group()
Ejemplo n.º 3
0
def load_latest_model():
    model_args = dict(name='default',
                      results_dir='./results',
                      models_dir='./models',
                      batch_size=8,
                      gradient_accumulate_every=6,
                      image_size=128,
                      network_capacity=4,
                      fmap_max=512,
                      transparent=False,
                      lr=2e-4,
                      lr_mlp=0.1,
                      ttur_mult=1.5,
                      rel_disc_loss=False,
                      num_workers=16,
                      save_every=1000,
                      evaluate_every=1000,
                      trunc_psi=0.75,
                      fp16=False,
                      cl_reg=False,
                      fq_layers=[],
                      fq_dict_size=256,
                      attn_layers=[],
                      no_const=False,
                      aug_prob=0.,
                      aug_types=['translation', 'cutout'],
                      top_k_training=False,
                      generator_top_k_gamma=0.99,
                      generator_top_k_frac=0.5,
                      dataset_aug_prob=0.,
                      calculate_fid_every=None,
                      mixed_prob=0.9,
                      log=False)
    model = Trainer(**model_args)
    model.load(-1)
    model.GAN.train(False)
    return model
Ejemplo n.º 4
0
def train_from_folder(data='./data',
                      results_dir='./results',
                      models_dir='./models',
                      name='default',
                      new=False,
                      load_from=-1,
                      image_size=128,
                      network_capacity=16,
                      transparent=False,
                      batch_size=5,
                      gradient_accumulate_every=6,
                      num_train_steps=150000,
                      learning_rate=2e-4,
                      lr_mlp=0.1,
                      ttur_mult=1.5,
                      rel_disc_loss=False,
                      num_workers=None,
                      save_every=1000,
                      generate=False,
                      generate_interpolation=False,
                      interpolation_num_steps=100,
                      save_frames=False,
                      num_image_tiles=8,
                      trunc_psi=0.75,
                      fp16=False,
                      cl_reg=False,
                      fq_layers=[],
                      fq_dict_size=256,
                      attn_layers=[],
                      no_const=False,
                      aug_prob=0.,
                      aug_types=['translation', 'cutout'],
                      generator_top_k=False,
                      generator_top_k_gamma=0.99,
                      generator_top_k_frac=0.5,
                      dataset_aug_prob=0.,
                      multi_gpus=False,
                      calculate_fid_every=None):
    model_args = dict(name=name,
                      results_dir=results_dir,
                      models_dir=models_dir,
                      batch_size=batch_size,
                      gradient_accumulate_every=gradient_accumulate_every,
                      image_size=image_size,
                      network_capacity=network_capacity,
                      transparent=transparent,
                      lr=learning_rate,
                      lr_mlp=lr_mlp,
                      ttur_mult=ttur_mult,
                      rel_disc_loss=rel_disc_loss,
                      num_workers=num_workers,
                      save_every=save_every,
                      trunc_psi=trunc_psi,
                      fp16=fp16,
                      cl_reg=cl_reg,
                      fq_layers=fq_layers,
                      fq_dict_size=fq_dict_size,
                      attn_layers=attn_layers,
                      no_const=no_const,
                      aug_prob=aug_prob,
                      aug_types=cast_list(aug_types),
                      generator_top_k=generator_top_k,
                      generator_top_k_gamma=generator_top_k_gamma,
                      generator_top_k_frac=generator_top_k_frac,
                      dataset_aug_prob=dataset_aug_prob,
                      calculate_fid_every=calculate_fid_every)

    if generate:
        model = Trainer(**model_args)
        model.load(load_from)
        samples_name = timestamped_filename()
        model.evaluate(samples_name, num_image_tiles)
        print(
            f'sample images generated at {results_dir}/{name}/{samples_name}')
        return

    if generate_interpolation:
        model = Trainer(**model_args)
        model.load(load_from)
        samples_name = timestamped_filename()
        model.generate_interpolation(samples_name,
                                     num_image_tiles,
                                     num_steps=interpolation_num_steps,
                                     save_frames=save_frames)
        print(
            f'interpolation generated at {results_dir}/{name}/{samples_name}')
        return

    world_size = torch.cuda.device_count()

    if world_size == 1 or not multi_gpus:
        run_training(0, 1, model_args, data, load_from, new, num_train_steps,
                     name)
        return

    mp.spawn(run_training,
             args=(world_size, model_args, data, load_from, new,
                   num_train_steps, name),
             nprocs=world_size,
             join=True)
Ejemplo n.º 5
0
def train_from_folder(data='./data',
                      results_dir='./results',
                      models_dir='./models',
                      name='default',
                      new=False,
                      load_from=-1,
                      image_size=128,
                      network_capacity=16,
                      fmap_max=512,
                      transparent=False,
                      batch_size=5,
                      gradient_accumulate_every=6,
                      num_train_steps=150000,
                      learning_rate=2e-4,
                      lr_mlp=0.1,
                      ttur_mult=1.5,
                      rel_disc_loss=False,
                      num_workers=None,
                      save_every=1000,
                      evaluate_every=1000,
                      generate=False,
                      num_generate=1,
                      generate_interpolation=False,
                      interpolation_num_steps=100,
                      save_frames=False,
                      num_image_tiles=8,
                      trunc_psi=0.75,
                      mixed_prob=0.9,
                      fp16=False,
                      no_pl_reg=False,
                      cl_reg=False,
                      fq_layers=[],
                      fq_dict_size=256,
                      attn_layers=[],
                      no_const=False,
                      aug_prob=0.,
                      aug_types=['translation', 'cutout'],
                      top_k_training=False,
                      generator_top_k_gamma=0.99,
                      generator_top_k_frac=0.5,
                      dual_contrast_loss=False,
                      dataset_aug_prob=0.,
                      multi_gpus=False,
                      calculate_fid_every=None,
                      calculate_fid_num_images=12800,
                      clear_fid_cache=False,
                      seed=42,
                      log=False,
                      cat_len=27,
                      softmax_loss=False):
    model_args = dict(name=name,
                      results_dir=results_dir,
                      models_dir=models_dir,
                      batch_size=batch_size,
                      gradient_accumulate_every=gradient_accumulate_every,
                      image_size=image_size,
                      network_capacity=network_capacity,
                      fmap_max=fmap_max,
                      transparent=transparent,
                      lr=learning_rate,
                      lr_mlp=lr_mlp,
                      ttur_mult=ttur_mult,
                      rel_disc_loss=rel_disc_loss,
                      num_workers=num_workers,
                      save_every=save_every,
                      evaluate_every=evaluate_every,
                      num_image_tiles=num_image_tiles,
                      trunc_psi=trunc_psi,
                      fp16=fp16,
                      no_pl_reg=no_pl_reg,
                      cl_reg=cl_reg,
                      fq_layers=fq_layers,
                      fq_dict_size=fq_dict_size,
                      attn_layers=attn_layers,
                      no_const=no_const,
                      aug_prob=aug_prob,
                      aug_types=cast_list(aug_types),
                      top_k_training=top_k_training,
                      generator_top_k_gamma=generator_top_k_gamma,
                      generator_top_k_frac=generator_top_k_frac,
                      dual_contrast_loss=dual_contrast_loss,
                      dataset_aug_prob=dataset_aug_prob,
                      calculate_fid_every=calculate_fid_every,
                      calculate_fid_num_images=calculate_fid_num_images,
                      clear_fid_cache=clear_fid_cache,
                      mixed_prob=mixed_prob,
                      log=log,
                      cat_len=cat_len,
                      softmax_loss=softmax_loss)

    if generate:
        model = Trainer(**model_args)
        model.load(load_from)
        samples_name = timestamped_filename()
        for num in tqdm(range(num_generate)):
            model.evaluate(f'{samples_name}-{num}', num_image_tiles)
        print(
            f'sample images generated at {results_dir}/{name}/{samples_name}')
        return

    if generate_interpolation:
        model = Trainer(**model_args)
        model.load(load_from)
        samples_name = timestamped_filename()
        model.generate_interpolation(samples_name,
                                     num_image_tiles,
                                     num_steps=interpolation_num_steps,
                                     save_frames=save_frames)
        print(
            f'interpolation generated at {results_dir}/{name}/{samples_name}')
        return

    world_size = torch.cuda.device_count()
    print("Number of available CUDA devices: ", world_size)

    if world_size == 1 or not multi_gpus:
        run_training(0, 1, model_args, data, load_from, new, num_train_steps,
                     name, seed)
        return

    mp.spawn(run_training,
             args=(world_size, model_args, data, load_from, new,
                   num_train_steps, name, seed),
             nprocs=world_size,
             join=True)
Ejemplo n.º 6
0
def train_from_folder(
    data='/hydration/ffhq/pose/trainB',
    results_dir='/hydration/results',
    models_dir='./models',
    name='run_testsg2_001',
    new=False,
    load_from=-1,
    image_size=128,  # 256 if comparing against Hydration; 128 if reproducing StyleGAN 2
    network_capacity=16,  # 64 if comparing against Hydration; 16 if reproducing StyleGAN 2
    transparent=False,
    batch_size=3,  # 1 if comparing against Hydration; 3 if reproducing StyleGAN 2
    gradient_accumulate_every=5,  # 1 if comparing against Hydration; 5 if reproducing StyleGAN 2
    num_train_steps=150000,
    learning_rate=2e-4,  # Always 0.0002
    num_workers=None,
    save_every=1000,
    generate=False,
    generate_interpolation=False,
    save_frames=False,
    num_image_tiles=8,
    trunc_psi=0.75,  # Always 0.75
    fp16=False,
    cl_reg=False,  # Always False
    fq_layers=[],  # [] if comparing against Hydration; [] if reproducing StyleGAN 2
    fq_dict_size=256,  # 256 if comparing against Hydration; 256 if reproducing StyleGAN 2
    attn_layers=[],  # [] if comparing against Hydration; [] if reproducing StyleGAN 2
    no_const=False,  # False if comparing against Hydration; False if reproducing StyleGAN 2
    aug_prob=0.,  # 0.0 if comparing against Hydration; 0.0 if reproducing StyleGAN 2
    dataset_aug_prob=0.,  # 0.0 if comparing against Hydration; 0.0 if reproducing StyleGAN 2
    use_manual_seed=-1,  # -1 for no seed  # 0 if comparing against Hydration; -1 if reproducing StyleGAN 2
    debug_and_crash_mode=False):
    model = Trainer(name,
                    results_dir,
                    models_dir,
                    batch_size=batch_size,
                    gradient_accumulate_every=gradient_accumulate_every,
                    image_size=image_size,
                    network_capacity=network_capacity,
                    transparent=transparent,
                    lr=learning_rate,
                    num_workers=num_workers,
                    save_every=save_every,
                    trunc_psi=trunc_psi,
                    fp16=fp16,
                    cl_reg=cl_reg,
                    fq_layers=fq_layers,
                    fq_dict_size=fq_dict_size,
                    attn_layers=attn_layers,
                    no_const=no_const,
                    aug_prob=aug_prob,
                    dataset_aug_prob=dataset_aug_prob,
                    use_manual_seed=use_manual_seed,
                    debug_and_crash_mode=debug_and_crash_mode)

    if not new:
        model.load(load_from)
    else:
        model.clear()

    if generate:
        now = datetime.now()
        timestamp = now.strftime("%m-%d-%Y_%H-%M-%S")
        samples_name = f'generated-{timestamp}'
        model.evaluate(samples_name, num_image_tiles)
        print(
            f'sample images generated at {results_dir}/{name}/{samples_name}')
        return

    if generate_interpolation:
        now = datetime.now()
        timestamp = now.strftime("%m-%d-%Y_%H-%M-%S")
        samples_name = f'generated-{timestamp}'
        model.generate_interpolation(samples_name,
                                     num_image_tiles,
                                     save_frames=save_frames)
        print(
            f'interpolation generated at {results_dir}/{name}/{samples_name}')
        return

    model.set_data_src(data)

    for _ in tqdm(range(num_train_steps - model.steps),
                  mininterval=10.,
                  desc=f'{name}<{data}>'):
        retry_call(model.train, tries=3, exceptions=NanException)
        if _ % 50 == 0:
            model.print_log()
Ejemplo n.º 7
0
def train_from_folder(data='./data',
                      results_dir='./results',
                      models_dir='./models',
                      name='default',
                      new=False,
                      load_from=-1,
                      image_size=128,
                      network_capacity=16,
                      transparent=False,
                      batch_size=5,
                      gradient_accumulate_every=6,
                      num_train_steps=150000,
                      learning_rate=2e-4,
                      lr_mlp=0.1,
                      ttur_mult=1.5,
                      rel_disc_loss=False,
                      num_workers=None,
                      save_every=1000,
                      generate=False,
                      generate_interpolation=False,
                      save_frames=False,
                      num_image_tiles=8,
                      trunc_psi=0.75,
                      fp16=False,
                      cl_reg=False,
                      fq_layers=[],
                      fq_dict_size=256,
                      attn_layers=[],
                      no_const=False,
                      aug_prob=0.,
                      dataset_aug_prob=0.,
                      gpu_ids=[0]):
    gpu_ids = cast_list(gpu_ids)
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpu_ids))

    from stylegan2_pytorch import Trainer, NanException

    model = Trainer(name,
                    results_dir,
                    models_dir,
                    batch_size=batch_size,
                    gradient_accumulate_every=gradient_accumulate_every,
                    image_size=image_size,
                    network_capacity=network_capacity,
                    transparent=transparent,
                    lr=learning_rate,
                    lr_mlp=lr_mlp,
                    ttur_mult=ttur_mult,
                    rel_disc_loss=rel_disc_loss,
                    num_workers=num_workers,
                    save_every=save_every,
                    trunc_psi=trunc_psi,
                    fp16=fp16,
                    cl_reg=cl_reg,
                    fq_layers=fq_layers,
                    fq_dict_size=fq_dict_size,
                    attn_layers=attn_layers,
                    no_const=no_const,
                    aug_prob=aug_prob,
                    dataset_aug_prob=dataset_aug_prob)

    if not new:
        model.load(load_from)
    else:
        model.clear()

    if generate:
        samples_name = timestamped_filename()
        model.evaluate(samples_name, num_image_tiles)
        print(
            f'sample images generated at {results_dir}/{name}/{samples_name}')
        return

    if generate_interpolation:
        samples_name = timestamped_filename()
        model.generate_interpolation(samples_name,
                                     num_image_tiles,
                                     save_frames=save_frames)
        print(
            f'interpolation generated at {results_dir}/{name}/{samples_name}')
        return

    model.set_data_src(data)

    for _ in tqdm(range(num_train_steps - model.steps),
                  mininterval=10.,
                  desc=f'{name}<{data}>'):
        retry_call(model.train, tries=3, exceptions=NanException)
        if _ % 50 == 0:
            model.print_log()
Ejemplo n.º 8
0
def synthesize_gwas_data(
    n_causal=100,
    exp_var=0.5,
    img_size=512,
    models_dir='../models',
    checkpoint=-1,
    same_noise=True,
    psi=0.6,
    name='stylegan2_healthy',
    mult_scale=1.,
    subset=None,
    seed=123,
    wdir='.',
):
    '''synthesize images from latent codes with StyleGAN2

    Load latent codes from corresponding LATENT_BOLT_TEMPL and generate synthetic images via StyleGAN2 model

    # Parameters
    n_causal (int): number of causal SNPs
    exp_var (float in [0, 1]): percentage of explained variance by causal SNPs
    img_size (int): size of images to be created (determined by training scheme of StyleGAN2)
    checkpoint (int): which epoch to use; if -1, load latest epoch in model directory
    same_noise (bool): use the same noise vector for all images
    psi (float in [0, 1]): truncation parameter for images, trade-off between image quality and diversity
    name (str): name of StyleGAN2 model in models directory
    mult_scale (float): multiplier for standard-normal style vector (input), to increase/decrease diversity of images
    subset (None or int): only create subset of images, for debugging
    seed (int): random seed
    '''
    pth = join(wdir, LATENT_DIR, get_latent_bolt(exp_var, n_causal, seed))
    latent = pd.read_csv(pth, sep=' ', index_col=1).drop('FID', 1)
    if subset is not None:
        latent = latent.sample(subset, random_state=123)

    T = Trainer(name,
                models_dir=models_dir,
                results_dir=models_dir,
                image_size=img_size,
                network_capacity=16)
    T.load(checkpoint)

    if psi is None:
        psi = T.trunc_psi

    if same_noise:
        N = image_noise(1, img_size)

    out_img_dir = join(wdir,
                       get_img_dir(name, exp_var, n_causal, mult_scale, seed))
    os.makedirs(out_img_dir, exist_ok=True)

    for i, lat in tqdm(latent.iterrows(), total=len(latent)):
        if not same_noise:
            N = image_noise(1, img_size)
        L = [(mult_scale *
              torch.from_numpy(lat.values).view(1, -1).float().cuda(), 8)]
        gen = T.generate_truncated(T.GAN.S, T.GAN.G, L, N, trunc_psi=psi)
        gen = (gen.cpu().double()[0].permute(1, 2, 0).numpy() * 255).astype(
            np.uint8)
        gen = PIL.Image.fromarray(gen)
        path = join(out_img_dir, f'{i}.jpg')
        gen.save(path)

    cwd = os.getcwd()
    paths = [join(cwd, out_img_dir, f'{i}.jpg') for i in latent.index]
    df = pd.DataFrame(np.array([latent.index, paths]).T,
                      columns=['IID', 'path'])
    csv_path = out_img_dir + '.csv'
    df.to_csv(csv_path, index=False)