def prepare_dataset(dset_params, train_args): torch.distributed.barrier() # barrier will force processes to stop until *all* processes have reached the barrier if is_main(train_args): prepare_data(dset_params["name"]) torch.distributed.barrier() # barrier will force processes to stop until *all* processes have reached the barrier else: torch.distributed.barrier()
dim_head=params["dim_head"], loss_fn = loss_function,#torch.nn.CrossEntropyLoss(), num_stages = params.get("pipeline_num_stages", 2) ) model = AutoregressiveWrapper(model) # optimizer ds_model_params = prepare_optimizer_parameters(model) optim = torch.optim.Adam(model.parameters(), lr=params["learning_rate"]) # prepare data dset_params = params["dataset"] assert dset_params is not None if is_main(train_args): prepare_data(dset_params["name"]) torch.distributed.barrier() # barrier will force processes to stop until *all* processes have reached the barrier else: torch.distributed.barrier() # data loading train_dataset = GPT2Dataset(glob_pattern=dset_params["train_path"], seq_len=params["seq_len"], train=True, **dset_params) train_loader = model_engine.deepspeed_io(train_dataset, pin_memory=params.get("pin_memory", False)) eval_dataset = GPT2Dataset(glob_pattern=dset_params["eval_path"], seq_len=params["seq_len"], train=False, **dset_params)