示例#1
0
def train_from_folder(
    data = './data',
    results_dir = './results',
    models_dir = './models',
    name = 'default',
    new = False,
    load_from = -1,
    image_size = 256,
    optimizer = 'adam',
    fmap_max = 512,
    transparent = False,
    batch_size = 10,
    gradient_accumulate_every = 4,
    num_train_steps = 150000,
    learning_rate = 2e-4,
    save_every = 1000,
    evaluate_every = 1000,
    generate = False,
    generate_interpolation = False,
    attn_res_layers = [32],
    sle_spatial = False,
    disc_output_size = 1,
    interpolation_num_steps = 100,
    save_frames = False,
    num_image_tiles = 8,
    trunc_psi = 0.75,
    aug_prob = 0.,
    aug_types = ['cutout', 'translation'],
    dataset_aug_prob = 0.,
    multi_gpus = False,
    calculate_fid_every = None,
    seed = 42
):
    model_args = dict(
        name = name,
        results_dir = results_dir,
        models_dir = models_dir,
        batch_size = batch_size,
        gradient_accumulate_every = gradient_accumulate_every,
        attn_res_layers = cast_list(attn_res_layers),
        sle_spatial = sle_spatial,
        disc_output_size = disc_output_size,
        image_size = image_size,
        optimizer = optimizer,
        fmap_max = fmap_max,
        transparent = transparent,
        lr = learning_rate,
        save_every = save_every,
        evaluate_every = evaluate_every,
        trunc_psi = trunc_psi,
        aug_prob = aug_prob,
        aug_types = cast_list(aug_types),
        dataset_aug_prob = dataset_aug_prob,
        calculate_fid_every = calculate_fid_every
    )

    if generate:
        model = Trainer(**model_args)
        model.load(load_from)
        samples_name = timestamped_filename()
        model.evaluate(samples_name, num_image_tiles)
        print(f'sample images generated at {results_dir}/{name}/{samples_name}')
        return

    if generate_interpolation:
        model = Trainer(**model_args)
        model.load(load_from)
        samples_name = timestamped_filename()
        model.generate_interpolation(samples_name, num_image_tiles, num_steps = interpolation_num_steps, save_frames = save_frames)
        print(f'interpolation generated at {results_dir}/{name}/{samples_name}')
        return

    world_size = torch.cuda.device_count()

    if world_size == 1 or not multi_gpus:
        run_training(0, 1, model_args, data, load_from, new, num_train_steps, name, seed)
        return

    mp.spawn(run_training,
        args=(world_size, model_args, data, load_from, new, num_train_steps, name, seed),
        nprocs=world_size,
        join=True)
示例#2
0
def run_training(rank, world_size, model_args, data, load_from, new, num_train_steps, name, seed):
    is_main = rank == 0
    is_ddp = world_size > 1

    if is_ddp:
        set_seed(seed)
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12355'
        dist.init_process_group('nccl', rank=rank, world_size=world_size)

        print(f"{rank + 1}/{world_size} process initialized.")

    model_args.update(
        is_ddp = is_ddp,
        rank = rank,
        world_size = world_size
    )

    model = Trainer(**model_args)

    if not new:
        model.load(load_from)
    else:
        model.clear()

    model.set_data_src(data)

    for _ in tqdm(range(num_train_steps - model.steps), initial = model.steps, total = num_train_steps, mininterval=10., desc=f'{name}<{data}>'):
        retry_call(model.train, tries=3, exceptions=NanException)
        if is_main and _ % 50 == 0:
            model.print_log()

    model.save(model.checkpoint_num)

    if is_ddp:
        dist.destroy_process_group()
示例#3
0
文件: video.py 项目: a3lab/cavernes
def load_model(num=-1, path=None):
    model = Trainer(save_every=1)
    model.load(num, path)
    model.eval()
    return model
