Ejemplo n.º 1
0
def run_training(rank, world_size, model_args, data, load_from, new,
                 num_train_steps, name):
    is_main = rank == 0
    is_ddp = world_size > 1

    if is_ddp:
        set_seed(seed)
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12355'
        dist.init_process_group('nccl', rank=rank, world_size=world_size)

        print(f"{rank + 1}/{world_size} process initialized.")

    model_args.update(is_ddp=is_ddp, rank=rank, world_size=world_size)

    model = Trainer(**model_args)

    if not new:
        model.load(load_from)
    else:
        model.clear()

    model.set_data_src(data)

    for _ in tqdm(range(num_train_steps - model.steps),
                  initial=model.steps,
                  total=num_train_steps,
                  mininterval=10.,
                  desc=f'{name}<{data}>'):
        retry_call(model.train, tries=3, exceptions=NanException)
        if is_main and _ % 50 == 0:
            model.print_log()

    if is_ddp:
        dist.destroy_process_group()
Ejemplo n.º 2
0
def run_training(rank, world_size, model_args, data, load_from, disc_load_from,
                 new, num_train_steps, name, seed):
    is_main = rank == 0
    is_ddp = world_size > 1

    if is_ddp:
        set_seed(seed)
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12355'
        dist.init_process_group('nccl', rank=rank, world_size=world_size)

        print(f"{rank + 1}/{world_size} process initialized.")

    model_args.update(is_ddp=is_ddp, rank=rank, world_size=world_size)

    model = Trainer(**model_args)

    if not new:
        model.load(load_from)

        if disc_load_from > 0:
            model2 = Trainer(**model_args)
            model2.load(disc_load_from)
            model.GAN.D, model.GAN.D_aug, model.GAN.D_cl, model.GAN.D_opt = model2.GAN.D, model2.GAN.D_aug, model2.GAN.D_cl, model2.GAN.D_opt
            del model2
    else:
        model.clear()

    model.set_data_src(data)

    progress_bar = tqdm(initial=model.steps,
                        total=num_train_steps,
                        mininterval=10.,
                        desc=f'{name}<{data}>')
    while model.steps < num_train_steps:
        retry_call(model.train, tries=3, exceptions=NanException)
        progress_bar.n = model.steps
        progress_bar.refresh()
        if is_main and model.steps % 50 == 0:
            model.print_log()

    model.save(model.checkpoint_num)

    if is_ddp:
        dist.destroy_process_group()
Ejemplo n.º 3
0
def train_from_folder(
    data='/hydration/ffhq/pose/trainB',
    results_dir='/hydration/results',
    models_dir='./models',
    name='run_testsg2_001',
    new=False,
    load_from=-1,
    image_size=128,  # 256 if comparing against Hydration; 128 if reproducing StyleGAN 2
    network_capacity=16,  # 64 if comparing against Hydration; 16 if reproducing StyleGAN 2
    transparent=False,
    batch_size=3,  # 1 if comparing against Hydration; 3 if reproducing StyleGAN 2
    gradient_accumulate_every=5,  # 1 if comparing against Hydration; 5 if reproducing StyleGAN 2
    num_train_steps=150000,
    learning_rate=2e-4,  # Always 0.0002
    num_workers=None,
    save_every=1000,
    generate=False,
    generate_interpolation=False,
    save_frames=False,
    num_image_tiles=8,
    trunc_psi=0.75,  # Always 0.75
    fp16=False,
    cl_reg=False,  # Always False
    fq_layers=[],  # [] if comparing against Hydration; [] if reproducing StyleGAN 2
    fq_dict_size=256,  # 256 if comparing against Hydration; 256 if reproducing StyleGAN 2
    attn_layers=[],  # [] if comparing against Hydration; [] if reproducing StyleGAN 2
    no_const=False,  # False if comparing against Hydration; False if reproducing StyleGAN 2
    aug_prob=0.,  # 0.0 if comparing against Hydration; 0.0 if reproducing StyleGAN 2
    dataset_aug_prob=0.,  # 0.0 if comparing against Hydration; 0.0 if reproducing StyleGAN 2
    use_manual_seed=-1,  # -1 for no seed  # 0 if comparing against Hydration; -1 if reproducing StyleGAN 2
    debug_and_crash_mode=False):
    model = Trainer(name,
                    results_dir,
                    models_dir,
                    batch_size=batch_size,
                    gradient_accumulate_every=gradient_accumulate_every,
                    image_size=image_size,
                    network_capacity=network_capacity,
                    transparent=transparent,
                    lr=learning_rate,
                    num_workers=num_workers,
                    save_every=save_every,
                    trunc_psi=trunc_psi,
                    fp16=fp16,
                    cl_reg=cl_reg,
                    fq_layers=fq_layers,
                    fq_dict_size=fq_dict_size,
                    attn_layers=attn_layers,
                    no_const=no_const,
                    aug_prob=aug_prob,
                    dataset_aug_prob=dataset_aug_prob,
                    use_manual_seed=use_manual_seed,
                    debug_and_crash_mode=debug_and_crash_mode)

    if not new:
        model.load(load_from)
    else:
        model.clear()

    if generate:
        now = datetime.now()
        timestamp = now.strftime("%m-%d-%Y_%H-%M-%S")
        samples_name = f'generated-{timestamp}'
        model.evaluate(samples_name, num_image_tiles)
        print(
            f'sample images generated at {results_dir}/{name}/{samples_name}')
        return

    if generate_interpolation:
        now = datetime.now()
        timestamp = now.strftime("%m-%d-%Y_%H-%M-%S")
        samples_name = f'generated-{timestamp}'
        model.generate_interpolation(samples_name,
                                     num_image_tiles,
                                     save_frames=save_frames)
        print(
            f'interpolation generated at {results_dir}/{name}/{samples_name}')
        return

    model.set_data_src(data)

    for _ in tqdm(range(num_train_steps - model.steps),
                  mininterval=10.,
                  desc=f'{name}<{data}>'):
        retry_call(model.train, tries=3, exceptions=NanException)
        if _ % 50 == 0:
            model.print_log()
