Ejemplo n.º 1
0
def prepare_dataset(dset_params, train_args):
    torch.distributed.barrier()  # barrier will force processes to stop until *all* processes have reached the barrier
    if is_main(train_args):
        prepare_data(dset_params["name"])
        torch.distributed.barrier()  # barrier will force processes to stop until *all* processes have reached the barrier
    else:
        torch.distributed.barrier()
Ejemplo n.º 2
0
    heads=params["n_heads"],
    dim_head=params["dim_head"],
    loss_fn = loss_function,#torch.nn.CrossEntropyLoss(),
    num_stages = params.get("pipeline_num_stages", 2)
)
model = AutoregressiveWrapper(model)

# optimizer
ds_model_params = prepare_optimizer_parameters(model)
optim = torch.optim.Adam(model.parameters(), lr=params["learning_rate"])

# prepare data
dset_params = params["dataset"]
assert dset_params is not None

if is_main(train_args):
    prepare_data(dset_params["name"])
    torch.distributed.barrier()  # barrier will force processes to stop until *all* processes have reached the barrier
else:
    torch.distributed.barrier()
    
# data loading
train_dataset = GPT2Dataset(glob_pattern=dset_params["train_path"],
                            seq_len=params["seq_len"],
                            train=True,
                            **dset_params)
train_loader = model_engine.deepspeed_io(train_dataset, pin_memory=params.get("pin_memory", False))

eval_dataset = GPT2Dataset(glob_pattern=dset_params["eval_path"],
                           seq_len=params["seq_len"],
                           train=False,
Ejemplo n.º 3
0
    model_engine.mpu.checkpoint = deepspeed.checkpointing.checkpoint
    model_engine.mpu.get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
    model_engine.mpu.model_parallel_cuda_manual_seed = deepspeed.checkpointing.model_parallel_cuda_manual_seed
    assert deepspeed.checkpointing.is_configured()

def prepare_dataset(dset_params, train_args):
    torch.distributed.barrier()  # barrier will force processes to stop until *all* processes have reached the barrier
    if is_main(train_args):
        prepare_data(dset_params["name"])
        torch.distributed.barrier()  # barrier will force processes to stop until *all* processes have reached the barrier
    else:
        torch.distributed.barrier()

if __name__ == '__main__':
	#arguments
	IS_MAIN = is_main(train_args)
	deepspeed.init_distributed(dist_backend='nccl')

	# only display system stats from one worker per machine
	wandb_settings = wandb.Settings() if is_main(train_args) else wandb.Settings(_disable_stats=True)
	name = f'{socket.gethostname()}-{train_args.local_rank}' if train_args.group_name else None

	if train_args.mode == 'no_pipeline':
		model = GPTNeoX(
    		num_tokens=vocab_size,
    		dim=params["hidden_dim"],
    		seq_len=params["seq_len"],
    		depth=params["n_layers"],
    		heads=params["n_heads"],
    		dim_head=params["dim_head"]
		)