Esempio n. 1
0
def run_training(rank, world_size, model_args, data, load_from, disc_load_from,
                 new, num_train_steps, name, seed):
    is_main = rank == 0
    is_ddp = world_size > 1

    if is_ddp:
        set_seed(seed)
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12355'
        dist.init_process_group('nccl', rank=rank, world_size=world_size)

        print(f"{rank + 1}/{world_size} process initialized.")

    model_args.update(is_ddp=is_ddp, rank=rank, world_size=world_size)

    model = Trainer(**model_args)

    if not new:
        model.load(load_from)

        if disc_load_from > 0:
            model2 = Trainer(**model_args)
            model2.load(disc_load_from)
            model.GAN.D, model.GAN.D_aug, model.GAN.D_cl, model.GAN.D_opt = model2.GAN.D, model2.GAN.D_aug, model2.GAN.D_cl, model2.GAN.D_opt
            del model2
    else:
        model.clear()

    model.set_data_src(data)

    progress_bar = tqdm(initial=model.steps,
                        total=num_train_steps,
                        mininterval=10.,
                        desc=f'{name}<{data}>')
    while model.steps < num_train_steps:
        retry_call(model.train, tries=3, exceptions=NanException)
        progress_bar.n = model.steps
        progress_bar.refresh()
        if is_main and model.steps % 50 == 0:
            model.print_log()

    model.save(model.checkpoint_num)

    if is_ddp:
        dist.destroy_process_group()
Esempio n. 2
0
def run_training(rank, world_size, model_args, data, load_from, new,
                 num_train_steps, name):
    is_main = rank == 0
    is_ddp = world_size > 1

    if is_ddp:
        set_seed(seed)
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = "12355"
        dist.init_process_group("nccl", rank=rank, world_size=world_size)

        print(f"{rank + 1}/{world_size} process initialized.")

    model_args.update(is_ddp=is_ddp, rank=rank, world_size=world_size)

    model = Trainer(**model_args)

    if not new:
        model.load(load_from)
    else:
        model.clear()

    model.set_data_src(data)

    for _ in tqdm(
            range(num_train_steps - model.steps),
            initial=model.steps,
            total=num_train_steps,
            mininterval=10.0,
            desc=f"{name}<{data}>",
    ):
        # retry_call(model.train, tries=3, exceptions=NanException)
        model.train()
        if is_main and _ % 50 == 0:
            model.print_log()

    model.save(model.checkpoint_num)

    if is_ddp:
        dist.destroy_process_group()
Esempio n. 3
0
def run_training(rank, world_size, model_args, data, attn_data, load_from, new,
                 num_train_steps, name, seed):
    is_main = rank == 0
    is_ddp = world_size > 1

    if is_ddp:
        set_seed(seed)
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12355'
        dist.init_process_group('nccl', rank=rank, world_size=world_size)

        print(f"{rank + 1}/{world_size} process initialized.")

    model_args.update(is_ddp=is_ddp, rank=rank, world_size=world_size)

    model = Trainer(**model_args)

    if not new:
        model.load(load_from)
    else:
        model.clear()

    model.set_data_src(data, attn_data)

    for _ in tqdm(range(num_train_steps - model.steps),
                  initial=model.steps,
                  total=num_train_steps,
                  mininterval=10.,
                  desc=f'{name}<{data}>'):
        retry_call(model.train, tries=3, exceptions=NanException)
        if is_main and _ % 200 == 0:
            print(model.gen_log())

    model.save(model.checkpoint_num)

    if is_ddp:
        dist.destroy_process_group()