Ejemplo n.º 4
0
def train_from_folder(data='./data',
                      results_dir='./results',
                      models_dir='./models',
                      name='default',
                      new=False,
                      load_from=-1,
                      image_size=128,
                      network_capacity=16,
                      transparent=False,
                      batch_size=5,
                      gradient_accumulate_every=6,
                      num_train_steps=150000,
                      learning_rate=2e-4,
                      lr_mlp=0.1,
                      ttur_mult=1.5,
                      rel_disc_loss=False,
                      num_workers=None,
                      save_every=1000,
                      generate=False,
                      generate_interpolation=False,
                      save_frames=False,
                      num_image_tiles=8,
                      trunc_psi=0.75,
                      fp16=False,
                      cl_reg=False,
                      fq_layers=[],
                      fq_dict_size=256,
                      attn_layers=[],
                      no_const=False,
                      aug_prob=0.,
                      dataset_aug_prob=0.,
                      gpu_ids=[0]):
    gpu_ids = cast_list(gpu_ids)
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpu_ids))

    from stylegan2_pytorch import Trainer, NanException

    model = Trainer(name,
                    results_dir,
                    models_dir,
                    batch_size=batch_size,
                    gradient_accumulate_every=gradient_accumulate_every,
                    image_size=image_size,
                    network_capacity=network_capacity,
                    transparent=transparent,
                    lr=learning_rate,
                    lr_mlp=lr_mlp,
                    ttur_mult=ttur_mult,
                    rel_disc_loss=rel_disc_loss,
                    num_workers=num_workers,
                    save_every=save_every,
                    trunc_psi=trunc_psi,
                    fp16=fp16,
                    cl_reg=cl_reg,
                    fq_layers=fq_layers,
                    fq_dict_size=fq_dict_size,
                    attn_layers=attn_layers,
                    no_const=no_const,
                    aug_prob=aug_prob,
                    dataset_aug_prob=dataset_aug_prob)

    if not new:
        model.load(load_from)
    else:
        model.clear()

    if generate:
        samples_name = timestamped_filename()
        model.evaluate(samples_name, num_image_tiles)
        print(
            f'sample images generated at {results_dir}/{name}/{samples_name}')
        return

    if generate_interpolation:
        samples_name = timestamped_filename()
        model.generate_interpolation(samples_name,
                                     num_image_tiles,
                                     save_frames=save_frames)
        print(
            f'interpolation generated at {results_dir}/{name}/{samples_name}')
        return

    model.set_data_src(data)

    for _ in tqdm(range(num_train_steps - model.steps),
                  mininterval=10.,
                  desc=f'{name}<{data}>'):
        retry_call(model.train, tries=3, exceptions=NanException)
        if _ % 50 == 0:
            model.print_log()