def get_trainer(
    data='./data',
    results_dir='./results',
    models_dir='./models',
    name='default',
    new=False,
    load_from=-1,
    image_size=256,
    optimizer='adam',
    fmap_max=512,
    transparent=False,
    batch_size=10,
    gradient_accumulate_every=4,
    num_train_steps=150000,
    learning_rate=2e-4,
    save_every=1000,
    evaluate_every=1000,
    generate=False,
    generate_interpolation=False,
    attn_res_layers=[32],
    sle_spatial=False,
    disc_output_size=1,
    antialias=False,
    interpolation_num_steps=100,
    save_frames=False,
    num_image_tiles=8,
    trunc_psi=0.75,
    aug_prob=None,
    aug_types=['cutout', 'translation'],
    dataset_aug_prob=0.,
    multi_gpus=False,
    calculate_fid_every=None,
    seed=42,
    amp=False
):
    def cast_list(el):
        return el if isinstance(el, list) else [el]

    model_args = dict(
        name=name,
        results_dir=results_dir,
        models_dir=models_dir,
        batch_size=batch_size,
        gradient_accumulate_every=gradient_accumulate_every,
        attn_res_layers=cast_list(attn_res_layers),
        sle_spatial=sle_spatial,
        disc_output_size=disc_output_size,
        antialias=antialias,
        image_size=image_size,
        optimizer=optimizer,
        fmap_max=fmap_max,
        transparent=transparent,
        lr=learning_rate,
        save_every=save_every,
        evaluate_every=evaluate_every,
        trunc_psi=trunc_psi,
        aug_prob=aug_prob,
        aug_types=cast_list(aug_types),
        dataset_aug_prob=dataset_aug_prob,
        calculate_fid_every=calculate_fid_every,
        amp=amp
    )

    ret = Trainer(**model_args)
    ret.load(load_from)
    return ret
示例#5
0
def train_from_folder(
    data='./data',
    results_dir='./results',
    models_dir='./models',
    name='default',
    new=False,
    load_from=-1,
    image_size=256,
    optimizer='adam',
    fmap_max=512,
    transparent=False,
    greyscale=False,
    batch_size=10,
    gradient_accumulate_every=4,
    num_train_steps=150000,
    learning_rate=2e-4,
    save_every=1000,
    evaluate_every=1000,
    generate=False,
    generate_types=['default', 'ema'],
    generate_interpolation=False,
    aug_test=False,
    aug_prob=None,
    aug_types=['cutout', 'translation'],
    dataset_aug_prob=0.,
    attn_res_layers=[32],
    freq_chan_attn=False,
    disc_output_size=1,
    antialias=False,
    interpolation_num_steps=100,
    save_frames=False,
    num_image_tiles=None,
    trunc_psi=0.75,
    multi_gpus=False,
    calculate_fid_every=None,
    calculate_fid_num_images=12800,
    clear_fid_cache=False,
    seed=42,
    amp=False,
    show_progress=False,
):
    num_image_tiles = default(num_image_tiles, 4 if image_size > 512 else 8)

    model_args = dict(name=name,
                      results_dir=results_dir,
                      models_dir=models_dir,
                      batch_size=batch_size,
                      gradient_accumulate_every=gradient_accumulate_every,
                      attn_res_layers=cast_list(attn_res_layers),
                      freq_chan_attn=freq_chan_attn,
                      disc_output_size=disc_output_size,
                      antialias=antialias,
                      image_size=image_size,
                      num_image_tiles=num_image_tiles,
                      optimizer=optimizer,
                      fmap_max=fmap_max,
                      transparent=transparent,
                      greyscale=greyscale,
                      lr=learning_rate,
                      save_every=save_every,
                      evaluate_every=evaluate_every,
                      trunc_psi=trunc_psi,
                      aug_prob=aug_prob,
                      aug_types=cast_list(aug_types),
                      dataset_aug_prob=dataset_aug_prob,
                      calculate_fid_every=calculate_fid_every,
                      calculate_fid_num_images=calculate_fid_num_images,
                      clear_fid_cache=clear_fid_cache,
                      amp=amp)

    if generate:
        model = Trainer(**model_args)
        model.load(load_from)
        samples_name = timestamped_filename()
        checkpoint = model.checkpoint_num
        dir_result = model.generate(samples_name, num_image_tiles, checkpoint,
                                    generate_types)
        print(f'sample images generated at {dir_result}')
        return

    if generate_interpolation:
        model = Trainer(**model_args)
        model.load(load_from)
        samples_name = timestamped_filename()
        model.generate_interpolation(samples_name,
                                     num_image_tiles,
                                     num_steps=interpolation_num_steps,
                                     save_frames=save_frames)
        print(
            f'interpolation generated at {results_dir}/{name}/{samples_name}')
        return

    if show_progress:
        model = Trainer(**model_args)
        model.show_progress(num_images=num_image_tiles, types=generate_types)
        return

    if aug_test:
        DiffAugmentTest(data=data,
                        image_size=image_size,
                        batch_size=batch_size,
                        types=aug_types,
                        nrow=num_image_tiles)
        return

    world_size = torch.cuda.device_count()

    if world_size == 1 or not multi_gpus:
        run_training(0, 1, model_args, data, load_from, new, num_train_steps,
                     name, seed)
        return

    mp.spawn(run_training,
             args=(world_size, model_args, data, load_from, new,
                   num_train_steps, name, seed),
             nprocs=world_size,
             join=True)