params = get_params(train_args.model)

# instantiate GPT-like decoder model
model = GPTNeoX(num_tokens=params["vocab_size"],
                dim=params["hidden_dim"],
                seq_len=params["seq_len"],
                depth=params["n_layers"],
                heads=params["n_heads"],
                dim_head=params["dim_head"])

model = AutoregressiveWrapper(model)
dset_params = params["dataset"]
deepspeed.init_distributed(dist_backend='nccl')
torch.distributed.barrier(
)  # barrier will force processes to stop until *all* processes have reached the barrier
if is_main(train_args):
    prepare_data(dset_params["name"])
    torch.distributed.barrier(
    )  # barrier will force processes to stop until *all* processes have reached the barrier
else:
    torch.distributed.barrier()

# prepare enwik8 data
data_train, data_val = read_enwik8_data(dset_params["path"])
train_dataset = TextSamplerDataset(data_train, params["seq_len"])
val_dataset = TextSamplerDataset(data_val, params["seq_len"])
val_loader = cycle(DataLoader(val_dataset, batch_size=params["batch_size"]))

# optimizer
optim = torch.optim.Adam(model.parameters(), lr=params["learning_rate"])
예제 #2
0
                dim=params["hidden_dim"],
                seq_len=params["seq_len"],
                depth=params["n_layers"],
                heads=params["n_heads"],
                dim_head=params["dim_head"])

model = AutoregressiveWrapper(model)

# prepare data
dset_params = params["dataset"]
assert dset_params is not None

deepspeed.init_distributed(dist_backend='nccl')
torch.distributed.barrier(
)  # barrier will force processes to stop until *all* processes have reached the barrier
if is_main(train_args):
    prepare_data(dset_params["name"])
    torch.distributed.barrier(
    )  # barrier will force processes to stop until *all* processes have reached the barrier
else:
    torch.distributed.barrier()

train_dataset = GPT2Dataset(glob_pattern=dset_params["train_path"],
                            seq_len=params["seq_len"],
                            train=True,
                            **dset_params)

eval_dataset = GPT2Dataset(glob_pattern=dset_params["eval_path"],
                           seq_len=params["seq_len"],
                           train=False,
                           **dset_params)
예제 #3
0
train_args = get_args()
params = get_params(train_args.model)

# instantiate GPT-like decoder model
model = GPTNeoX(num_tokens=params["vocab_size"],
                dim=params["hidden_dim"],
                seq_len=params["seq_len"],
                depth=params["n_layers"],
                heads=params["n_heads"],
                dim_head=params["dim_head"])

## wandb
use_wandb = get_wandb_api_key() is not None
if use_wandb:
    # only display system stats from one worker per machine
    wandb_settings = wandb.Settings() if is_main(
        train_args) else wandb.Settings(_disable_stats=True)
    name = f'{socket.gethostname()}-{train_args.local_rank}' if train_args.group_name else None

    try:
        wandb.init(project="neox_train_enwik8",
                   group=train_args.group_name,
                   name=name,
                   save_code=True,
                   force=False,
                   entity=params.get('wandb', {}).get('team'),
                   settings=wandb_settings)
    except UsageError as e:
        use_wandb = False
        print(e)
        print(
            'Skipping wandb. Execute `wandb login` on local machine to enable.'